In the restoring model section above if I understand correctly you build the model and then restore the variables. I believe rebuilding the model is not necessary so long as you add the relevant tensors/placeholders when saving using tf.add_to_collection()
. For example:
tf.add_to_collection('cost_op', cost_op)
Then later you can restore the saved graph and get access to cost_op
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('model.meta')`
new_saver.restore(sess, 'model')
cost_op = tf.get_collection('cost_op')[0]
Even if you don't run tf.add_to_collection()
, you can retrieve your tensors, but the process is a bit more cumbersome, and you may have to do some digging to find the right names for things. For example:
in a script that builds a tensorflow graph, we define some set of tensors lab_squeeze
with tf.variable_scope("inputs"):
split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
with tf.Session().as_default() as sess:
we can recall them later on as follows:
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')`
new_saver.restore(sess, './checkpoint.chk')
There are a number of ways to find the name of a tensor -- you can find it in your graph on tensor board, or you can search through for it with something like:
[ for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph
Saving a model in tensorflow is pretty easy.
Let's say you have a linear model with input x
and want to predict an output y
. The loss here is the mean square error (MSE). The batch size is 16.
# Define the model
x = tf.placeholder(tf.float32, [16, 10]) # input
y = tf.placeholder(tf.float32, [16, 1]) # output
w = tf.Variable(tf.zeros([10, 1]), dtype=tf.float32)
res = tf.matmul(x, w)
loss = tf.reduce_sum(tf.square(res - y))
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
Here comes the Saver object, which can have multiple parameters (cf. doc).
# Define the tf.train.Saver object
# (cf. params section for all the parameters)
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
Finally we train the model in a tf.Session()
, for 1000
iterations. We only save the model every 100
iterations here.
# Start a session
max_steps = 1000
with tf.Session() as sess:
# initialize the variables
for step in range(max_steps):
feed_dict = {x: np.random.randn(16, 10), y: np.random.randn(16, 1)} # dummy input
_, loss_value =[train_op, loss], feed_dict=feed_dict)
# Save the model every 100 iterations
if step % 100 == 0:, "./model", global_step=step)
After running this code, you should see the last 5 checkpoints in your directory:
and model-500.meta
and model-600.meta
and model-700.meta
and model-800.meta
and model-900.meta
Note that in this example, while the saver
actually saves both the current values of the variables as a checkpoint and the structure of the graph (*.meta
), no specific care was taken w.r.t how to retrieve e.g. the placeholders x
and y
once the model was restored. E.g. if the restoring is done anywhere else than this training script, it can be cumbersome to retrieve x
and y
from the restored graph (especially in more complicated models). To avoid that, always give names to your variables / placeholders / ops or think about using tf.collections
as shown in one of the remarks.
Restoring is also quite nice and easy.
Here's a handy helper function:
def restore_vars(saver, sess, chkpt_dir):
""" Restore saved net, global score and step, and epsilons OR
create checkpoint directory for later storage. """
checkpoint_dir = chkpt_dir
if not os.path.exists(checkpoint_dir):
print("making checkpoint_dir")
return False
except OSError:
path = tf.train.get_checkpoint_state(checkpoint_dir)
print("path = ",path)
if path is None:
return False
saver.restore(sess, path.model_checkpoint_path)
return True
Main code:
path_to_saved_model = './'
max_steps = 1
# Start a session
with tf.Session() as sess:
... define the model here ...
print("define the param saver")
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
# restore session if there is a saved checkpoint
print("restoring model")
restored = restore_vars(saver, sess, path_to_saved_model)
print("model restored ",restored)
# Now continue training if you so choose
for step in range(max_steps):
# do an update on the model (not needed)
loss_value =[loss])
# Now save the model, "./model", global_step=step)