Spaces:
Runtime error
Runtime error
""" | |
DeepLabCut Toolbox (deeplabcut.org) | |
© A. & M. Mathis Labs | |
Licensed under GNU Lesser General Public License v3.0 | |
""" | |
import tensorflow as tf | |
vers = (tf.__version__).split(".") | |
if int(vers[0]) == 2 or int(vers[0]) == 1 and int(vers[1]) > 12: | |
tf = tf.compat.v1 | |
else: | |
tf = tf | |
def read_graph(file): | |
""" | |
Loads the graph from a protobuf file | |
Parameters | |
----------- | |
file : string | |
path to the protobuf file | |
Returns | |
-------- | |
graph_def :class:`tensorflow.tf.compat.v1.GraphDef` | |
The graph definition of the DeepLabCut model found at the object's path | |
""" | |
with tf.io.gfile.GFile(file, "rb") as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
return graph_def | |
def finalize_graph(graph_def): | |
""" | |
Finalize the graph and get inputs to model | |
Parameters | |
----------- | |
graph_def :class:`tensorflow.compat.v1.GraphDef` | |
The graph of the DeepLabCut model, read using the :func:`read_graph` method | |
Returns | |
-------- | |
graph :class:`tensorflow.compat.v1.GraphDef` | |
The finalized graph of the DeepLabCut model | |
inputs :class:`tensorflow.Tensor` | |
Input tensor(s) for the model | |
""" | |
graph = tf.Graph() | |
with graph.as_default(): | |
tf.import_graph_def(graph_def, name="DLC") | |
graph.finalize() | |
return graph | |
def get_output_nodes(graph): | |
""" | |
Get the output node names from a graph | |
Parameters | |
----------- | |
graph :class:`tensorflow.Graph` | |
The graph of the DeepLabCut model | |
Returns | |
-------- | |
output : list | |
the output node names as a list of strings | |
""" | |
op_names = [str(op.name) for op in graph.get_operations()] | |
if "concat_1" in op_names[-1]: | |
output = [op_names[-1]] | |
else: | |
output = [op_names[-1], op_names[-2]] | |
return output | |
def get_output_tensors(graph): | |
""" | |
Get the names of the output tensors from a graph | |
Parameters | |
----------- | |
graph :class:`tensorflow.Graph` | |
The graph of the DeepLabCut model | |
Returns | |
-------- | |
output : list | |
the output tensor names as a list of strings | |
""" | |
output_nodes = get_output_nodes(graph) | |
output_tensor = [out + ":0" for out in output_nodes] | |
return output_tensor | |
def get_input_tensor(graph): | |
input_tensor = str(graph.get_operations()[0].name) + ":0" | |
return input_tensor | |
def extract_graph(graph, tf_config=None): | |
""" | |
Initializes a tensorflow session with the specified graph and extracts the model's inputs and outputs | |
Parameters | |
----------- | |
graph :class:`tensorflow.Graph` | |
a tensorflow graph containing the desired model | |
tf_config :class:`tensorflow.ConfigProto` | |
Returns | |
-------- | |
sess :class:`tensorflow.Session` | |
a tensorflow session with the specified graph definition | |
outputs :class:`tensorflow.Tensor` | |
the output tensor(s) for the model | |
""" | |
input_tensor = get_input_tensor(graph) | |
output_tensor = get_output_tensors(graph) | |
sess = tf.Session(graph=graph, config=tf_config) | |
inputs = graph.get_tensor_by_name(input_tensor) | |
outputs = [graph.get_tensor_by_name(out) for out in output_tensor] | |
return sess, inputs, outputs | |