Generic Training
Deep.Net contains a powerful, generic function to train your model. Together with the dataset handler it provides the following functionality:
- initialization of the model's parameters
- mini-batch training
- logging of losses on the training, validation and test sets
- automatic scheduling of the learning rate
-
termination of training when
- a desired validation loss is reached
- a set number of iterations have been performed
- there is no loss improvement on the validation set within a set number of iterations
- checkpointing allows the training state to be saved to disk and training to be restarted afterwards (useful when running on non-reliable hardware or on a compute cluster that pauses jobs or moves them around on the cluster's nodes)
Example model
To demonstrate its use we return to our two-layer neural network model for classifying MNIST digits.
1: 2: 3: 4: 5: |
|
We load the MNIST dataset using the Mnist.load
function using a validation to training ratio of 0.1.
1: 2: |
|
Next, we define and instantiate a model using the MLP (multi-layer perceptron, i.e. multi-layer neural network) component.
1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14: 15: 16: 17: 18: 19: 20: 21: 22: 23: 24: 25: 26: 27: |
|
Note that the input and target matrices must be transposed, since the neural network model expects each sample to be a column in the matrix while the dataset provides a matrix where each row is a sample.
We instantiate the Adam optimizer to minimize the loss and use its default configuration.
1: 2: 3: |
|
In previous example we have written a simple optimization loop by hand. Here instead, we will employ the generic training function provided by Deep.Net.
Defining a Trainable
The generic training function works on any object that implements the Train.ITrainable<'Smpl, 'T>
interface where 'Smpl
is a sample record type (see dataset handling) and 'T
is the data type of the model parameters, e.g. single
.
The easiest way to create an ITrainable from a symbolic loss expression is to use the Train.trainableFromLossExpr
function.
This function has the signature
1: 2: 3: 4: 5: 6: |
|
The arguments have the following meaning.
modelInstance
is the model instance containing the parameters of the model to be trained.loss
is the loss expression to be minimized.varEnvBuilder
is a user-provided function that takes an instance of user-provided type'Smpl
and returns a variable environment to evaluate the loss expression on this sample(s). The sample below shows how to build a variable environment from a sample.optimizer
is an instance of an optimizer. All optimizers in Deep.Net implement theIOptimizer
interface.optCfg
is the optimizer configuration to use. The learning rate in the specified optimizer configuration will be overwritten.
Let us build a trainable for our model. First, we need to define a function that creates a variable environment from a sample.
1: 2: 3: 4: |
|
The value of the symbolic variable input
is set to the image of the MNIST sample and the symbolic variable target
is set to the label in one-hot encoding.
We are now ready to construct the trainable.
1: 2: |
|
Training configuration
Next, we need to specify the training configuration using the Train.Cfg
record type.
For illustration purposes we write down the whole record instance; in practice you would copy Train.defaultCfg
and change fields as necessary.
1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: |
|
The meaning of the fields is as follows.
- Seed is the random seed for model parameter initialization.
- BatchSize is the size of mini-batches used for training and evaluating the losses.
- LossRecordInterval is the number of iterations to perform between evaluating the loss on the validation and test sets.
-
Termination is the termination criterium and can have the following values:
Train.ItersWithImprovements cnt
to stop training aftercnt
iteraitons without improvement.Train.IterGain gain
to train for \(\mathrm{gain} \cdot \mathrm{bestIter}\) iterations where \(\mathrm{bestIter}\) is the best iteration. Usually one would use \(\mathrm{gain} \approx 2.0\).Train.Forever
disables the termination criterium.
- MinImprovement is the minimum loss change to count as improvement and should be a small number.
- TargetLoss can be used to specify a target validation loss that stops training when achieved. Use
Some loss
orNone
. - MinIters can be the minimum number of training iterations to perform in the form
Some iters
, orNone
. - MaxIters can be a hard limit on the training iterations in the form
Some iters
, orNone
. - LearningRates is a list of learning rates to use. Training starts with the first element and moves to the next one, when the termination criterium (specified by the field Termination) is triggered.
- CheckpointDir may specify a directory in the form
Some dir
. (see checkpoint section for details) - DiscardCheckpoint prohibits loading of a checkpoint if it is
true
.
Performing the training
Now training can be performed by calling the Train.train
function.
It takes three arguments: a trainable, the dataset to use and the training configuration.
The dataset was already loaded above.
1:
|
|
This will produce output similar to
1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14: 15: 16: 17: 18: 19: 20: 21: 22: 23: 24: 25: 26: 27: 28: 29: 30: 31: 32: 33: 34: |
|
While training is executed you can press the q
key to stop training immediately and the d
key to switch to the next learning rate specified in the configuration.
During training the parameters that produce the best validation loss are saved each time the losses are evaluated (as set by the LossRecordInterval
field in the training configuration).
When the validation loss does not improve for the set number of iterations (field Termination
in the training configuration), the best parameters are restored and the next learning rate (field LearningRates
) from the configuration is used.
This explains why the iteration count resets by 100 steps, each time the loss stops improving.
The best validation lost is achieved around iteration 400, then the model starts to overfit. Decreasing the learning rate does not help in this case, thus training is terminated after exhausting the list of learning rates.
Training result and log
The return value of Train.train
is a record of type TrainingResult
that contains the training results and the training log.
1: 2: 3: |
|
This prints
1: 2: 3: 4: 5: 6: 7: 8: |
|
It is possible to save the training result as a JSON file by calling result.Save
.
This is useful when you use software or scripts to gather and analyze the results of multiple experiments.
Checkpointing
Checkpoint allows to training process to be interrupted and resumed later.
To enable checkpoint support, set the CheckpointDir
of the configuration record to some suitable directory.
This directory has to be unique for each process.
When checkpoint support is enabled, the training functions traps the CTRL+C and CTRL+BREAK signals. When such a signal is received, the training state (including the model parameters) is stored in the specified directory and the process is terminated with exit code 10. In this case, the training function does not return to the user code.
When the program is executed again and the training function is called, it checks for a valid checkpoint. If one is found, it is loaded and training resumes where it was interrupted.
To discard an existing checkpoint (for example if training or models parameters were changed), set DiscardCheckpoint
to true.
This will delete any existing checkpoints from disk and restart training from the beginning.
Summary
With the generic training function you can train any model that has a loss expression. The main effort is to write a small wrapper function that maps a training sample to a variable environment. Various termination criteria, common in machine learning, are implemented.
Full name: Training.mnist
Full name: Training.mb
val single : value:'T -> single (requires member op_Explicit)
Full name: Microsoft.FSharp.Core.ExtraTopLevelOperators.single
--------------------
type single = System.Single
Full name: Microsoft.FSharp.Core.single
Full name: Training.nBatch
Full name: Training.nInput
Full name: Training.nClass
Full name: Training.nHidden
Full name: Training.mlp
Full name: Training.input
Full name: Training.target
Full name: Training.mi
Full name: Training.loss
Full name: Training.opt
Full name: Training.optCfg
Full name: Training.smplVarEnv
Full name: Training.trainable
Full name: Training.trainCfg
Full name: Training.result
Full name: Microsoft.FSharp.Core.ExtraTopLevelOperators.printfn
module List
from Microsoft.FSharp.Collections
--------------------
type List<'T> =
| ( [] )
| ( :: ) of Head: 'T * Tail: 'T list
interface IEnumerable
interface IEnumerable<'T>
member GetSlice : startIndex:int option * endIndex:int option -> 'T list
member Head : 'T
member IsEmpty : bool
member Item : index:int -> 'T with get
member Length : int
member Tail : 'T list
static member Cons : head:'T * tail:'T list -> 'T list
static member Empty : 'T list
Full name: Microsoft.FSharp.Collections.List<_>
Full name: Microsoft.FSharp.Collections.List.length
Full name: Microsoft.FSharp.Core.Operators.log