Transfer learning

Transfer learning#

In this tutorial, we provide an example of using the transfer learning functionality. We require the rdkit package, so make sure to pip install 'physicsml[rdkit]' to follow along!

We will train an EGNN model on the lumo energy of QM9 and then transfer this model to predict the u0 energy.

Pre-trained model#

First, let’s load a truncated QM9 dataset with 1000 datapoints

import numpy as np
from molflux.datasets import load_dataset_from_store

dataset = load_dataset_from_store("gdb9_trunc.parquet")

print(dataset)

idxs = np.random.permutation(range(len(dataset)))
dataset = dataset.select(idxs[:1000])

print(dataset)
Dataset({
    features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom'],
    num_rows: 1000
})
Dataset({
    features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom'],
    num_rows: 1000
})

Note

The dataset above is a truncated version to run efficiently in the docs. For running this locally, load the entire dataset by doing

from molflux.datasets import load_dataset

dataset = load_dataset("gdb9", "rdkit")

You can see that there is the mol_bytes column (which is the rdkit serialisation of the 3d molecules) and the remaining columns of computes properties.

Next, we will featurise the dataset. In this example, we start by using the atomic numbers only (since we do not require self energies for the lumo energy).

import logging
logging.disable(logging.CRITICAL)

from molflux.core import featurise_dataset

featurisation_metadata = {
    "version": 1,
    "config": [
        {
            "column": "mol_bytes",
            "representations": [
                {
                    "name": "physicsml_features",
                    "config": {
                        "atomic_number_mapping": {
                            1: 0,
                            6: 1,
                            7: 2,
                            8: 3,
                            9: 4,
                        },
                        "backend": "rdkit",
                    },
                    "as": "{feature_name}"
                }
            ]
        }
    ]
}

# featurise the mols
featurised_dataset = featurise_dataset(
    dataset,
    featurisation_metadata=featurisation_metadata,
    num_proc=4,
    batch_size=100,
)

print(featurised_dataset)
Dataset({
    features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates'],
    num_rows: 1000
})

You can see that we now have the extra columns for

  • physicsml_atom_idxs: The index of the atoms in the molecules

  • physicsml_atom_numbers: The atomic numbers mapped using the mapping dictionary

  • physicsml_coordinates: The coordinates of the atoms

Next, we need to split the dataset. For this, we use the simple shuffle_split (random split) with 80% training and 20% test. To split the dataset, we use the split_dataset function from molflux.datasets.

from molflux.datasets import split_dataset
from molflux.splits import load_from_dict as load_split_from_dict

shuffle_strategy = load_split_from_dict(
    {
        "name": "shuffle_split",
        "presets": {
            "train_fraction": 0.8,
            "validation_fraction": 0.0,
            "test_fraction": 0.2,
        }
    }
)

split_featurised_dataset = next(split_dataset(featurised_dataset, shuffle_strategy))

print(split_featurised_dataset)
DatasetDict({
    train: Dataset({
        features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates'],
        num_rows: 800
    })
    validation: Dataset({
        features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates'],
        num_rows: 0
    })
    test: Dataset({
        features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates'],
        num_rows: 200
    })
})

We can now turn to pre training the model! To do so, we need to define the model config and the x_features and the y_features. Once trained, we will save it to be used for transferring.

import json

from molflux.modelzoo import load_from_dict as load_model_from_dict
from molflux.core import save_model

model = load_model_from_dict(
    {
        "name": "egnn_model",
        "config": {
            "x_features": [
                'physicsml_atom_idxs',
                'physicsml_atom_numbers',
                'physicsml_coordinates',
            ],
            "y_features": ['lumo'],
            "num_node_feats": 5,
            "num_layers": 2,
            "c_hidden": 12,
            "y_graph_scalars_loss_config": {
                "name": "MSELoss",
            },
            "optimizer": {
                "name": "AdamW",
                "config": {
                    "lr": 1e-3,
                }
            },
            "datamodule": {
                "y_graph_scalars": ['lumo'],
                "num_elements": 5,
                "cut_off": 5.0,
                "pre_batch": "in_memory",
                "train": {"batch_size": 64},
                "validation": {"batch_size": 128},
            },
            "trainer": {
                "max_epochs": 10,
                "accelerator": "cpu",
                "logger": False,
            }
        }
    }
)

