Transfer learning#
In this tutorial, we provide an example of using the transfer learning functionality. We require the rdkit
package, so make sure to pip install 'physicsml[rdkit]'
to follow along!
We will train an EGNN model on the lumo
energy of QM9 and then transfer this model to predict the
u0
energy.
Pre-trained model#
First, let’s load a truncated QM9 dataset with 1000 datapoints
import numpy as np
from molflux.datasets import load_dataset_from_store
dataset = load_dataset_from_store("gdb9_trunc.parquet")
print(dataset)
idxs = np.random.permutation(range(len(dataset)))
dataset = dataset.select(idxs[:1000])
print(dataset)
Dataset({
features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom'],
num_rows: 1000
})
Dataset({
features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom'],
num_rows: 1000
})
Note
The dataset above is a truncated version to run efficiently in the docs. For running this locally, load the entire dataset by doing
from molflux.datasets import load_dataset
dataset = load_dataset("gdb9", "rdkit")
You can see that there is the mol_bytes
column (which is the rdkit
serialisation of the 3d molecules) and the
remaining columns of computes properties.
Next, we will featurise the dataset. In this example, we start by using the atomic numbers only (since we do not require
self energies for the lumo
energy).
import logging
logging.disable(logging.CRITICAL)
from molflux.core import featurise_dataset
featurisation_metadata = {
"version": 1,
"config": [
{
"column": "mol_bytes",
"representations": [
{
"name": "physicsml_features",
"config": {
"atomic_number_mapping": {
1: 0,
6: 1,
7: 2,
8: 3,
9: 4,
},
"backend": "rdkit",
},
"as": "{feature_name}"
}
]
}
]
}
# featurise the mols
featurised_dataset = featurise_dataset(
dataset,
featurisation_metadata=featurisation_metadata,
num_proc=4,
batch_size=100,
)
print(featurised_dataset)
Dataset({
features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates'],
num_rows: 1000
})
You can see that we now have the extra columns for
physicsml_atom_idxs
: The index of the atoms in the moleculesphysicsml_atom_numbers
: The atomic numbers mapped using the mapping dictionaryphysicsml_coordinates
: The coordinates of the atoms
Next, we need to split the dataset. For this, we use the simple shuffle_split
(random split) with 80% training and
20% test. To split the dataset, we use the split_dataset
function from molflux.datasets
.
from molflux.datasets import split_dataset
from molflux.splits import load_from_dict as load_split_from_dict
shuffle_strategy = load_split_from_dict(
{
"name": "shuffle_split",
"presets": {
"train_fraction": 0.8,
"validation_fraction": 0.0,
"test_fraction": 0.2,
}
}
)
split_featurised_dataset = next(split_dataset(featurised_dataset, shuffle_strategy))
print(split_featurised_dataset)
DatasetDict({
train: Dataset({
features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates'],
num_rows: 800
})
validation: Dataset({
features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates'],
num_rows: 0
})
test: Dataset({
features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates'],
num_rows: 200
})
})
We can now turn to pre training the model! To do so, we need to define the model config and the x_features
and
the y_features
. Once trained, we will save it to be used for transferring.
import json
from molflux.modelzoo import load_from_dict as load_model_from_dict
from molflux.core import save_model
model = load_model_from_dict(
{
"name": "egnn_model",
"config": {
"x_features": [
'physicsml_atom_idxs',
'physicsml_atom_numbers',
'physicsml_coordinates',
],
"y_features": ['lumo'],
"num_node_feats": 5,
"num_layers": 2,
"c_hidden": 12,
"y_graph_scalars_loss_config": {
"name": "MSELoss",
},
"optimizer": {
"name": "AdamW",
"config": {
"lr": 1e-3,
}
},
"datamodule": {
"y_graph_scalars": ['lumo'],
"num_elements": 5,
"cut_off": 5.0,
"pre_batch": "in_memory",
"train": {"batch_size": 64},
"validation": {"batch_size": 128},
},
"trainer": {
"max_epochs": 10,
"accelerator": "cpu",
"logger": False,
}
}
}
)
model.train(
train_data=split_featurised_dataset["train"]
)
save_model(model, "pre_trained_model", featurisation_metadata)
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/torch/jit/_trace.py:763: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.
traced = torch._C._create_function_from_trace(
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/physicsml/lightning/pre_batching_in_memory.py:53: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:105: Total length of `dict` across ranks is zero. Please make sure this was your intention.
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/torch/storage.py:414: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
return torch.load(io.BytesIO(b))
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/core/module.py:516: You called `self.log('train/total/y_graph_scalars', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/core/module.py:516: You called `self.log('train/total/loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
'pre_trained_model'
We now have a dummy pre-trained model.
Transfer learning#
Finally, we come to the transfer learning. First, we need to re-featurise the dataset to include the atomic self energies
for 'u0'
and then split it for training (ignoring that this is the same dataset for pretraining and transferring, it’s only
for demonstration).
import logging
logging.disable(logging.CRITICAL)
from molflux.core import featurise_dataset
from molflux.datasets import split_dataset
from molflux.splits import load_from_dict as load_split_from_dict
featurisation_metadata = {
"version": 1,
"config": [
{
"column": "mol_bytes",
"representations": [
{
"name": "physicsml_features",
"config": {
"atomic_number_mapping": {
1: 0,
6: 1,
7: 2,
8: 3,
9: 4,
},
"atomic_energies": {
1: -0.6019805629746086,
6: -38.07749583990695,
7: -54.75225433326539,
8: -75.22521603087064,
9: -99.85134426752529
},
"backend": "rdkit",
},
"as": "{feature_name}"
}
]
}
]
}
# featurise the mols
featurised_dataset = featurise_dataset(
dataset,
featurisation_metadata=featurisation_metadata,
num_proc=4,
batch_size=100,
)
shuffle_strategy = load_split_from_dict(
{
"name": "shuffle_split",
"presets": {
"train_fraction": 0.8,
"validation_fraction": 0.0,
"test_fraction": 0.2,
}
}
)
split_featurised_dataset = next(split_dataset(featurised_dataset, shuffle_strategy))
print(split_featurised_dataset)
DatasetDict({
train: Dataset({
features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates', 'physicsml_total_atomic_energy'],
num_rows: 800
})
validation: Dataset({
features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates', 'physicsml_total_atomic_energy'],
num_rows: 0
})
test: Dataset({
features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates', 'physicsml_total_atomic_energy'],
num_rows: 200
})
})
Now we need to define the model config. This will be exactly the same as the pre-trained model config with lumo
substituted with u0
and with the addition of the transfer learning config. For more information about the
transfer learning config, see here.
model_config = {
"name": "egnn_model",
"config": {
"x_features": [
'physicsml_atom_idxs',
'physicsml_atom_numbers',
'physicsml_coordinates',
'physicsml_total_atomic_energy',
],
"y_features": ['u0'],
"num_node_feats": 5,
"num_layers": 2,
"c_hidden": 12,
"y_graph_scalars_loss_config": {
"name": "MSELoss",
},
"optimizer": {
"name": "AdamW",
"config": {
"lr": 1e-3,
}
},
"datamodule": {
"y_graph_scalars": ['u0'],
"num_elements": 5,
"cut_off": 5.0,
"pre_batch": "in_memory",
"train": {"batch_size": 64},
"validation": {"batch_size": 128},
},
"trainer": {
"max_epochs": 10,
"accelerator": "cpu",
"logger": False,
},
"transfer_learning": {
"pre_trained_model_path": "pre_trained_model",
"modules_to_match": {
"egnn": "egnn",
},
"stages": [
{
"freeze_modules": ["egnn"],
"datamodule": {
"train": {"batch_size": 128},
},
"optimizer": {
"config": {
"lr": 1e-2
}
}
},
{
"trainer": {
"max_epochs": 4,
},
"optimizer": {
"config": {
"lr": 1e-4
}
}
}
]
}
}
}
As you can see, we first specify the path to the pre-trained model. The EGNN model contains two main submodules: the
egnn
backbone (message passing) and the pooling_head
(for pooling and generating predictions). In this
example, we choose to match only the backbone (since the lumo
and u0
tasks are different and the pooling_head
will not contain any useful learnt information).
Next, we specify a two stage transfer learning. In the first, we choose to freeze the egnn
backbone and to train the
pooling_head
. Additionally, we override some kwargs (such as batch size and learning rate). Notice that you
only need to specify the kwargs you want to override (such as learning rate) and all the rest will be used from the
main config. In the second stage we train the whole model (no frozen modules) at a lower learning rate for less epochs.
So let’s run the training!
import logging
logging.disable(logging.NOTSET)
from molflux.modelzoo import load_from_dict as load_model_from_dict
model = load_model_from_dict(model_config)
model.train(
train_data=split_featurised_dataset["train"]
)
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/molflux/modelzoo/models/lightning/model.py:369: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
WARNING:molflux.modelzoo.models.lightning.model:Matched module 'egnn' in new module to 'egnn' in old module.
WARNING:molflux.modelzoo.models.lightning.model:Freezing module: egnn.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Loading `train_dataloader` to estimate number of stepping batches.
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/physicsml/lightning/pre_batching_in_memory.py:53: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
| Name | Type | Params | Mode
-----------------------------------------------------
0 | egnn | EGNN | 3.5 K | train
1 | pooling_head | PoolingHead | 481 | train
-----------------------------------------------------
481 Trainable params
3.5 K Non-trainable params
4.0 K Total params
0.016 Total estimated model params size (MB)
63 Modules in train mode
0 Modules in eval mode
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:105: Total length of `dict` across ranks is zero. Please make sure this was your intention.
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/torch/storage.py:414: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
return torch.load(io.BytesIO(b))
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/core/module.py:516: You called `self.log('train/total/y_graph_scalars', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/core/module.py:516: You called `self.log('train/total/loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
`Trainer.fit` stopped: `max_epochs=10` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
| Name | Type | Params | Mode
-----------------------------------------------------
0 | egnn | EGNN | 3.5 K | train
1 | pooling_head | PoolingHead | 481 | train
-----------------------------------------------------
4.0 K Trainable params
0 Non-trainable params
4.0 K Total params
0.016 Total estimated model params size (MB)
63 Modules in train mode
0 Modules in eval mode
`Trainer.fit` stopped: `max_epochs=4` reached.
You can see the logs from the transfer learning above. First, the egnn
module is matched. Then the first stage
starts: freezes the egnn
module (you can see the number of trainable parameters) and trains for 10 epochs. Then the
second stage starts and trains all the parameters for 4 epochs.
And that’s it!