Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer | |
from huggingface_hub import hf_hub_download | |
from model import LlamaForCausalLM # Import your custom model class | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer") | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "[PAD]" | |
# Initialize model with reduced parameters (135M config) | |
model = LlamaForCausalLM( | |
vocab_size=tokenizer.vocab_size, | |
dim=576, | |
num_layers=30, | |
hidden_dim=1536, | |
num_heads=9 | |
) | |
device = "cpu" | |
# Load trained weights | |
# state_dict = torch.hub.load_state_dict_from_url( | |
# "https://huggingface.co/satyanayak/custom-smallmv2135/resolve/main/model-dict-step-5500.pt", | |
# map_location="cpu" | |
# ) | |
# model.load_state_dict(state_dict) | |
# model.eval() | |
model_id = "satyanayak/custom-smallmv2135" | |
checkpoint_path = hf_hub_download(repo_id=model_id, filename="model-dict-step-5500.pt") | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.to(device) | |
model.eval() | |
def generate_text(prompt, max_length=100, temperature=0.7, top_k=50): | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
for _ in range(max_length): | |
outputs = model(input_ids) | |
next_token_logits = outputs[:, -1, :] / temperature | |
# Apply top-k sampling | |
top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1) | |
probs = torch.softmax(top_k_logits, dim=-1) | |
# Sample from distribution | |
next_token_idx = torch.multinomial(probs, num_samples=1) | |
next_token = top_k_indices[0, next_token_idx[0]] | |
if next_token.item() == tokenizer.eos_token_id: | |
break | |
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) | |
return tokenizer.decode(input_ids[0], skip_special_tokens=True) | |
# Gradio interface | |
demo = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox(label="Input Prompt", lines=3), | |
gr.Slider(50, 200, value=100, label="Max Length"), | |
gr.Slider(0.1, 2.0, value=0.7, label="Temperature"), | |
gr.Slider(10, 100, value=50, label="Top-k") | |
], | |
outputs=gr.Textbox(label="Generated Text", lines=5), | |
title="🦙 Custom SmolLLM Demo", | |
description="A 135M parameter language model trained on smollm-corpus" | |
) | |
if __name__ == "__main__": | |
demo.launch() |