File size: 2,678 Bytes
1485b15
 
d98d701
fa55745
1547ed2
fa55745
fd88619
 
d98d701
1547ed2
df95987
 
1547ed2
1485b15
c162771
d98d701
e517f23
d98d701
 
 
 
 
5d9e01e
 
 
242e83d
 
5d9e01e
242e83d
e517f23
5d9e01e
242e83d
 
 
d98d701
 
5d9e01e
242e83d
 
 
 
5d9e01e
d98d701
5d9e01e
e517f23
 
 
5d9e01e
242e83d
 
 
 
 
 
1547ed2
d98d701
e517f23
 
 
 
d98d701
1547ed2
 
 
 
 
 
 
d98d701
1547ed2
d98d701
 
 
1547ed2
d98d701
1547ed2
d98d701
1547ed2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from __future__ import annotations

from pathlib import Path

import torch
import yaml
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

# from torch_geometric.data import Data

with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
    REGISTRY = yaml.safe_load(f)


class MLIP(
    nn.Module,
    PyTorchModelHubMixin,
    tags=["atomistic-simulation", "MLIP"],
):
    def __init__(self, model: nn.Module) -> None:
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)


class MLIPCalculator(MLIP, Calculator):
    name: str
    implemented_properties: list[str] = ["energy", "forces", "stress"]

    def __init__(
        self,
        model,
        # ASE Calculator
        restart=None,
        atoms=None,
        directory=".",
        calculator_kwargs: dict = {},
    ):
        MLIP.__init__(self, model=model)  # Initialize MLIP part
        Calculator.__init__(
            self, restart=restart, atoms=atoms, directory=directory, **calculator_kwargs
        )  # Initialize ASE Calculator part
        # Additional initialization if needed
        # self.name: str = self.__class__.__name__
        # self.device = device or torch.device(
        #     "cuda" if torch.cuda.is_available() else "cpu"
        # )
        # self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
        # self.implemented_properties = ["energy", "forces", "stress"]

    def calculate(
        self,
        atoms: Atoms,
        properties: list[str],
        system_changes: list = all_changes,
    ):
        """Calculate energies and forces for the given Atoms object"""
        super().calculate(atoms, properties, system_changes)

        output = self.forward(atoms)

        self.results = {}
        if "energy" in properties:
            self.results["energy"] = output["energy"].squeeze().item()
        if "forces" in properties:
            self.results["forces"] = output["forces"].squeeze().cpu().detach().numpy()
        if "stress" in properties:
            self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()

    def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
        """Implement data conversion, graph creation, and model forward pass

        Example implementation:
        1. Use `ase.neighborlist.NeighborList` to get neighbor list
        2. Create `torch_geometric.data.Data` object and copy the data
        3. Pass the `Data` object to the model and return the output

        """
        raise NotImplementedError