You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

244 lines
8.6KB

  1. # Copyright (c) 2019 Guo Yejun
  2. #
  3. # This file is part of FFmpeg.
  4. #
  5. # FFmpeg is free software; you can redistribute it and/or
  6. # modify it under the terms of the GNU Lesser General Public
  7. # License as published by the Free Software Foundation; either
  8. # version 2.1 of the License, or (at your option) any later version.
  9. #
  10. # FFmpeg is distributed in the hope that it will be useful,
  11. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  13. # Lesser General Public License for more details.
  14. #
  15. # You should have received a copy of the GNU Lesser General Public
  16. # License along with FFmpeg; if not, write to the Free Software
  17. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  18. # ==============================================================================
  19. import tensorflow as tf
  20. import numpy as np
  21. import sys, struct
  22. __all__ = ['convert_from_tensorflow']
  23. class TFConverter:
  24. def __init__(self, graph_def, nodes, outfile, dump4tb):
  25. self.graph_def = graph_def
  26. self.nodes = nodes
  27. self.outfile = outfile
  28. self.dump4tb = dump4tb
  29. self.layer_number = 0
  30. self.output_names = []
  31. self.name_node_dict = {}
  32. self.edges = {}
  33. self.conv_activations = {'Relu':0, 'Tanh':1, 'Sigmoid':2, 'None':3, 'LeakyRelu':4}
  34. self.conv_paddings = {'VALID':0, 'SAME':1}
  35. self.converted_nodes = set()
  36. self.conv2d_scope_names = set()
  37. self.op2code = {'Conv2D':1, 'DepthToSpace':2, 'MirrorPad':3}
  38. self.mirrorpad_mode = {'CONSTANT':0, 'REFLECT':1, 'SYMMETRIC':2}
  39. def dump_for_tensorboard(self):
  40. graph = tf.get_default_graph()
  41. tf.import_graph_def(self.graph_def, name="")
  42. tf.summary.FileWriter('/tmp/graph', graph)
  43. print('graph saved, run "tensorboard --logdir=/tmp/graph" to see it')
  44. def get_conv2d_params(self, conv2d_scope_name):
  45. knode = self.name_node_dict[conv2d_scope_name + '/kernel']
  46. bnode = self.name_node_dict[conv2d_scope_name + '/bias']
  47. if conv2d_scope_name + '/dilation_rate' in self.name_node_dict:
  48. dnode = self.name_node_dict[conv2d_scope_name + '/dilation_rate']
  49. else:
  50. dnode = None
  51. # the BiasAdd name is possible be changed into the output name,
  52. # if activation is None, and BiasAdd.next is the last op which is Identity
  53. if conv2d_scope_name + '/BiasAdd' in self.edges:
  54. activation = self.edges[conv2d_scope_name + '/BiasAdd'][0]
  55. activation = activation.op
  56. else:
  57. activation = 'None'
  58. return knode, bnode, dnode, activation
  59. def dump_conv2d_to_file(self, node, f):
  60. assert(node.op == 'Conv2D')
  61. self.layer_number = self.layer_number + 1
  62. self.converted_nodes.add(node.name)
  63. scope_name = TFConverter.get_scope_name(node.name)
  64. #knode for kernel, bnode for bias, dnode for dilation
  65. knode, bnode, dnode, activation = self.get_conv2d_params(scope_name)
  66. if dnode is not None:
  67. dilation = struct.unpack('i', dnode.attr['value'].tensor.tensor_content[0:4])[0]
  68. else:
  69. dilation = 1
  70. padding = node.attr['padding'].s.decode("utf-8")
  71. # conv2d with dilation > 1 generates tens of nodes, not easy to parse them, so use tricky.
  72. if dilation > 1 and scope_name + '/stack' in self.name_node_dict:
  73. if self.name_node_dict[scope_name + '/stack'].op == "Const":
  74. padding = 'SAME'
  75. padding = self.conv_paddings[padding]
  76. ktensor = knode.attr['value'].tensor
  77. filter_height = ktensor.tensor_shape.dim[0].size
  78. filter_width = ktensor.tensor_shape.dim[1].size
  79. in_channels = ktensor.tensor_shape.dim[2].size
  80. out_channels = ktensor.tensor_shape.dim[3].size
  81. kernel = np.frombuffer(ktensor.tensor_content, dtype=np.float32)
  82. kernel = kernel.reshape(filter_height, filter_width, in_channels, out_channels)
  83. kernel = np.transpose(kernel, [3, 0, 1, 2])
  84. np.array([self.op2code[node.op], dilation, padding, self.conv_activations[activation], in_channels, out_channels, filter_height], dtype=np.uint32).tofile(f)
  85. kernel.tofile(f)
  86. btensor = bnode.attr['value'].tensor
  87. if btensor.tensor_shape.dim[0].size == 1:
  88. bias = struct.pack("f", btensor.float_val[0])
  89. else:
  90. bias = btensor.tensor_content
  91. f.write(bias)
  92. def dump_depth2space_to_file(self, node, f):
  93. assert(node.op == 'DepthToSpace')
  94. self.layer_number = self.layer_number + 1
  95. block_size = node.attr['block_size'].i
  96. np.array([self.op2code[node.op], block_size], dtype=np.uint32).tofile(f)
  97. self.converted_nodes.add(node.name)
  98. def dump_mirrorpad_to_file(self, node, f):
  99. assert(node.op == 'MirrorPad')
  100. self.layer_number = self.layer_number + 1
  101. mode = node.attr['mode'].s
  102. mode = self.mirrorpad_mode[mode.decode("utf-8")]
  103. np.array([self.op2code[node.op], mode], dtype=np.uint32).tofile(f)
  104. pnode = self.name_node_dict[node.input[1]]
  105. self.converted_nodes.add(pnode.name)
  106. paddings = pnode.attr['value'].tensor.tensor_content
  107. f.write(paddings)
  108. self.converted_nodes.add(node.name)
  109. def dump_layers_to_file(self, f):
  110. for node in self.nodes:
  111. if node.name in self.converted_nodes:
  112. continue
  113. # conv2d with dilation generates very complex nodes, so handle it in special
  114. scope_name = TFConverter.get_scope_name(node.name)
  115. if scope_name in self.conv2d_scope_names:
  116. if node.op == 'Conv2D':
  117. self.dump_conv2d_to_file(node, f)
  118. continue
  119. if node.op == 'DepthToSpace':
  120. self.dump_depth2space_to_file(node, f)
  121. elif node.op == 'MirrorPad':
  122. self.dump_mirrorpad_to_file(node, f)
  123. def dump_to_file(self):
  124. with open(self.outfile, 'wb') as f:
  125. self.dump_layers_to_file(f)
  126. np.array([self.layer_number], dtype=np.uint32).tofile(f)
  127. def generate_name_node_dict(self):
  128. for node in self.nodes:
  129. self.name_node_dict[node.name] = node
  130. def generate_output_names(self):
  131. used_names = []
  132. for node in self.nodes:
  133. for input in node.input:
  134. used_names.append(input)
  135. for node in self.nodes:
  136. if node.name not in used_names:
  137. self.output_names.append(node.name)
  138. def remove_identity(self):
  139. id_nodes = []
  140. id_dict = {}
  141. for node in self.nodes:
  142. if node.op == 'Identity':
  143. name = node.name
  144. input = node.input[0]
  145. id_nodes.append(node)
  146. # do not change the output name
  147. if name in self.output_names:
  148. self.name_node_dict[input].name = name
  149. self.name_node_dict[name] = self.name_node_dict[input]
  150. del self.name_node_dict[input]
  151. else:
  152. id_dict[name] = input
  153. for idnode in id_nodes:
  154. self.nodes.remove(idnode)
  155. for node in self.nodes:
  156. for i in range(len(node.input)):
  157. input = node.input[i]
  158. if input in id_dict:
  159. node.input[i] = id_dict[input]
  160. def generate_edges(self):
  161. for node in self.nodes:
  162. for input in node.input:
  163. if input in self.edges:
  164. self.edges[input].append(node)
  165. else:
  166. self.edges[input] = [node]
  167. @staticmethod
  168. def get_scope_name(name):
  169. index = name.rfind('/')
  170. if index == -1:
  171. return ""
  172. return name[0:index]
  173. def generate_conv2d_scope_names(self):
  174. for node in self.nodes:
  175. if node.op == 'Conv2D':
  176. scope = TFConverter.get_scope_name(node.name)
  177. self.conv2d_scope_names.add(scope)
  178. def run(self):
  179. self.generate_name_node_dict()
  180. self.generate_output_names()
  181. self.remove_identity()
  182. self.generate_edges()
  183. self.generate_conv2d_scope_names()
  184. if self.dump4tb:
  185. self.dump_for_tensorboard()
  186. self.dump_to_file()
  187. def convert_from_tensorflow(infile, outfile, dump4tb):
  188. with open(infile, 'rb') as f:
  189. # read the file in .proto format
  190. graph_def = tf.GraphDef()
  191. graph_def.ParseFromString(f.read())
  192. nodes = graph_def.node
  193. converter = TFConverter(graph_def, nodes, outfile, dump4tb)
  194. converter.run()