自定义模型的保存与加载
TensorFlow2中自定义Model/Layer的保存与加载。
保存/加载带有自定义层的模型或子类化模型分成两步:
-
您应该重写
get_config
和from_config
(可选)方法。1
2
3
4
5
6
7
8
9def get_config(self):
'''将返回一个包含模型配置的 Python 字典。'''
config = super(Attention, self).get_config()
config.update({
'use_W': self.use_W,
'return_self_attend': self.return_self_attend,
'return_attend_weight': self.return_attend_weight
})
return config -
您应该注册自定义对象,以便 Keras 能够感知它
1
2
3custom_objects = {'Attention': Attention}
self.model = tf.keras.models.load_model(self.config.checkpoint_file, custom_objects=custom_objects)
print(self.model.summary())