Add your own model architecture#
While physicsml
ships with a large catalogue of available model-architectures, you can also add you own
state-of-the-art architecture and make it more widely available to all users of the package. In this guide,
we will provide an overview of how to do this.
There are three main components to adding a new model: The model config, molflux
model API handler, and torch
module
code.
The model config#
The most basic part of adding a new model is defining the model config. The general model config looks like this:
from typing import Dict, Literal, Optional
from pydantic.v1.dataclasses import dataclass
from physicsml.lightning.config import ConfigDict, PhysicsMLModelConfig
@dataclass(config=ConfigDict)
class MyModelConfig(PhysicsMLModelConfig):
# all model kwargs
...
The config is a pydantic dataclass which takes care of validating the inputs. It also inherits the PhysicsMLModelConfig
which provides the shared kwargs (such as the datamodule
config and the optimizer
config). In here, you can specify
whatever your model requires for training and inference. A good simple example to follow is the EGNN model config.
The molflux
model#
The next part is the molflux
model wrapper. This is responsible for handling all the training, inferencing, loading,
and saving functionality. If your model is a generic GNN, then this a simple wrapper class
from typing import Any, Type
from molflux.modelzoo.info import ModelInfo
from physicsml.lightning.model import PhysicsMLModelBase
from physicsml.models.my_model.supervised.default_configs import MyModelConfig
from physicsml.models.my_model.supervised.my_model_module import MyModelModule
class MyModel(PhysicsMLModelBase[MyModelConfig]):
def _info(self) -> ModelInfo:
return ModelInfo(
model_description="my model description",
config_description="config description",
)
@property
def _config_builder(self) -> Type[EGNNModelConfig]:
return MyModelConfig
def _instantiate_module(self) -> Any:
return MyModelModule(
model_config=self.model_config,
)
The class inherits the PhysicsMLModelBase[MyModelConfig]
class. The config in the square brackets is an inherited generic
for typing purposes.
You need to specify the ModelInfo
(which has a description of the model and the config), the _config_builder
which
returns the config class (for instantiation internally), and the _instantiate_module
which returns an initialised lightning
module (more on that below). For a simple example of this, check out the EGNN model.
If your model requires a specialised dataset and dataloader, then you can override the _datamodule_builder
which
returns specific datamodule class for your model. For an example of this, see the ANI model.
Note
Uncertainty models need to inherit the PhysicsMLUncertaintyModelBase
from physicsml.lightning.model_uncertainty
.
This class handles the additional API functionality for uncertainty models. For an example, see the
MeanVarEGNNModel.
Make your model discoverable#
To make your model discoverable in physicsml
and molflux
, you need to register it as a plugin. You can do that in
the pyproject.toml
file of your repo and under [project.entry-points.'molflux.modelzoo.plugins.physicsml']
,
add a plugin to your model class as follows
[project.entry-points.'molflux.modelzoo.plugins.physicsml']
name_of_model = 'path.to.module.file:YourModelName'
Note
You can also do this in the setup.cfg
file of your repo and under [options.entry_points]
. Add a plugin to your
model class as follows
[options.entry_points]
molflux.modelzoo.plugins.physicsml =
name_of_model = path.to.module.file:YourModelName
This entry point allows molflux.modelzoo
to hook into your model and automatically register it in the catalogue.
The lightning
module#
The lightning
module is where the core part of the model code lives. The general class looks like
from typing import Any, Dict, Optional
import torch
from physicsml.lightning.module import PhysicsMLModuleBase
from physicsml.models.my_model.supervised.default_configs import MyModelConfig
class MyModelModule(PhysicsMLModuleBase):
model_config: MyModelConfig
def __init__(
self,
model_config: EGNNModelConfig,
**kwargs: Any,
) -> None:
super().__init__(model_config=model_config)
# configure your module code here
def forward(
self,
data: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
# model operations on the data
return output
def compute_loss(self, input: Any, target: Any) -> torch.Tensor:
# compute a loss
return loss
If you model is a generic GNN, then all you have to do is to specify you module code (pure torch
code), the forward
pass, and the compute_loss
method. The forward
pass must return a dictionary which includes all the expected outputs
of the model (such as y_graph_scalars
, y_node_vector
, etc…). The compute_loss
method must take in the input
data
batch and the output of the model and return a loss. For an example of this, see the EGNN Module.
Notice that all the lightning boiler plate is handled by the inherited PhysicsMLModuleBase
.
If your model requires specialised training steps (or if you’d like more control over those), then you can directly override them as in the ANI Module.