Alex Chan
initial commit
999c5c9
"""
DeepLabCut Toolbox (deeplabcut.org)
© A. & M. Mathis Labs
Licensed under GNU Lesser General Public License v3.0
"""
import tensorflow as tf
vers = (tf.__version__).split(".")
if int(vers[0]) == 2 or int(vers[0]) == 1 and int(vers[1]) > 12:
tf = tf.compat.v1
else:
tf = tf
def read_graph(file):
"""
Loads the graph from a protobuf file
Parameters
-----------
file : string
path to the protobuf file
Returns
--------
graph_def :class:`tensorflow.tf.compat.v1.GraphDef`
The graph definition of the DeepLabCut model found at the object's path
"""
with tf.io.gfile.GFile(file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
def finalize_graph(graph_def):
"""
Finalize the graph and get inputs to model
Parameters
-----------
graph_def :class:`tensorflow.compat.v1.GraphDef`
The graph of the DeepLabCut model, read using the :func:`read_graph` method
Returns
--------
graph :class:`tensorflow.compat.v1.GraphDef`
The finalized graph of the DeepLabCut model
inputs :class:`tensorflow.Tensor`
Input tensor(s) for the model
"""
graph = tf.Graph()
with graph.as_default():
tf.import_graph_def(graph_def, name="DLC")
graph.finalize()
return graph
def get_output_nodes(graph):
"""
Get the output node names from a graph
Parameters
-----------
graph :class:`tensorflow.Graph`
The graph of the DeepLabCut model
Returns
--------
output : list
the output node names as a list of strings
"""
op_names = [str(op.name) for op in graph.get_operations()]
if "concat_1" in op_names[-1]:
output = [op_names[-1]]
else:
output = [op_names[-1], op_names[-2]]
return output
def get_output_tensors(graph):
"""
Get the names of the output tensors from a graph
Parameters
-----------
graph :class:`tensorflow.Graph`
The graph of the DeepLabCut model
Returns
--------
output : list
the output tensor names as a list of strings
"""
output_nodes = get_output_nodes(graph)
output_tensor = [out + ":0" for out in output_nodes]
return output_tensor
def get_input_tensor(graph):
input_tensor = str(graph.get_operations()[0].name) + ":0"
return input_tensor
def extract_graph(graph, tf_config=None):
"""
Initializes a tensorflow session with the specified graph and extracts the model's inputs and outputs
Parameters
-----------
graph :class:`tensorflow.Graph`
a tensorflow graph containing the desired model
tf_config :class:`tensorflow.ConfigProto`
Returns
--------
sess :class:`tensorflow.Session`
a tensorflow session with the specified graph definition
outputs :class:`tensorflow.Tensor`
the output tensor(s) for the model
"""
input_tensor = get_input_tensor(graph)
output_tensor = get_output_tensors(graph)
sess = tf.Session(graph=graph, config=tf_config)
inputs = graph.get_tensor_by_name(input_tensor)
outputs = [graph.get_tensor_by_name(out) for out in output_tensor]
return sess, inputs, outputs