Spaces:
Running
Running
File size: 2,678 Bytes
1485b15 d98d701 fa55745 1547ed2 fa55745 fd88619 d98d701 1547ed2 df95987 1547ed2 1485b15 c162771 d98d701 e517f23 d98d701 5d9e01e 242e83d 5d9e01e 242e83d e517f23 5d9e01e 242e83d d98d701 5d9e01e 242e83d 5d9e01e d98d701 5d9e01e e517f23 5d9e01e 242e83d 1547ed2 d98d701 e517f23 d98d701 1547ed2 d98d701 1547ed2 d98d701 1547ed2 d98d701 1547ed2 d98d701 1547ed2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from __future__ import annotations
from pathlib import Path
import torch
import yaml
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
# from torch_geometric.data import Data
with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
REGISTRY = yaml.safe_load(f)
class MLIP(
nn.Module,
PyTorchModelHubMixin,
tags=["atomistic-simulation", "MLIP"],
):
def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model
def forward(self, x):
return self.model(x)
class MLIPCalculator(MLIP, Calculator):
name: str
implemented_properties: list[str] = ["energy", "forces", "stress"]
def __init__(
self,
model,
# ASE Calculator
restart=None,
atoms=None,
directory=".",
calculator_kwargs: dict = {},
):
MLIP.__init__(self, model=model) # Initialize MLIP part
Calculator.__init__(
self, restart=restart, atoms=atoms, directory=directory, **calculator_kwargs
) # Initialize ASE Calculator part
# Additional initialization if needed
# self.name: str = self.__class__.__name__
# self.device = device or torch.device(
# "cuda" if torch.cuda.is_available() else "cpu"
# )
# self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
# self.implemented_properties = ["energy", "forces", "stress"]
def calculate(
self,
atoms: Atoms,
properties: list[str],
system_changes: list = all_changes,
):
"""Calculate energies and forces for the given Atoms object"""
super().calculate(atoms, properties, system_changes)
output = self.forward(atoms)
self.results = {}
if "energy" in properties:
self.results["energy"] = output["energy"].squeeze().item()
if "forces" in properties:
self.results["forces"] = output["forces"].squeeze().cpu().detach().numpy()
if "stress" in properties:
self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()
def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
"""Implement data conversion, graph creation, and model forward pass
Example implementation:
1. Use `ase.neighborlist.NeighborList` to get neighbor list
2. Create `torch_geometric.data.Data` object and copy the data
3. Pass the `Data` object to the model and return the output
"""
raise NotImplementedError
|