Add your own model architecture#
While physicsml ships with a large catalogue of available model-architectures, you can also add you own
state-of-the-art architecture and make it more widely available to all users of the package. In this guide,
we will provide an overview of how to do this.
There are three main components to adding a new model: The model config, molflux model API handler, and torch module
code.
The model config#
The most basic part of adding a new model is defining the model config. The general model config looks like this:
from typing import Dict, Literal, Optional
from pydantic.v1.dataclasses import dataclass
from physicsml.lightning.config import ConfigDict, PhysicsMLModelConfig
@dataclass(config=ConfigDict)
class MyModelConfig(PhysicsMLModelConfig):
# all model kwargs
...
The config is a pydantic dataclass which takes care of validating the inputs. It also inherits the PhysicsMLModelConfig
which provides the shared kwargs (such as the datamodule config and the optimizer config). In here, you can specify
whatever your model requires for training and inference. A good simple example to follow is the EGNN model config.
The molflux model#
The next part is the molflux model wrapper. This is responsible for handling all the training, inferencing, loading,
and saving functionality. If your model is a generic GNN, then this a simple wrapper class
from typing import Any, Type
from molflux.modelzoo.info import ModelInfo
from physicsml.lightning.model import PhysicsMLModelBase
from physicsml.models.my_model.supervised.default_configs import MyModelConfig
from physicsml.models.my_model.supervised.my_model_module import MyModelModule
class MyModel(PhysicsMLModelBase[MyModelConfig]):
def _info(self) -> ModelInfo:
return ModelInfo(
model_description="my model description",
config_description="config description",
)
@property
def _config_builder(self) -> Type[EGNNModelConfig]:
return MyModelConfig
def _instantiate_module(self) -> Any:
return MyModelModule(
model_config=self.model_config,
)
The class inherits the PhysicsMLModelBase[MyModelConfig] class. The config in the square brackets is an inherited generic
for typing purposes.
You need to specify the ModelInfo (which has a description of the model and the config), the _config_builder which
returns the config class (for instantiation internally), and the _instantiate_module which returns an initialised lightning
module (more on that below). For a simple example of this, check out the EGNN model.
If your model requires a specialised dataset and dataloader, then you can override the _datamodule_builder which
returns specific datamodule class for your model. For an example of this, see the ANI model.
Note
Uncertainty models need to inherit the PhysicsMLUncertaintyModelBase from physicsml.lightning.model_uncertainty.
This class handles the additional API functionality for uncertainty models. For an example, see the
MeanVarEGNNModel.
Make your model discoverable#
To make your model discoverable in physicsml and molflux, you need to register it as a plugin. You can do that in
the pyproject.toml file of your repo and under [project.entry-points.'molflux.modelzoo.plugins.physicsml'],
add a plugin to your model class as follows
[project.entry-points.'molflux.modelzoo.plugins.physicsml']
name_of_model = 'path.to.module.file:YourModelName'
Note
You can also do this in the setup.cfg file of your repo and under [options.entry_points]. Add a plugin to your
model class as follows
[options.entry_points]
molflux.modelzoo.plugins.physicsml =
name_of_model = path.to.module.file:YourModelName
This entry point allows molflux.modelzoo to hook into your model and automatically register it in the catalogue.
The lightning module#
The lightning module is where the core part of the model code lives. The general class looks like
from typing import Any, Dict, Optional
import torch
from physicsml.lightning.module import PhysicsMLModuleBase
from physicsml.models.my_model.supervised.default_configs import MyModelConfig
class MyModelModule(PhysicsMLModuleBase):
model_config: MyModelConfig
def __init__(
self,
model_config: EGNNModelConfig,
**kwargs: Any,
) -> None:
super().__init__(model_config=model_config)
# configure your module code here
def forward(
self,
data: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
# model operations on the data
return output
def compute_loss(self, input: Any, target: Any) -> torch.Tensor:
# compute a loss
return loss
If you model is a generic GNN, then all you have to do is to specify you module code (pure torch code), the forward
pass, and the compute_loss method. The forward pass must return a dictionary which includes all the expected outputs
of the model (such as y_graph_scalars, y_node_vector, etc…). The compute_loss method must take in the input
data batch and the output of the model and return a loss. For an example of this, see the EGNN Module.
Notice that all the lightning boiler plate is handled by the inherited PhysicsMLModuleBase.
If your model requires specialised training steps (or if you’d like more control over those), then you can directly override them as in the ANI Module.