Spaces:
Runtime error
Runtime error
File size: 3,276 Bytes
999c5c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
"""
DeepLabCut Toolbox (deeplabcut.org)
© A. & M. Mathis Labs
Licensed under GNU Lesser General Public License v3.0
"""
import tensorflow as tf
vers = (tf.__version__).split(".")
if int(vers[0]) == 2 or int(vers[0]) == 1 and int(vers[1]) > 12:
tf = tf.compat.v1
else:
tf = tf
def read_graph(file):
"""
Loads the graph from a protobuf file
Parameters
-----------
file : string
path to the protobuf file
Returns
--------
graph_def :class:`tensorflow.tf.compat.v1.GraphDef`
The graph definition of the DeepLabCut model found at the object's path
"""
with tf.io.gfile.GFile(file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
def finalize_graph(graph_def):
"""
Finalize the graph and get inputs to model
Parameters
-----------
graph_def :class:`tensorflow.compat.v1.GraphDef`
The graph of the DeepLabCut model, read using the :func:`read_graph` method
Returns
--------
graph :class:`tensorflow.compat.v1.GraphDef`
The finalized graph of the DeepLabCut model
inputs :class:`tensorflow.Tensor`
Input tensor(s) for the model
"""
graph = tf.Graph()
with graph.as_default():
tf.import_graph_def(graph_def, name="DLC")
graph.finalize()
return graph
def get_output_nodes(graph):
"""
Get the output node names from a graph
Parameters
-----------
graph :class:`tensorflow.Graph`
The graph of the DeepLabCut model
Returns
--------
output : list
the output node names as a list of strings
"""
op_names = [str(op.name) for op in graph.get_operations()]
if "concat_1" in op_names[-1]:
output = [op_names[-1]]
else:
output = [op_names[-1], op_names[-2]]
return output
def get_output_tensors(graph):
"""
Get the names of the output tensors from a graph
Parameters
-----------
graph :class:`tensorflow.Graph`
The graph of the DeepLabCut model
Returns
--------
output : list
the output tensor names as a list of strings
"""
output_nodes = get_output_nodes(graph)
output_tensor = [out + ":0" for out in output_nodes]
return output_tensor
def get_input_tensor(graph):
input_tensor = str(graph.get_operations()[0].name) + ":0"
return input_tensor
def extract_graph(graph, tf_config=None):
"""
Initializes a tensorflow session with the specified graph and extracts the model's inputs and outputs
Parameters
-----------
graph :class:`tensorflow.Graph`
a tensorflow graph containing the desired model
tf_config :class:`tensorflow.ConfigProto`
Returns
--------
sess :class:`tensorflow.Session`
a tensorflow session with the specified graph definition
outputs :class:`tensorflow.Tensor`
the output tensor(s) for the model
"""
input_tensor = get_input_tensor(graph)
output_tensor = get_output_tensors(graph)
sess = tf.Session(graph=graph, config=tf_config)
inputs = graph.get_tensor_by_name(input_tensor)
outputs = [graph.get_tensor_by_name(out) for out in output_tensor]
return sess, inputs, outputs
|