Transfer learning#
One functionality that physicsml
inherits from molflux[lightning]
is the ability to do transfer learning. It has been
empirically shown that in many cases training a model on a large diverse dataset and then fine tuning it on a small dataset
leads to better performance. At its core, transfer learning improves performance by effectively providing better weight initialisations
for models to begin training from.
To help you perform transfer learning smoothly, we have created a config driven transfer learning functionality. In this section, we go through the config template and explain the functions of each part.
For a tutorial on transfer learning, see Transfer learning tutorial.
Transfer learning config#
The transfer learning config has the following form
pre_trained_model_path
: Specifies the path to a pre-trained model.modules_to_match
: Specified how to match the weights.stages
: Specifies how the transfer learning stages are run.
pre_trained_model_path
#
This must be a path to a pre-trained physicsml
model. Make sure that this points to a model with the same architecture
and hyperparameters as the model you are training. If it is not, then matching the weights will fail.
modules_to_match
#
This is a dictionary specifying how to match the weights from the pre-trained model to the new model. The dictionary has the following format
{
"module_name_in_pre_trained_model": "module_name_in_new_model",
}
You can specify any of the names of the child modules as well (in the format module.submodule.subsubmodule
). Any
child modules of the higher-level modules you specify will also be matched. If you do not specify this dictionary (which is
defaulted to None
), then all modules will be matched.
If the weights cannot be matched (because of a wrong module name or wrong parameters shape), then it will fail with an error specifying where it failed.
stages
#
Finally, the stages
specifies the structure of the transfer learning. In general, transfer learning is
done in consecutive stages. Each stage is essentially a separate training run which reserves its final model weights for the
next stage. In each of these stages, you can override any of the training kwargs for the trainer
, datamodule
,
optimizer
, and scheduler
. For each of these, any specified kwargs will be overridden and any unspecified kwargs will
be taken from the initial definitions of these configs. This allows you to control how the stages run. Importantly, you
can also specify a list of module names in freeze_modules
whose weights you would like to freeze in that stage.
The complete stage
config looks like
{
"freeze_modules": [...],
"trainer": {},
"datamodule": {},
"optimizer": {},
"scheduler": {},
}