Tensorflow 在 C++ 中导出和运行图的不同方式

2021-12-09 00:00:00 python tensorflow c++

要将经过训练的网络导入 C++,您需要导出网络才能执行此操作.在搜索了很多并且几乎没有找到关于它的信息之后,澄清我们应该使用 freeze_graph() 能够做到这一点.

For importing your trained network to the C++ you need to export your network to be able to do so. After searching a lot and finding almost no information about it, it was clarified that we should use freeze_graph() to be able to do it.

感谢 Tensorflow 的新 0.7 版本,他们添加了 文档.

Thanks to the new 0.7 version of Tensorflow, they added documentation of it.

查看文档后发现类似的方法很少,你能说一下freeze_graph()和:tf.train.export_meta_graph 因为它有类似的参数,但它似乎也可以用于将模型导入 C++(我只是猜测不同之处在于,对于使用这种方法输出的文件,您只能使用 import_graph_def() 还是别的什么?)

After looking into documentations, I found that there are few similar methods, can you tell what is the difference between freeze_graph() and: tf.train.export_meta_graph as it has similar parameters, but it seems it can also be used for importing models to C++ (I just guess the difference is that for using the file output by this method you can only use import_graph_def() or it's something else?)

还有一个关于如何使用 write_graph() 的问题:在文档中,graph_defsess.graph_def 给出,但在 freeze_graph() 的例子中,它是 sess.graph.as_graph_def().这两者有什么区别?

Also one question about how to use write_graph(): In documentations the graph_def is given by sess.graph_def but in examples in freeze_graph() it is sess.graph.as_graph_def(). What is the difference between these two?

这个问题与这个问题有关.

谢谢!

推荐答案

这是我利用 TF 0.12 中引入的 V2 检查点的解决方案.

Here's my solution utilizing the V2 checkpoints introduced in TF 0.12.

无需将所有变量转换为常量或冻结图表.

There's no need to convert all variables to constants or freeze the graph.

为了清楚起见,我的目录 models 中的 V2 检查点如下所示:

Just for clarity, a V2 checkpoint looks like this in my directory models:

checkpoint  # some information on the name of the files in the checkpoint
my-model.data-00000-of-00001  # the saved weights
my-model.index  # probably definition of data layout in the previous file
my-model.meta  # protobuf of the graph (nodes and topology info)

Python 部分(保存)

with tf.Session() as sess:
    tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model')

如果您使用 tf.trainable_variables() 创建 Saver,您可以节省一些头痛和存储空间.但也许一些更复杂的模型需要保存所有数据,然后将此参数删除到 Saver,只需确保您正在创建 Saver after> 您的图表已创建.给所有变量/层赋予唯一的名称也是非常明智的,否则你可能会遇到不同的问题.

If you create the Saver with tf.trainable_variables(), you can save yourself some headache and storage space. But maybe some more complicated models need all data to be saved, then remove this argument to Saver, just make sure you're creating the Saver after your graph is created. It is also very wise to give all variables/layers unique names, otherwise you can run in different problems.

Python 部分(推理)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('models/my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('models/'))
    outputTensors = sess.run(outputOps, feed_dict=feedDict)

C++ 部分(推理)

请注意,checkpointPath 不是任何现有文件的路径,只是它们的公共前缀.如果您错误地放置了 .index 文件的路径,TF 不会告诉您这是错误的,但是由于未初始化的变量,它会在推理过程中死亡.

Note that checkpointPath isn't a path to any of the existing files, just their common prefix. If you mistakenly put there path to the .index file, TF won't tell you that was wrong, but it will die during inference due to uninitialized variables.

#include <tensorflow/core/public/session.h>
#include <tensorflow/core/protobuf/meta_graph.pb.h>

using namespace std;
using namespace tensorflow;

...
// set up your input paths
const string pathToGraph = "models/my-model.meta"
const string checkpointPath = "models/my-model";
...

auto session = NewSession(SessionOptions());
if (session == nullptr) {
    throw runtime_error("Could not create Tensorflow session.");
}

Status status;

// Read in the protobuf graph we exported
MetaGraphDef graph_def;
status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def);
if (!status.ok()) {
    throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString());
}

// Add the graph to the session
status = session->Create(graph_def.graph_def());
if (!status.ok()) {
    throw runtime_error("Error creating graph: " + status.ToString());
}

// Read weights from the saved checkpoint
Tensor checkpointPathTensor(DT_STRING, TensorShape());
checkpointPathTensor.scalar<std::string>()() = checkpointPath;
status = session->Run(
        {{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},
        {},
        {graph_def.saver_def().restore_op_name()},
        nullptr);
if (!status.ok()) {
    throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString());
}

// and run the inference to your liking
auto feedDict = ...
auto outputOps = ...
std::vector<tensorflow::Tensor> outputTensors;
status = session->Run(feedDict, outputOps, {}, &outputTensors);

相关文章