|
import gradio as gr |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from model import GPT, GPTConfig |
|
import tiktoken |
|
|
|
|
|
repo_id = "fridayfringe/my-gpt" |
|
model_path = hf_hub_download(repo_id=repo_id, filename="model.pth") |
|
|
|
|
|
config = GPTConfig() |
|
model = GPT(config) |
|
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) |
|
model.eval() |
|
|
|
|
|
enc = tiktoken.get_encoding("gpt2") |
|
|
|
|
|
def generate_text(prompt, max_length=100): |
|
tokens = enc.encode(prompt) |
|
tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
for _ in range(max_length): |
|
logits, _ = model(tokens) |
|
next_token = torch.argmax(logits[:, -1, :], dim=-1) |
|
tokens = torch.cat((tokens, next_token.unsqueeze(0)), dim=1) |
|
|
|
output = enc.decode(tokens[0].tolist()) |
|
return output |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_text, |
|
inputs=[ |
|
gr.Textbox(label="Enter your prompt", placeholder="Type something..."), |
|
gr.Slider(10, 200, value=100, step=10, label="Max Length"), |
|
], |
|
outputs="text", |
|
title="Custom GPT Chatbot", |
|
description="A lightweight GPT model trained on Spotify lyrics, deployed using Hugging Face Spaces.", |
|
) |
|
|
|
|
|
iface.launch() |
|
|