model.train(
    train_data=split_featurised_dataset["train"]
)

save_model(model, "pre_trained_model", featurisation_metadata)
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/torch/jit/_trace.py:763: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.
  traced = torch._C._create_function_from_trace(
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/physicsml/lightning/pre_batching_in_memory.py:53: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:105: Total length of `dict` across ranks is zero. Please make sure this was your intention.
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/torch/storage.py:414: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  return torch.load(io.BytesIO(b))
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/core/module.py:516: You called `self.log('train/total/y_graph_scalars', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/core/module.py:516: You called `self.log('train/total/loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
'pre_trained_model'

We now have a dummy pre-trained model.

Transfer learning#

Finally, we come to the transfer learning. First, we need to re-featurise the dataset to include the atomic self energies for 'u0' and then split it for training (ignoring that this is the same dataset for pretraining and transferring, it’s only for demonstration).

import logging
logging.disable(logging.CRITICAL)

from molflux.core import featurise_dataset
from molflux.datasets import split_dataset
from molflux.splits import load_from_dict as load_split_from_dict

featurisation_metadata = {
    "version": 1,
    "config": [
        {
            "column": "mol_bytes",
            "representations": [
                {
                    "name": "physicsml_features",
                    "config": {
                        "atomic_number_mapping": {
                            1: 0,
                            6: 1,
                            7: 2,
                            8: 3,
                            9: 4,
                        },
                        "atomic_energies": {
                            1: -0.6019805629746086,
                            6: -38.07749583990695,
                            7: -54.75225433326539,
                            8: -75.22521603087064,
                            9: -99.85134426752529
                        },
                        "backend": "rdkit",
                    },
                    "as": "{feature_name}"
                }
            ]
        }
    ]
}

# featurise the mols
featurised_dataset = featurise_dataset(
    dataset,
    featurisation_metadata=featurisation_metadata,
    num_proc=4,
    batch_size=100,
)

shuffle_strategy = load_split_from_dict(
    {
        "name": "shuffle_split",
        "presets": {
            "train_fraction": 0.8,
            "validation_fraction": 0.0,
            "test_fraction": 0.2,
        }
    }
)

split_featurised_dataset = next(split_dataset(featurised_dataset, shuffle_strategy))

print(split_featurised_dataset)
DatasetDict({
    train: Dataset({
        features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates', 'physicsml_total_atomic_energy'],
        num_rows: 800
    })
    validation: Dataset({
        features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates', 'physicsml_total_atomic_energy'],
        num_rows: 0
    })
    test: Dataset({
        features: ['mol_bytes', 'mol_id', 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u298', 'h298', 'g298', 'cv', 'u0_atom', 'u298_atom', 'h298_atom', 'g298_atom', 'physicsml_atom_idxs', 'physicsml_atom_numbers', 'physicsml_coordinates', 'physicsml_total_atomic_energy'],
        num_rows: 200
    })
})

Now we need to define the model config. This will be exactly the same as the pre-trained model config with lumo substituted with u0 and with the addition of the transfer learning config. For more information about the transfer learning config, see here.

model_config = {
    "name": "egnn_model",
    "config": {
        "x_features": [
            'physicsml_atom_idxs',
            'physicsml_atom_numbers',
            'physicsml_coordinates',
            'physicsml_total_atomic_energy',
        ],
        "y_features": ['u0'],
        "num_node_feats": 5,
        "num_layers": 2,
        "c_hidden": 12,
        "y_graph_scalars_loss_config": {
            "name": "MSELoss",
        },
        "optimizer": {
            "name": "AdamW",
            "config": {
                "lr": 1e-3,
            }
        },
        "datamodule": {
            "y_graph_scalars": ['u0'],
            "num_elements": 5,
            "cut_off": 5.0,
            "pre_batch": "in_memory",
            "train": {"batch_size": 64},
            "validation": {"batch_size": 128},
        },
        "trainer": {
            "max_epochs": 10,
            "accelerator": "cpu",
            "logger": False,
        },
        "transfer_learning": {
            "pre_trained_model_path": "pre_trained_model",
            "modules_to_match": {
                "egnn": "egnn",
            },
            "stages": [
                {
                    "freeze_modules": ["egnn"],
                    "datamodule": {
                        "train": {"batch_size": 128},
                    },
                    "optimizer": {
                        "config": {
                            "lr": 1e-2
                        }
                    }
                },
                {
                    "trainer": {
                        "max_epochs": 4,
                    },
                    "optimizer": {
                        "config": {
                            "lr": 1e-4
                        }
                    }
                }
            ]
        }
    }
}

As you can see, we first specify the path to the pre-trained model. The EGNN model contains two main submodules: the egnn backbone (message passing) and the pooling_head (for pooling and generating predictions). In this example, we choose to match only the backbone (since the lumo and u0 tasks are different and the pooling_head will not contain any useful learnt information).

Next, we specify a two stage transfer learning. In the first, we choose to freeze the egnn backbone and to train the pooling_head. Additionally, we override some kwargs (such as batch size and learning rate). Notice that you only need to specify the kwargs you want to override (such as learning rate) and all the rest will be used from the main config. In the second stage we train the whole model (no frozen modules) at a lower learning rate for less epochs.

So let’s run the training!

import logging
logging.disable(logging.NOTSET)

from molflux.modelzoo import load_from_dict as load_model_from_dict
model = load_model_from_dict(model_config)

model.train(
    train_data=split_featurised_dataset["train"]
)
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/molflux/modelzoo/models/lightning/model.py:369: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
WARNING:molflux.modelzoo.models.lightning.model:Matched module 'egnn' in new module to 'egnn' in old module.
WARNING:molflux.modelzoo.models.lightning.model:Freezing module: egnn.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Loading `train_dataloader` to estimate number of stepping batches.
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/physicsml/lightning/pre_batching_in_memory.py:53: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.

  | Name         | Type        | Params | Mode 
-----------------------------------------------------
0 | egnn         | EGNN        | 3.5 K  | train
1 | pooling_head | PoolingHead | 481    | train
-----------------------------------------------------
481       Trainable params
3.5 K     Non-trainable params
4.0 K     Total params
0.016     Total estimated model params size (MB)
63        Modules in train mode
0         Modules in eval mode
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/utilities/data.py:105: Total length of `dict` across ranks is zero. Please make sure this was your intention.
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/torch/storage.py:414: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  return torch.load(io.BytesIO(b))
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/core/module.py:516: You called `self.log('train/total/y_graph_scalars', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
/home/runner/work/physicsml/physicsml/.cache/nox/docs_build-3-11/lib/python3.11/site-packages/lightning/pytorch/core/module.py:516: You called `self.log('train/total/loss', ..., logger=True)` but have no logger configured. You can enable one by doing `Trainer(logger=ALogger(...))`
`Trainer.fit` stopped: `max_epochs=10` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
  | Name         | Type        | Params | Mode 
-----------------------------------------------------
0 | egnn         | EGNN        | 3.5 K  | train
1 | pooling_head | PoolingHead | 481    | train
-----------------------------------------------------
4.0 K     Trainable params
0         Non-trainable params
4.0 K     Total params
0.016     Total estimated model params size (MB)
63        Modules in train mode
0         Modules in eval mode
`Trainer.fit` stopped: `max_epochs=4` reached.

You can see the logs from the transfer learning above. First, the egnn module is matched. Then the first stage starts: freezes the egnn module (you can see the number of trainable parameters) and trains for 10 epochs. Then the second stage starts and trains all the parameters for 4 epochs.

And that’s it!