使用 Tensorflow 检查点在 C++ 中恢复模型

2021-12-27 00:00:00 python tensorflow deep-learning c++

我已经训练了一个我使用 Python 用 Tensorflow 实现的网络.最后,我用 tf.train.Saver() 保存了模型.现在我想使用 C++ 来使用这个预先训练好的网络进行预测.

I've trained a network that I implemented with Tensorflow using Python. In the end, I saved the model with tf.train.Saver(). And now I would like to use C++ to make predictions using this pre trained network.

我该怎么做?有没有办法转换检查点,以便我可以将它与 tiny-dnn 或 Tensorflow C++ 一起使用?

How can I do that ? Is there a way to convert checkpoint so I can use it with tiny-dnn or Tensorflow C++ ?

欢迎提出任何想法:) 谢谢!

Any idea is welcome :) thank you !

推荐答案

您可能应该在 SavedModel 格式,封装了计算图和保存的变量(tf.train.Saver只保存变量,所以你无论如何都必须保存图表).

You probably should export the model in the SavedModel format, which encapsulates the computational graph and the saved variables (tf.train.Saver only saves the variables, so you'd have to save the graph anyway).

然后您可以使用 LoadSavedModel.

You can then load the saved model in C++ using LoadSavedModel.

确切的调用取决于模型的输入和输出是什么.但是 Python 代码看起来像这样:

The exact invocation would depend on what the inputs and outputs of your model are. But the Python code would look something like so:

# You'd adjust the arguments here according to your model
signature = tf.saved_model.signature_def_utils.predict_signature_def(                                                                        
  inputs={'image': input_tensor}, outputs={'scores': output_tensor})                                                                         


builder = tf.saved_model.builder.SavedModelBuilder('/tmp/my_saved_model')                                                                    

builder.add_meta_graph_and_variables(                                                                                                        
   sess=sess,                                                                                                                    
   tags=[tf.saved_model.tag_constants.SERVING],                                                                                             
   signature_def_map={                                                                                                       
 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:                                                                
        signature                                                                                                                        
})                                                                                                                                       

builder.save()

然后在 C++ 中你会做这样的事情:

And then in C++ you'd do something like this:

tensorflow::SavedModelBundle model;
auto status = tensorflow::LoadSavedModel(session_options, run_options, "/tmp/my_saved_model", {tensorflow::kSavedModelTagServe}, &model);
if (!status.ok()) {
   std::cerr << "Failed: " << status;
   return;
}
// At this point you can use model.session

(请注意,使用 SavedModel 格式还允许您使用 TensorFlow Serving 提供模型,如果这对您的应用程序有意义)

(Note that using the SavedModel format will also allow you to serve models using TensorFlow Serving, if that makes sense for your application)

希望有所帮助.

相关文章