自定义模型的保存与加载

TensorFlow2中自定义Model/Layer的保存与加载。

保存/加载带有自定义层的模型或子类化模型分成两步:

  1. 您应该重写 get_configfrom_config(可选)方法。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    def get_config(self):
    '''将返回一个包含模型配置的 Python 字典。'''
    config = super(Attention, self).get_config()
    config.update({
    'use_W': self.use_W,
    'return_self_attend': self.return_self_attend,
    'return_attend_weight': self.return_attend_weight
    })
    return config
  2. 您应该注册自定义对象,以便 Keras 能够感知它

    1
    2
    3
    custom_objects = {'Attention': Attention}
    self.model = tf.keras.models.load_model(self.config.checkpoint_file, custom_objects=custom_objects)
    print(self.model.summary())