File size: 3,357 Bytes
999c5c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
DeepLabCut Toolbox (deeplabcut.org)
© A. & M. Mathis Labs

Licensed under GNU Lesser General Public License v3.0
"""


from tkinter import Tk, Label
import colorcet as cc
from PIL import Image, ImageTk, ImageDraw


class Display(object):
    """
    Simple object to display frames with DLC labels.

    Parameters
    -----------
    cmap : string
        string indicating the Matoplotlib colormap to use.
    pcutoff : float
        likelihood threshold to display points
    """

    def __init__(self, cmap="bmy", radius=3, pcutoff=0.5):
        """ Constructor method
        """

        self.cmap = cmap
        self.colors = None
        self.radius = radius
        self.pcutoff = pcutoff
        self.window = None

    def set_display(self, im_size, bodyparts):
        """ Create tkinter window to display image
        
        Parameters
        ----------
        im_size : tuple
            (width, height) of image
        bodyparts : int
            number of bodyparts
        """

        self.window = Tk()
        self.window.title("DLC Live")
        self.lab = Label(self.window)
        self.lab.pack()

        all_colors = getattr(cc, self.cmap)
        self.colors = all_colors[:: int(len(all_colors) / bodyparts)]

    def display_frame(self, frame, pose=None):
        """
        Display the image with DeepLabCut labels using opencv imshow

        Parameters
        -----------
        frame :class:`numpy.ndarray`
            an image as a numpy array

        pose :class:`numpy.ndarray`
            the pose estimated by DeepLabCut for the image
        """

        im_size = (frame.shape[1], frame.shape[0])

        if pose is not None:

            if self.window is None:
                self.set_display(im_size, pose.shape[0])

            img = Image.fromarray(frame)
            draw = ImageDraw.Draw(img)

            for i in range(pose.shape[0]):
                if pose[i, 2] > self.pcutoff:
                    try:
                        x0 = (
                            pose[i, 0] - self.radius
                            if pose[i, 0] - self.radius > 0
                            else 0
                        )
                        x1 = (
                            pose[i, 0] + self.radius
                            if pose[i, 0] + self.radius < im_size[0]
                            else im_size[1]
                        )
                        y0 = (
                            pose[i, 1] - self.radius
                            if pose[i, 1] - self.radius > 0
                            else 0
                        )
                        y1 = (
                            pose[i, 1] + self.radius
                            if pose[i, 1] + self.radius < im_size[1]
                            else im_size[0]
                        )
                        coords = [x0, y0, x1, y1]
                        draw.ellipse(
                            coords, fill=self.colors[i], outline=self.colors[i]
                        )
                    except Exception as e:
                        print(e)

        img_tk = ImageTk.PhotoImage(image=img, master=self.window)
        self.lab.configure(image=img_tk)
        self.window.update()

    def destroy(self):
        """
        Destroys the opencv image window
        """

        self.window.destroy()