Home > Backend Development > Python Tutorial > Getting started with TensorFlow and using tf.train.Saver() to save the model

Getting started with TensorFlow and using tf.train.Saver() to save the model

不言
Release: 2018-04-24 14:15:07
Original
4127 people have browsed it

This article mainly introduces how to save the model using tf.train.Saver() to get started with TensorFlow. Now I will share it with you and give you a reference. Let’s take a look together

Some thoughts on model saving

saver = tf.train.Saver(max_to_keep=3)
Copy after login

When defining the saver, the maximum number of saved models is usually defined. Generally speaking, if the model itself is large, we need to consider the hard disk size. If you need to perform fine-tune on the basis of the currently trained model, then save as many models as possible. Subsequent fine-tune may not necessarily be performed from the best ckpt, because it may be overfitted all of a sudden. But if you save too many files, the hard disk will be under pressure. If you only want to keep the best model, the method is to calculate the accuracy or f1 value on the validation set every time it iterates to a certain number of steps. If the result this time is better than the last time, save the new model. Otherwise, there is no need to save it.

If you want to use models saved in different epochs for fusion, 3 to 5 models are enough. Assume that the fused models become M, and the best single model is called m_best, so Fusion can indeed be better than m_best for M. But if you fuse this model with models of other structures, the effect of M is not as good as m_best, because M is equivalent to an average operation, which reduces the "characteristics" of the model.

But there is a new fusion method, which is to use adjusting the learning rate to obtain multiple local optimal points. That is, when the loss cannot be reduced, save a ckpt, and then increase the learning rate to continue to find the next local optimal point. Advantages, and then use these ckpt for fusion. I haven't tried it yet. The single model will definitely be improved, but I don't know if there will be a situation where the above improvement will not be improved when combined with other models.

How to use tf.train.Saver() to save the model

I have been getting errors before, mainly because of cheating coding issues. So be careful not to have any Chinese characters in the file path.

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)
Copy after login

Model saved in file: ./ckpt/test-model.ckpt-1

Note that after saving the model above. You should restart the kernel before using the following model to import. Otherwise the name will be wrong by naming "v1" twice.

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)
Copy after login

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.              2.29999995]
55.5

Before importing the model, you must redefine the variables.

But it is not necessary to redefine all variables, just define the variables we need.

In other words, the variables you define must exist in the checkpoint; but not all variables in the checkpoint must be redefined.

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
Copy after login

INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1.         2.29999995]

tf.Saver([tensors_to_be_saved]) You can pass in a list and pass in the tensors to be saved. If this list is not given, it will default to Save all current tensors. Generally speaking, tf.Saver can be cleverly combined with tf.variable_scope(). You can refer to: [Transfer Learning] Add new variables to an already saved model and fine-tune it

Related recommendations:

About the tf.train.batch function in Tensorflow

The above is the detailed content of Getting started with TensorFlow and using tf.train.Saver() to save the model. For more information, please follow other related articles on the PHP Chinese website!

Related labels:
source:php.cn
Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Popular Tutorials
More>
Latest Downloads
More>
Web Effects
Website Source Code
Website Materials
Front End Template