如何保存包含所有权重的TensorFlow 2目标检测模型?

How to save Tensorflow 2 Object Detection Model including all weights?(如何保存包含所有权重的TensorFlow 2目标检测模型?)
本文介绍了如何保存包含所有权重的TensorFlow 2目标检测模型?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用PYTHON中的TensorFlow 2API进行对象检测。到目前为止,这个方法运行得很好。然而,如果我想保存模型,我使用的是exporter_main_v2.py,它导出一个图形(.pb)和一个检查点(checkpointckpt-0.datackpt-0.index))。图表不包括任何权重,我必须始终使用检查点来处理保存的模型。 是否有办法将所有权重保存到Protobuf(.pb)文件中?

以下是我尝试过的内容:

  • 保存冻结模型:TF2显然不再支持冻结图形。将冻结包含所有权重的图形的export_inference_graph.py在TF2下不起作用。
  • freeze_graph.py相同:只能使用TF1

推荐答案

您仍然可以使用tf2中的冻结技术,使用compat.v1模块:

在下面的代码片段中,我假设您有一个预先训练好的模型,其权重以TF2方式保存,tf.saved_model.save

graph = tf.Graph()
with graph.as_default():
    sess = tf.compat.v1.Session()
    with sess.as_default():
        # creating the model/loading it from a TF2 pb file
        # (If you have a keras model, you can use 
        #`tf.keras.models.load_model` instead). 
        model = tf.saved_model.load("/path/to/model")

# the default signature might be different.
sign = model.signatures["serving_default"]
# if using keras, just use model.outputs
tensor_out_names = [out.name.split(":")[0] for out in sign.outputs]
    
graphdef = tf.compat.v1.graph_util.convert_variables_to_constants(
    sess, graph.as_graph_def(), tensor_out_names
)
# the following is optional, use only if no more training is required
graphdef = tf.compat.v1.graph_util.remove_training_nodes(graphdef)
tf.python.framework.graph_io.write_graph(graphdef, "./", "/path/to/frozengraph", as_text=False)
但是,除非是出于与旧工具的兼容性原因,否则我不会这么做。compat模块可能有一天会被弃用,据我所知,只有一个文件包含图形+权重,而不是拆分它们,不会有很大的值。

这篇关于如何保存包含所有权重的TensorFlow 2目标检测模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!

本站部分内容来源互联网,如果有图片或者内容侵犯您的权益请联系我们删除!

相关文档推荐

Leetcode 234: Palindrome LinkedList(Leetcode 234:回文链接列表)
How do I read an Excel file directly from Dropbox#39;s API using pandas.read_excel()?(如何使用PANDAS.READ_EXCEL()直接从Dropbox的API读取Excel文件?)
subprocess.Popen tries to write to nonexistent pipe(子进程。打开尝试写入不存在的管道)
I want to realize Popen-code from Windows to Linux:(我想实现从Windows到Linux的POpen-code:)
Reading stdout from a subprocess in real time(实时读取子进程中的标准输出)
How to call type safely on a random file in Python?(如何在Python中安全地调用随机文件上的类型?)