For loading this model from within https://github.com/danbraunai/simple_stories_train, you can run:
from typing import Any
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from simple_stories_train.models.llama import Llama, LlamaConfig
from simple_stories_train.models.model_configs import MODEL_CONFIGS_DICT
class LlamaTransformer(
nn.Module,
PyTorchModelHubMixin,
repo_url="https://github.com/danbraunai/simple_stories_train",
language=["en"],
pipeline_tag="text-generation"
):
def __init__(self, **config : Any):
super().__init__()
self.llama = Llama(LlamaConfig(**config))
def forward(self, x : torch.Tensor):
return self.llama(x)
config = MODEL_CONFIGS_DICT["d12"]
model = LlamaTransformer(**config)
HUB_REPO_NAME = "lennart-finke/SimpleStories-125M"
model = model.from_pretrained(HUB_REPO_NAME)
- Downloads last month
- 15