TensorFlow入门使用 tf.train.Saver()保存模型

关于模型保存的一点心得

tf.train.Saver

注意:saver()与restore()只是保存了session中的相关变量对应的值,并不涉及模型的结构。


saver = tf.train.Saver(max_to_keep=3)

Saver

Defined in
tensorflow/python/training/saver.py.
See the guides: Exporting and Importing a MetaGraph > Exporting a
Complete Model to
MetaGraph
,
Exporting and Importing a
MetaGraph
,
Variables > Saving and Restoring
Variables

Saves and restores variables.
See
Variables
for an overview of variables, saving and restoring.

Saver的作用是将我们训练好的模型的参数保存下来,以便下一次继续用于训练或测试;Restore则是将训练好的参数提取出来。Saver类训练完后,是以checkpoints文件形式保存。提取的时候也是从checkpoints文件中恢复变量。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值

一般地,Saver会自动的管理Checkpoints文件。我们可以指定保存最近的N个Checkpoints文件,当然每一步都保存ckpt文件也是可以的,只是没必要,费存储空间。

  • saver()可以选择global_step参数来为ckpt文件名添加数字标记:

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
  • max_to_keep参数定义saver()将自动保存的最近n个ckpt文件,默认n=5,即保存最近的5个检查点ckpt文件。若n=0或者None,则保存所有的ckpt文件。
  • keep_checkpoint_every_n_hoursmax_to_keep类似,定义每n小时保存一个ckpt文件。

...
# Create a saver.
saver = tf.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.Session()
for step in xrange(1000000):
    sess.run(..training_op..)
    if step % 1000 == 0:
        # Append the step number to the checkpoint name:
        saver.save(sess, 'my-model', global_step=step)

一个简单的例子:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

import tensorflow as tf
import time

time.clock()

x = tf.placeholder(tf.float32 ,[None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)

# 为了计算交叉熵,我们需要添加一个新的占位符用于输入正确值。
y_ = tf.placeholder(tf.float32, [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
# 在此,我们要求TF使用梯度下降算法,并以0.01的学习速率最小化交叉熵。
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

# 创建Saver节点,并设置自动保存最近n=1次模型
saver = tf.train.Saver(max_to_keep=1)
saver_max_acc = 0 
for i in range(100):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_:batch_ys})
    correct_prediction = tf.equal(tf.argmax(y,1), tf.arg_max(y_, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
    if (i+1)%10 == 0:
        print('{0:0>2d}:{1:.4f}'.format((i+1),accuracy.eval(session=sess, feed_dict={x: mnist.test.images, y_: mnist.test.labels})))
    # 添加判断语句,选择保存精度最高的模型
    if accuracy > saver_max_acc:
        saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
        saver_max_acc = accuracy
sess.close()
print(time.clock())

在定义 saver
的时候一般会定义最多保存模型的数量,一般来说,如果模型本身很大,我们需要考虑到硬盘大小。如果你需要在当前训练好的模型的基础上进行
fine-tune,那么尽可能多的保存模型,后继 fine-tune 不一定从最好的 ckpt
进行,因为有可能一下子就过拟合了。但是如果保存太多,硬盘也有压力呀。如果只想保留最好的模型,方法就是每次迭代到一定步数就在验证集上计算一次
accuracy 或者 f1
值,如果本次结果比上次好才保存新的模型,否则没必要保存。

Restore

restore(sess, save_path)
# sess: A Session to use to restore the parameters.
# save_path: Path where parameters were previously saved.
  • sess: 保存参数的会话。
  • save_path: 保存参数的路径。
  • 当从文件中恢复变量时,不需要事先对他们进行初始化,因为“恢复”自身就是一种初始化变量的方法。
  • 可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)

如果你想用不同 epoch 保存下来的模型进行融合的话,3到5
个模型已经足够了,假设这各融合的模型成为 M,而最好的一个单模型称为
m_best, 这样融合的话对于M 确实可以比 m_best
更好。但是如果拿这个模型和其他结构的模型再做融合的话,M 的效果并没有
m_best 好,因为M 相当于做了平均操作,减少了该模型的“特性”。

参考资料:

  1. tensorflow 1.0
    学习:模型的保存与恢复(Saver)
  2. 莫烦 Tensorflow 19 Saver 保存读取 (神经网络
    教学教程tutorial)
  3. TensorFlow手把手入门之 —
    TensorFlow保存还原模型的正确方式,Saver的save和restore方法,亲测可用

但是又有一种新的融合方式,就是利用调整学习率来获取多个局部最优点,就是当
loss 降不下了,保存一个 ckpt,
然后开大学习率继续寻找下一个局部最优点,然后用这些 ckpt
来做融合,还没试过,单模型肯定是有提高的,就是不知道还会不会出现上面再与其他模型融合就没提高的情况。

如何使用 tf.train.Saver() 来保存模型

之前一直出错,主要是因为坑爹的编码问题。所以要注意文件的路径绝对不不要出现什么中文呀。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)

Model saved in file: ./ckpt/test-model.ckpt-1

注意,在上面保存完了模型之后。应该把 kernel restart
之后才能使用下面的模型导入。否则会因为两次命名 “v1” 而导致名字错误。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]
55.5

导入模型之前,必须重新再定义一遍变量。

但是并不需要全部变量都重新进行定义,只定义我们需要的变量就行了。

也就是说,你所定义的变量一定要在 checkpoint
中存在;但不是所有在checkpoint中的变量,你都要重新定义。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.          2.29999995]

tf.Saver([tensors_to_be_saved]) 中可以传入一个 list,把要保存的
tensors 传入,如果没有给定这个list的话,他会默认保存当前所有的
tensors。一般来说,tf.Saver 可以和 tf.variable_scope()
巧妙搭配,可以参考:
【迁移学习】往一个已经保存好的模型添加新的变量并进行微调

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

您可能感兴趣的文章: