File size: 2,699 Bytes
fa55745
d98d701
fa55745
1547ed2
fa55745
fd88619
 
d98d701
1547ed2
 
 
fa55745
 
1547ed2
d98d701
 
 
 
 
 
 
 
 
 
242e83d
 
 
5f464b3
242e83d
 
 
 
 
 
 
 
d98d701
242e83d
5f464b3
 
242e83d
 
d98d701
 
242e83d
 
 
 
 
d98d701
242e83d
 
 
 
 
 
 
1547ed2
d98d701
a140ef6
d98d701
1547ed2
 
 
 
 
 
 
d98d701
1547ed2
d98d701
 
 
1547ed2
d98d701
1547ed2
d98d701
1547ed2
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
from pathlib import Path

import torch
import yaml
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
from torch_geometric.data import Data

with open(os.path.join(os.path.dirname(__file__), "registry.yaml")) as f:
    REGISTRY = yaml.load(f, Loader=yaml.FullLoader)


class MLIP(
    nn.Module,
    PyTorchModelHubMixin,
    tags=["atomistic-simulation", "MLIP"],
):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)


class ModuleMLIP(MLIP):
    def __init__(self, model: nn.Module, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.add_module("model", model)

    def forward(self, x):
        print("Forwarding...")
        out = self.model(x)
        print("Forwarded!")
        return out


class MLIPCalculator(Calculator):
    name: str
    # device: torch.device
    # model: MLIP
    implemented_properties: list[str] = ["energy", "forces", "stress"]

    def __init__(
        self,
        # ASE Calculator
        restart=None,
        atoms=None,
        directory=".",
        **kwargs,
    ):
        super().__init__(restart=restart, atoms=atoms, directory=directory, **kwargs)
        # self.name: str = self.__class__.__name__
        # self.device = device or torch.device(
        #     "cuda" if torch.cuda.is_available() else "cpu"
        # )
        # self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device)
        # self.implemented_properties = ["energy", "forces", "stress"]

    def calculate(
        self, atoms: Atoms, properties: list[str], system_changes: list = all_changes,
    ):
        """Calculate energies and forces for the given Atoms object"""
        super().calculate(atoms, properties, system_changes)

        output = self.forward(atoms)

        self.results = {}
        if "energy" in properties:
            self.results["energy"] = output["energy"].squeeze().item()
        if "forces" in properties:
            self.results["forces"] = output["forces"].squeeze().cpu().detach().numpy()
        if "stress" in properties:
            self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy()

    def forward(self, x: Atoms) -> dict[str, torch.Tensor]:
        """Implement data conversion, graph creation, and model forward pass

        Example implementation:
        1. Use `ase.neighborlist.NeighborList` to get neighbor list
        2. Create `torch_geometric.data.Data` object and copy the data
        3. Pass the `Data` object to the model and return the output

        """
        raise NotImplementedError