visuals.py 6.29 KB
Newer Older
1
2
3
4
5
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
6
import numpy as np
7
8
9
10


class SequencePlot(object):
    """Key class for visualizations with two axis"""
11
    def __init__(self, view, width=40, height=3):
12
13
        self.width = width
        self.height = height
14
15
        self._x = self.get_xpos(view)
        # outsourcen in Frames Class und dann hier löschen
16
17
18

    def get_xpos(self, view):
        """calculate frame numbers for x-ticks"""
19

20
21
        # TODO Die Frame No. Range könnte ich auch besser zu einer Methode von
        # Frames machen
22
        return [nr for nr in range(view._frames.start, view._frames.end,
23
24
                                   view._frame_step)]
        # um X-Achse Minutenanzeigen zu lassen
25

26
27
    # TODO Befindet sich jetzt in helpers, kann also aus der Klasse entfernt
    # werden
28
    def _timelabels(self, val, pos):
29
        """counts time values for plot tickers
30

31
        Arguments:
32
33
34
35
36
            val {int} -- default ticker value (frame number) as used by
                         matplotlib
            pos {int} -- current ticker position as passed by the
                         matplotlib.FuncFormatter

37
38
39
40
41
42
        Returns:
            str -- timecode for the given ticker label (frame number)
        """

        val, _ = divmod(int(val), 4)  # scale frame number from fps=4 to fps=1
        min, sec = divmod(int(val), 60)  # calculate timecode
43
44
45
46
47
        timelabel = "{0}:{1:02d}".format(min, sec)
        return timelabel

    # TODO styling und plotten sind hier noch etwas zusammengemischt
    def ittenstyle(self, ax, view):
48

49
        plt.style.use('ggplot')
50

51
52
53
        fig_coef = self.width / self.height
        tick_cnt = fig_coef / 0.3605405405
        tick_step = int(view._frames.frm_cnt / tick_cnt)
54

55
56
        # loc = plticker.MultipleLocator(base=tick_freq) # this locator puts
        # ticks at regular intervals (0.0005)
57
58
59
60
61
62
63
        loc = plticker.FixedLocator(range(0, view._frames.end, tick_step))
        fmt = plticker.FuncFormatter(self._timelabels)
        ax.xaxis.set_major_locator(loc)
        ax.xaxis.set_major_formatter(fmt)
        ax.set_xlim(view._frames.start - 20, view._frames.end + 20)

        # Beschriftung der Y-Achse
64
        # TODO funktioniert nicht richtig
65
        ax.set_ylim(-1, view._bins + 1)
66
67
        loc = plticker.FixedLocator(range(0, view._bins + 1,
                                    int(view._bins / 8)))
68
        ax.yaxis.set_major_locator(loc)
69
70
        # TODO mit iter_ticks evtl noch die angegebenen Sekunden auf base 60
        # setzen
71
72
73
74

        # obere x-achse mit zeitlich versetzten werten
        axt = ax.twiny()
        axt.set_xlim(ax.get_xlim())
75
        axt.set_ylim(ax.get_ylim())
76
77
        # loc = plticker.MultipleLocator(base=tick_freq)
        # loc = plticker.LinearLocator(20)
78
79
        loc = plticker.FixedLocator(range(int(tick_step / 2), view._frames.end,
                                    tick_step))
80
81
82
83
        fmt = plticker.FuncFormatter(self._timelabels)
        axt.xaxis.set_major_locator(loc)
        axt.xaxis.set_major_formatter(fmt)

84
85
        ax.set_facecolor((1, 1, 1))
        # INFO war vor matplotlib 2 set_axis_bgcolor
86
87
88
89
90
91
92
93
94
95
96
97
98

        chn_label = view._contrast
        ax.set_ylabel(chn_label, {'fontsize': 8})
        ax.set_xlabel('Time', {'fontsize': 8}, y=0.5)
        ax.yaxis.grid(False)
        axt.yaxis.grid(False)
        ax.tick_params(length=0)
        axt.tick_params(length=0)
        ax.xaxis.grid(c=(0.90, 0.90, 0.90))
        axt.xaxis.grid(c=(0.90, 0.90, 0.90))

        return (ax, axt)

99
100
101
102
103
    # TODO Überführbar in Superclass?
    # evtl. auch eher feature drer View Klasse
    # TODO self._x sollte eine array sein
    def _vlines(self, view, mark, mark_gt, mark_lt):

104
        npx = np.array(self._x)
105
106
107
108
109
110
111
112
113
114
115
116
117
        if mark_gt:
            poss = npx[view > mark_gt]
            for pos in poss:
                self._ax.axvline(pos, color='#a6e22e', alpha=0.4, linewidth=3)
        if mark_lt:
            poss = npx[view < mark_lt]
            for pos in poss:
                self._ax.axvline(pos, color='#f92672', alpha=0.4, linewidth=3)
        if mark:
            for pos in mark:
                print(self._ax)
                self._ax.axvline(pos, color='#66d9ef', alpha=0.4, linewidth=3)

118
119
120
    def saveplt(self, title=False, fname='plot.png'):
        if title:
            self._ax.set_title(title, {'fontsize': 14}, y=1.18)
121

122
        self.fig.set_size_inches(self.width, self.height)
123
124
125
        self.fig.tight_layout()

        self.fig.savefig(fname, dpi=400)
126

127
128
129

class MultivariatePlot(SequencePlot):
    """Scatterplot that shows n features per frame"""
130
131
    def __init__(self, view):
        super(MultivariatePlot, self).__init__(view)
132
133
134
        self.fig = plt.figure()
        self._ax = plt.axes()

135
    def plot(self, view, mark_gt=False, mark_lt=False, mark=False):
136
137
138
139
140
141
142
143
        x = view[:, 0]
        y = view[:, 1]
        value = view[:, 2]
        thickness = list([int((v - view._threshold) / 4000) for v in value])

        self._ax, axt = self.ittenstyle(self._ax, view)

        # Plotten
144
        axt.scatter(x, y, c='black', s=thickness, linewidths=0)
145
146
        # axt.scatter(x, y, c=y, cmap='Greys_r', s=thickness, linewidths=0)
        # # DOC: vmin/vmax sorgt für die Verteilung der Fraben der Colorm
147
148
        # ax.scatter(x, y, c=y, cmap='hsv', s=thickness, linewidths=0)

149
150
151
        if mark:
            self._vlines(view, mark, False, False)

152
        self.fig.tight_layout()
153
        self.fig.set_size_inches(self.width, self.height)
154
155

        return self.fig, self._ax
156
157
158
159


class UnivariatePlot(SequencePlot):
    """Lineplot that shows one featur per fram"""
160
161
162
    def __init__(self, view):
        super(UnivariatePlot, self).__init__(view)
        # TODO die können evtl. noch in die super class
163
164
        self.fig = plt.figure()
        self._ax = plt.axes()
165
        self._ax, self._axt = self.ittenstyle(self._ax, view)
166

167
168
169
170
171
    def plot(self, view,
             label=False,
             mark_gt=False,
             mark_lt=False,
             mark=False):
172
173
174

        if not(label):
            label = view.feature
175

176
177
        # Interpolation mit savitzky_golay funktioniert nicht
        # contrast_points = savitzky_golay(np.array(contrast_points), 51, 7)
178
        self._axt.plot(self._x, view, label=label)
179

180
181
        if any([mark, mark_gt, mark_lt]):
            self._vlines(view, mark, mark_gt, mark_lt)
182

183
        self._axt.legend()
184

185
        return self.fig