20 import tensorflow
as tf
24 __all__ = [
'convert_from_tensorflow']
41 self.
op2code = {
'Conv2D':1,
'DepthToSpace':2}
45 graph = tf.get_default_graph()
46 tf.import_graph_def(self.
graph_def, name=
"")
48 tf.summary.FileWriter(
'/tmp/graph', graph)
55 next = self.
edges[node.name][0]
56 if next.op ==
'BiasAdd':
59 next = self.
edges[next.name][0]
63 return knode, bnode, activation
67 assert(node.op ==
'Conv2D')
72 dilation = node.attr[
'dilations'].list.i[0]
73 padding = node.attr[
'padding'].s
76 ktensor = knode.attr[
'value'].tensor
77 filter_height = ktensor.tensor_shape.dim[0].size
78 filter_width = ktensor.tensor_shape.dim[1].size
79 in_channels = ktensor.tensor_shape.dim[2].size
80 out_channels = ktensor.tensor_shape.dim[3].size
81 kernel = np.frombuffer(ktensor.tensor_content, dtype=np.float32)
82 kernel = kernel.reshape(filter_height, filter_width, in_channels, out_channels)
83 kernel = np.transpose(kernel, [3, 0, 1, 2])
85 np.array([self.
op2code[node.op], dilation, padding, self.
conv_activations[activation], in_channels, out_channels, filter_height], dtype=np.uint32).tofile(f)
88 btensor = bnode.attr[
'value'].tensor
89 if btensor.tensor_shape.dim[0].size == 1:
90 bias = struct.pack(
"f", btensor.float_val[0])
92 bias = btensor.tensor_content
97 assert(node.op ==
'DepthToSpace')
99 block_size = node.attr[
'block_size'].i
100 np.array([self.
op2code[node.op], block_size], dtype=np.uint32).tofile(f)
108 with open(
'/tmp/tmp.model',
'wb')
as f:
114 for node
in self.
nodes:
117 if node.op ==
'Conv2D':
119 elif node.op ==
'DepthToSpace':
125 with open(self.
outfile,
'wb')
as f:
126 np.array([self.
layer_number], dtype=np.uint32).tofile(f)
131 for node
in self.
nodes:
137 for node
in self.
nodes:
138 for input
in node.input:
139 used_names.append(input)
141 for node
in self.
nodes:
142 if node.name
not in used_names:
149 for node
in self.
nodes:
150 if node.op ==
'Identity':
152 input = node.input[0]
153 id_nodes.append(node)
160 id_dict[name] = input
162 for idnode
in id_nodes:
163 self.
nodes.remove(idnode)
165 for node
in self.
nodes:
166 for i
in range(
len(node.input)):
167 input = node.input[i]
169 node.input[i] = id_dict[input]
173 for node
in self.
nodes:
174 for input
in node.input:
175 if input
in self.
edges:
178 self.
edges[input] = [node]
194 with open(infile,
'rb')
as f:
196 graph_def = tf.GraphDef()
197 graph_def.ParseFromString(f.read())
198 nodes = graph_def.node