In our last post we gave a basic introduction to TensorFlow 1.0. What we want to do now is take our foundation and move it forward. One of the most important parts of deep learning is understanding what is going on while the code is running. As our problems get more complicated and our datasets get larger, training time can go from minutes to days. If we’ve picked a model with poor hyper-parameters or just a bad model in general, we don’t want to have to wait hours to make an adjustment to our model. Or if we have great hyperparameters and models, but don’t tell the model to train for enough steps we don't want to start from scratch. Or do we…
We at Bitfusion won’t tell you how to do you, but if you want to understand how your model is doing while it is running and you want to save the model weights as they are being defined, monitoring and checkpointing is for you. If you want more details, as always, you can consult TensorFlow’s page.
Quick Note on Following the Code
In the intro to TensorFlow blog post, we talked about using code from previous blog posts for new blog posts. What we are going to do is copy the code from the 01-intro-to-tensorflow folder as the starting point to the 02-monitoring-and-checkpointing, then make changes and commit them. So if you want a true step-by-step you can look at the commit history for the files in 02-monitoring-and-checkpointing. For this example it will be a bit overkill, but it should help you get used to the method for future posts.
Checkpointing Our MNIST Neural Network
The first thing we want to do is convert our code so that the models we are running are periodically saving their output. This is foundational if you want to:
- Apply the model to new data
- Stop and continue the training of the network
- Use transfer learning to apply the learned weights to new applications
How do we change our code to allow this complicated saving and restoring? Well for now, we will just add 2 lines of code and let tf.contrib.learn take care of the rest. This line:
<code class="hljs dos"><span class="hljs-built_in">classifier = learn.Estimator(model_fn=fully_connected_model)</span></code>
classifier = learn.Estimator(model_fn=fully_connected_model,
You’re done. Now the model will save checkpoint files into the output directory (and create the directory if it does not yet exist) every 10 seconds.
Monitoring Our MNIST Neural Network
The next thing that we want to do in this tutorial is throw some helpful validation metrics as our model trains. What our code will do is periodically use the model and weights that we have trained so far and apply them to our validation dataset. To enable this we first need to create a validation monitor. We do this with the following code:
validation_monitor = learn.monitors.ValidationMonitor(
The validation monitor needs to know where the data is, what metrics it needs to run (as you saw in the last post, it defaults to global_steps/sec, step, and loss). We add accuracy as a metric and also tell it to run the validation monitor every 500 steps. The last thing we need to do is tell the “fit” call to use the validation monitor. We do this by simply adding one line to the classifier.fit code as shown below:
And with that, we have added a validation monitor and model checkpointing.
How to Run Our Changes
For setting up the code for the blog, please consult the “MNIST Neural Net in TensorFlow 1.0 - Getting Started” section of the first blog post. The only change you need to make is changing the directory to 02-monitoring-and-checkpointing rather than 01-intro-to-tensorflow.
For now we are supporting both the jupyter notebook and running from the command line interface. In order to run our code for a certain number of steps (which helps illustrate restoring from a checkpoint), we are going to add a TensorFlow flag in our code. In the model.py file, we add the following code:
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('num_steps', 10000, 'Denotes the number of steps for the model to train for.')
This will create a TensorFlow application flag that defaults to 10,000. We also need to change the classifier.fit code to use the new runtime flag that we defined.
All done. Now we can run the code for a certain number of steps by running the following code from the command line.
<code class="hljs dos"><span class="hljs-built_in">python model.py --num_steps=2500</span></code>
Running this code will yield the following:
Saving checkpoints for 1 into ./output/model.ckpt.
loss = 2.34152, step = 1
loss = 1.66204, step = 101
loss = 1.01449, step = 201
loss = 0.731237, step = 301
loss = 0.609014, step = 401
Starting evaluation at 2017-03-16-22:31:49
Finished evaluation at 2017-03-16-22:31:49
Saving dict for global step 1: accuracy = 0.1088, global_step = 1, loss = 2.34772
Validation (step 500): loss = 2.34772, global_step = 1, accuracy = 0.1088
We can see that the code is saving into the correct directory and also that it is running validation accuracy. Now we will run model through the CLI again for 2500 steps and see the following:
Saving dict for global step 2501: accuracy = 0.9304, global_step = 2501, loss = 0.248158
Validation (step 2501): loss = 0.248158, global_step = 2501, accuracy = 0.9304
loss = 0.280427, step = 2601
loss = 0.289, step = 2701
loss = 0.193591, step = 2801
loss = 0.354609, step = 2901
loss = 0.236409, step = 3001
loss = 0.331802, step = 3101
loss = 0.281791, step = 3201
Saving checkpoints for 3222 into ./output/model.ckpt.
loss = 0.296924, step = 3301
loss = 0.199314, step = 3401
loss = 0.291136, step = 3501
Starting evaluation at 2017-03-16-22:32:42
Finished evaluation at 2017-03-16-22:32:43
Saving dict for global step 3222: accuracy = 0.9375, global_step = 3222, loss = 0.226606
Validation (step 3501): loss = 0.226606, global_step = 3222, accuracy = 0.9375
Great! It restored from the checkpoint file and continued training and checking against the validation set.
So to summarize what we learned:
- How to save our model as it trains
- How to show how our model is doing while it is training
- How to use tf.FLAGS as a CLI tool
The next post will take this concept a little further. We will start playing with tensorboard and how you can use it to visualize what the network is learning.