MolFlux layer#

The molflux layer handles all the interaction with physicsml models. All functionality for loading, saving, training, and inference are handled through its standard API (this does not include interfaces for plugins such as OpenMM, see plugins for that).

For a complete overview of the molflux API, see its documentation. Here, we will discuss the relevant parts for the physicsml package.

The overall pipeline of any model build can be divided into 5 stages

Accessing datasets#

Datasets can be accessed via the molflux.datasets module. There are built in datasets that come available with molflux by default (see what is available here). You can also load preprocessed datasets from disk and remote storage (such as s3). If you would like to share your datasets with a wider audience, you can register them in the molflux.datasets module and make them available (see the molflux documentation on how to do that).

Featurising datasets#

To stay in line with the modular principles of molflux and physicsml, we separate the featurisation of datasets (e.g. extracting the atomic numbers and coordinates) into its own step. This not only allows for reducing redundant computation but also makes the process of adding and expanding the available features more streamlined.

For more information about featurisation, see PhysicsML features.

Splitting datasets#

The next stage is splitting the datasets for model evaluation and benchmarking. This is done via the molflux.splits module (see docs here). If you already have pre-specified splits for your datasets, then you do not have to worry about this.

Model building#

The main part of the pipeline is model building. We divide this section into model loading, model training, and model saving.

Model loading#

Models are best loaded by specifying a model config and using the molflux.modelzoo.load_from_dict function. The model config has this generic form

from molflux.modelzoo import load_from_dict

model_config = {
    "name": <name of model>,
    "config": {
        "x_features": <list of x features generated from featurisation>,
        "y_features": <list of y features to train on>,
        ...,  # a bunch of model specific kwargs,
        "optimizer": ...,  # the optimizer config
        "scheduler": ...,  # the scheduler config
        "datamodule": ...,  # a bunch of lightning specific kwargs to control the datamodule
        "trainer": ...,  # a bunch of lightning specific kwargs to control the training
        "transfer_learning": ...,  # a bunch of kwargs to control transfer learning
    }
}

model = load_from_dict(model_config)

This loads the model object (which is molflux a model that can handle the training and inference routines).

Model training#

Once the model is loaded, training is as simple as

model.train(
    train_data=training_dataset,
    validation_data=validation_dataset,
)

The validation_data is optional (for early stopping and monitoring) and models can in principle be trained on a train_data only.

Under the hood, this sets up a bunch of objects from lightning to run the training. First, it instantiates the torch module (i.e. the actual model code), the datamodule, and the Trainer. Then, it passes the module and the datamodule to the Trainer to run the training. For more info, see Lightning layer.

Note

If a transfer learning config is specified, the training routine is a bit more involved. See transfer learning for more info.

Model saving#

Once the model is trained, we can simply save it by doing

from molflux.core import save_model

save_model(model, "path_to_model", featurisation_metadata)

It is important to persist the featurisation_metadata so that the model can be inferenced later on. Saving the model creates a directory with the model config, featursation config, model artefacts (weights checkpoint), and a frozen requirements file to recreate the environment it was trained in. For more information, see the molflux docs.

Computing metrics#

Once the model is trained, we can compute some metrics. This is simply done by computing some predictions and using the supplied metrics functionality of molflux. See here for more info.