import os import monai.networks.nets import torch from transformers import AutoConfig, AutoModel, PreTrainedModel from vista3d_config import VISTA3DConfig class VISTA3DModel(PreTrainedModel): """VISTA3D model for hugging face""" config_class = VISTA3DConfig def __init__(self, config): super().__init__(config) if config.model_type == "VISTA3D": self.network = monai.networks.nets.vista3d132( encoder_embed_dim=config.encoder_embed_dim, in_channels=config.input_channels, ) def forward(self, input): return self.network(input) def register_my_model(): """Utility function to register VISTA3D model so that it can be instantiate by the AutoModel function.""" AutoConfig.register("VISTA3D", VISTA3DConfig) AutoModel.register(VISTA3DConfig, VISTA3DModel) if __name__ == "__main__": FILE_PATH = os.path.dirname(__file__) MODEL_WEIGHT_PATH = os.path.join(FILE_PATH, "models/model.pt") MODEL_PATH = os.path.join(FILE_PATH, "vista3d_pretrained_model") config = VISTA3DConfig() hugging_face_model = VISTA3DModel(config) hugging_face_model.network.load_state_dict(torch.load(MODEL_WEIGHT_PATH)) hugging_face_model.save_pretrained(MODEL_PATH)