For loading this model from within https://github.com/danbraunai/simple_stories_train, you can run:

from typing import Any

import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

from simple_stories_train.models.llama import Llama, LlamaConfig
from simple_stories_train.models.model_configs import MODEL_CONFIGS_DICT

class LlamaTransformer(
    nn.Module,
    PyTorchModelHubMixin, 
    repo_url="https://github.com/danbraunai/simple_stories_train",
    language=["en"],
    pipeline_tag="text-generation"
):
    def __init__(self, **config : Any):
        super().__init__()
        self.llama = Llama(LlamaConfig(**config))

    def forward(self, x : torch.Tensor):
        return self.llama(x)

config = MODEL_CONFIGS_DICT["d12"]
model = LlamaTransformer(**config)
HUB_REPO_NAME = "lennart-finke/SimpleStories-125M"

model = model.from_pretrained(HUB_REPO_NAME)
Downloads last month
15
Safetensors
Model size
127M params
Tensor type
F32
·
Inference Examples
Unable to determine this model's library. Check the docs .

Dataset used to train lennart-finke/SimpleStories-125M