Deep learning for molecular discovery with a simple sklearn-style interface
torch-molecule
is a package under active development that facilitates molecular discovery through deep learning, featuring a user-friendly, sklearn
-style interface. It includes model checkpoints for efficient deployment and benchmarking across a range of molecular tasks. Currently, the package focuses on three main components:
- Predictive Models: Done: GREA, SGIR, IRM, GIN/GCN w/ virtual, DIR. SMILES-based LSTM/Transformers. TODO more
- Generative Models: Done: Graph DiT, GraphGA, DiGress, GDS, MolGPT TODO: more
- Representation Models: Done: MoAMa, AttrMasking, ContextPred, EdgePred. Many pretrained models from HF. TODO: checkpoints, more
see the List of Supported Models section.
Note: This project is in active development, and features may change.
-
Create a Conda environment:
conda create --name torch_molecule python=3.11.7 conda activate torch_molecule
-
Install
torch_molecule
from GitHub:Clone the repository:
git clone https://github.com/liugangcode/torch-molecule
Install the requirements:
pip install -r requirements.txt
Editable install:
pip install -e .
-
Install
torch_molecule
from PyPI (Legacy):pip install -i https://test.pypi.org/simple/ torch-molecule
Model | Required Packages |
---|---|
HFPretrainedMolecularEncoder | transformers |
Refer to the tests
folder for more use cases.
The following example demonstrates how to use the GREAMolecularPredictor
class from torch_molecule
:
More examples could be found in the folders examples
and tests
.
from torch_molecule import GREAMolecularPredictor
# Train GREA model
grea_model = GREAMolecularPredictor(
num_task=num_task,
task_type="regression",
model_name="GREA_multitask",
evaluate_criterion='r2',
evaluate_higher_better=True,
verbose=True
)
# Fit the model
X_train = ['C1=CC=CC=C1', 'C1=CC=CC=C1']
y_train = [[0.5], [1.5]]
X_val = ['C1=CC=CC=C1', 'C1=CC=CC=C1']
y_val = [[0.5], [1.5]]
N_trial = 10
grea_model.autofit(
X_train=X_train.tolist(),
y_train=y_train,
X_val=X_val.tolist(),
y_val=y_val,
n_trials=N_trial,
)
torch-molecule
provides checkpoint functions that can be interacted with on Hugging Face.
from torch_molecule import GREAMolecularPredictor
from sklearn.metrics import mean_absolute_error
# Define the repository ID for Hugging Face
repo_id = "user/repo_id"
# Initialize the GREAMolecularPredictor model
model = GREAMolecularPredictor()
# Train the model using autofit
model.autofit(
X_train=X.tolist(), # List of SMILES strings for training
y_train=y_train, # numpy array [n_samples, n_tasks] for training labels
X_val=X_val.tolist(),# List of SMILES strings for validation
y_val=y_val, # numpy array [n_samples, n_tasks] for validation labels
)
# Make predictions on the test set
output = model.predict(X_test.tolist()) # (n_sample, n_task)
# Calculate the mean absolute error
mae = mean_absolute_error(y_test, output['prediction'])
metrics = {'MAE': mae}
# Save the trained model to Hugging Face
model.save_to_hf(
repo_id=repo_id,
task_id=f"{task_name}",
metrics=metrics,
commit_message=f"Upload GREA_{task_name} model with metrics: {metrics}",
private=False
)
# Load a pretrained checkpoint from Hugging Face
model = GREAMolecularPredictor()
model.load_from_hf(repo_id=repo_id, local_cache=f"{model_dir}/GREA_{task_name}.pt")
# Set model parameters
model.set_params(verbose=True)
# Make predictions using the loaded model
predictions = model.predict(smiles_list)
Model | Reference |
---|---|
SGIR | Semi-Supervised Graph Imbalanced Regression. KDD 2023 |
GREA | Graph Rationalization with Environment-based Augmentations. KDD 2022 |
DIR | Discovering Invariant Rationales for Graph Neural Networks. ICLR 2022 |
SSR | SizeShiftReg: a Regularization Method for Improving Size-Generalization in Graph Neural Networks. NeurIPS 2022 |
IRM | Invariant Risk Minimization (2019) |
RPGNN | Relational Pooling for Graph Representations. ICML 2019 |
GNNs | Graph Convolutional Networks. ICLR 2017 and Graph Isomorphism Network. ICLR 2019 |
Transformer (SMILES) | Transformer (Attention is All You Need. NeurIPS 2017) based on SMILES strings |
LSTM (SMILES) | Long short-term memory (Neural Computation 1997) based on SMILES strings |
Model | Reference |
---|---|
MoAMa | Motif-aware Attribute Masking for Molecular Graph Pre-training. LoG 2024 |
AttrMasking | Strategies for Pre-training Graph Neural Networks. ICLR 2020 |
ContextPred | Strategies for Pre-training Graph Neural Networks. ICLR 2020 |
EdgePred | Strategies for Pre-training Graph Neural Networks. ICLR 2020 |
InfoGraph | InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization. ICLR 2020 |
Supervised | Supervised pretraining |
Pretrained | More than ten pretrained models from Hugging Face |
The structure of torch_molecule
is as follows:
tree -L 2 torch_molecule -I '__pycache__|*.pyc|*.pyo|.git|old*'
torch_molecule
├── base
│ ├── base.py
│ ├── encoder.py
│ ├── generator.py
│ ├── __init__.py
│ └── predictor.py
├── encoder
│ ├── attrmask
│ ├── constant.py
│ ├── contextpred
│ ├── edgepred
│ ├── moama
│ └── supervised
├── generator
│ ├── digress
│ ├── graph_dit
│ └── graphga
├── __init__.py
├── nn
│ ├── attention.py
│ ├── embedder.py
│ ├── gnn.py
│ ├── __init__.py
│ └── mlp.py
├── predictor
│ ├── dir
│ ├── gnn
│ ├── grea
│ ├── irm
│ ├── lstm
│ ├── rpgnn
│ ├── sgir
│ └── ssr
└── utils
├── checker.py
├── checkpoint.py
├── format.py
├── generic
├── graph
├── hf.py
├── __init__.py
└── search.py
This project is under active development, and some features may change over time.
The project template was adapted from https://github.com/lwaekfjlk/python-project-template. We thank the authors for their contribution to the open-source community.