详情页标题前

阿里云人工智能平台PAITensorFlow模型如何导出为SavedModel-云淘科技

详情页1

本文为您介绍如何将TensorFlow模型导出为SavedModel格式。

SavedModel格式

使用EAS预置官方Processor将TensorFlow模型部署为在线服务,必须先将模型导出为官方定义的SavedModel格式(TensorFlow官方推荐的导出模型格式)。SavedModel模型格式的目录结构如下。

assets/
variables/
    variables.data-00000-of-00001
    variables.index
saved_model.pb|saved_model.pbtxt

其中:

  • assets表示一个可选目录,用于存储预测时的辅助文档信息。

  • variables存储tf.train.Saver保存的变量信息。

  • saved_model.pbsaved_model.pbtxt存储MetaGraphDef(存储训练预测模型的程序逻辑)和SignatureDef(用于标记预测时的输入和输出)。

导出SavedModel

使用TensorFlow导出SavedModel格式的模型请参见Saving and Restoring。如果模型比较简单,则可以使用如下方式快速导出SavedModel。

tf.saved_model.simple_save(
  session,
  "./savedmodel/",
  inputs={"image": x},   ## x表示模型的输入变量。
  outputs={"scores": y}  ## y表示模型的输出。
)

请求在线预测服务时,请求中需要指定模型signature_name,使用simple_save()方法导出的模型中,signature_name默认为serving_default

如果模型比较复杂,则可以使用手工方式导出SavedModel,代码示例如下。

print('Exporting trained model to', export_path)
  builder = tf.saved_model.builder.SavedModelBuilder(export_path)
  tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
  tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
  prediction_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={'images': tensor_info_x},
          outputs={'scores': tensor_info_y},
          method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
  legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
  builder.add_meta_graph_and_variables(
      sess, [tf.saved_model.tag_constants.SERVING],
      signature_def_map={
          'predict_images':
              prediction_signature,
      },
      legacy_init_op=legacy_init_op)
  builder.save()
  print('Done exporting!')

其中:

  • export_path表示导出模型的路径。

  • prediction_signature表示模型为输入和输出构建的SignatureDef,详情请参见SignatureDef。示例中的signature_name为predict_images

  • builder.add_meta_graph_and_variables方法表示导出模型的参数。

说明

  • 导出预测所需的模型时,必须指定导出模型的Tag为tf.saved_model.tag_constants.SERVING。

  • 有关TensorFlow模型的更多信息,请参见TensorFlow SavedModel。

Keras模型转换为SavedModel

使用Keras的model.save()方法会将Keras模型导出为H5格式,需要将其转换为SavedModel才能进行在线预测。您可以先调用load_model()方法加载H5模型,再将其导出为SavedModel格式,代码示例如下。

import tensorflow as tf
with tf.device("/cpu:0"):
    model = tf.keras.models.load_model('./mnist.h5')
    tf.saved_model.simple_save(
      tf.keras.backend.get_session(),
      "./h5_savedmodel/",
      inputs={"image": model.input},
      outputs={"scores": model.output}
    )

Checkpoint转换为Savedmodel

训练过程中使用tf.train.Saver()方法保存的模型格式为checkpoint,需要将其转换为SavedModel才能进行在线预测。您可以先调用saver.restore()方法将Checkpoint加载为tf.Session,再将其导出为SavedModel格式,代码示例如下。

import tensorflow as tf
# variable define ...
saver = tf.train.Saver()
with tf.Session() as sess:
  # Initialize v1 since the saver will not.
    saver.restore(sess, "./lr_model/model.ckpt")
    tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
    tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
    tf.saved_model.simple_save(
      sess,
      "./savedmodel/",
      inputs={"image": tensor_info_x},
      outputs={"scores": tensor_info_y}
    )

内容没看懂? 不太想学习?想快速解决? 有偿解决: 联系专家

阿里云企业补贴进行中: 马上申请

腾讯云限时活动1折起,即将结束: 马上收藏

同尘科技为腾讯云授权服务中心。

购买腾讯云产品享受折上折,更有现金返利:同意关联,立享优惠

转转请注明出处:https://www.yunxiaoer.com/165353.html

(0)
上一篇 2023年12月10日
下一篇 2023年12月10日
详情页2

相关推荐

联系我们

400-800-8888

在线咨询: QQ交谈

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息

关注微信
本站为广大会员提供阿里云、腾讯云、华为云、百度云等一线大厂的购买,续费优惠,保证底价,买贵退差。