Model checkpointing using meta-graphs in TensorFlow

Dec 01, 2016

Deep learning models usually take a while to train, even on GPUs, so being able to checkpoint intermediate stages in training is really important. Frameworks like Keras usually offer functionality to store learned variables like weights and biases, but for resuming training for a checkpoint, one might also need to restore the optimizer state (that will be the case if you use any optimizer that stores any kind of state, like Adam). Ideally, we would like to store everything it needs and not have to reinstantiate the entire graph again, and this is possible in TensorFlow by using meta-graphs.

After checking the documentation on meta-graphs, I tried to quickly hack an example but there were a few gotchas:

I found an example that illustrates how to solve the first two issues on Stack Overflow. The latter one, while trivial, is also a common use case when you train your models in environments you cannot control (like Amazon EC2, Microsoft Azure, your university cluster, or even your own computer if you have a power failure!), so I decided to prepare a simple example that covers that case as well. The example is based on the logistic regression tutorial by Aymeric Damian (which, by the way, is a great resource if you are just learning TensorFlow).

To test the example, do the following:

  1. Run the file without any arguments (python logistic_regression_with_checkpointing.py). It will run for 5 epochs and save checkpoints for each epoch.

  2. Run file again, now passing --load True --max_epochs 10 as arguments. The script will detect it has already trained for 5 epochs, and run for another 5 epochs.

You will end up with a bunch of files called model.ckpt-{epoch} and model.ckpt-{number}.meta. The former will contain your variable values, and the latter (the ones ending in .meta) store the meta-graph. When running with --load True, the code will determine when we last saved the model (lines 62-65), create a new tf.train.Saver object based on the exported meta-graph, and later restore the model parameters and also placeholders and operations we need in order to continue training (in this case, the placeholders x and y and the operations to compute the cost function, update the model, and do inference with it).

I wrote this example for myself but hope it will be useful for other people. If you are reading this, that is probably the case!

This entry was tagged as development

blog comments powered by Disqus