Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from model import GPT, GPTConfig | |
import tiktoken | |
# Load model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
config = GPTConfig() | |
model = GPT(config) | |
model.load_state_dict(torch.load("model.pth", map_location=torch.device(device))) | |
model.eval() | |
# Tokenizer | |
enc = tiktoken.get_encoding("gpt2") | |
# Function for text generation | |
def generate_text(prompt, max_length=100): | |
tokens = enc.encode(prompt) | |
tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
for _ in range(max_length): | |
logits, _ = model(tokens) | |
logits = logits[:, -1, :] | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
next_token = torch.multinomial(probs, 1) | |
tokens = torch.cat([tokens, next_token], dim=1) | |
return enc.decode(tokens.squeeze().tolist()) | |
# Gradio UI | |
iface = gr.Interface( | |
fn=generate_text, | |
inputs=["text", gr.Slider(50, 500, step=10, label="Max Length")], | |
outputs="text", | |
title="My GPT Model", | |
description="Enter a prompt and generate text using my GPT model." | |
) | |
iface.launch() | |