File size: 3,276 Bytes
999c5c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
DeepLabCut Toolbox (deeplabcut.org)
© A. & M. Mathis Labs

Licensed under GNU Lesser General Public License v3.0
"""


import tensorflow as tf

vers = (tf.__version__).split(".")
if int(vers[0]) == 2 or int(vers[0]) == 1 and int(vers[1]) > 12:
    tf = tf.compat.v1
else:
    tf = tf


def read_graph(file):
    """
    Loads the graph from a protobuf file

    Parameters
    -----------
    file : string
        path to the protobuf file

    Returns
    --------
    graph_def :class:`tensorflow.tf.compat.v1.GraphDef`
        The graph definition of the DeepLabCut model found at the object's path
    """

    with tf.io.gfile.GFile(file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        return graph_def


def finalize_graph(graph_def):
    """
    Finalize the graph and get inputs to model

    Parameters
    -----------
    graph_def :class:`tensorflow.compat.v1.GraphDef`
        The graph of the DeepLabCut model, read using the :func:`read_graph` method

    Returns
    --------
    graph :class:`tensorflow.compat.v1.GraphDef`
        The finalized graph of the DeepLabCut model
    inputs :class:`tensorflow.Tensor`
        Input tensor(s) for the model
    """

    graph = tf.Graph()
    with graph.as_default():
        tf.import_graph_def(graph_def, name="DLC")
    graph.finalize()

    return graph


def get_output_nodes(graph):
    """
    Get the output node names from a graph

    Parameters
    -----------
    graph :class:`tensorflow.Graph`
        The graph of the DeepLabCut model

    Returns
    --------
    output : list
        the output node names as a list of strings
    """

    op_names = [str(op.name) for op in graph.get_operations()]
    if "concat_1" in op_names[-1]:
        output = [op_names[-1]]
    else:
        output = [op_names[-1], op_names[-2]]

    return output


def get_output_tensors(graph):
    """
    Get the names of the output tensors from a graph

    Parameters
    -----------
    graph :class:`tensorflow.Graph`
        The graph of the DeepLabCut model

    Returns
    --------
    output : list
        the output tensor names as a list of strings
    """

    output_nodes = get_output_nodes(graph)
    output_tensor = [out + ":0" for out in output_nodes]
    return output_tensor


def get_input_tensor(graph):

    input_tensor = str(graph.get_operations()[0].name) + ":0"
    return input_tensor


def extract_graph(graph, tf_config=None):
    """
    Initializes a tensorflow session with the specified graph and extracts the model's inputs and outputs

    Parameters
    -----------
    graph :class:`tensorflow.Graph`
        a tensorflow graph containing the desired model
    tf_config :class:`tensorflow.ConfigProto`

    Returns
    --------
    sess :class:`tensorflow.Session`
        a tensorflow session with the specified graph definition
    outputs :class:`tensorflow.Tensor`
        the output tensor(s) for the model
    """

    input_tensor = get_input_tensor(graph)
    output_tensor = get_output_tensors(graph)
    sess = tf.Session(graph=graph, config=tf_config)
    inputs = graph.get_tensor_by_name(input_tensor)
    outputs = [graph.get_tensor_by_name(out) for out in output_tensor]

    return sess, inputs, outputs