Leap0 / app.py
Ronakparmar's picture
Update app.py
3d7ebe2 verified
import gradio as gr
import torch
import json
from transformers import GPT2Tokenizer
from safetensors.torch import load_file
from transformers import GPT2Config as GPTConfig
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass
# Define the GPTConfig class with filtering
class GPTConfig:
def __init__(self, n_embd, n_head, n_layer, vocab_size):
self.n_embd = n_embd
self.n_head = n_head
self.n_layer = n_layer
self.vocab_size = vocab_size
@classmethod
def from_dict(cls, config_dict):
# Define the expected keys
expected_keys = {'n_embd', 'n_head', 'n_layer', 'vocab_size'}
# Filter out unexpected keys
filtered_dict = {key: value for key, value in config_dict.items() if key in expected_keys}
return cls(**filtered_dict)
# Define the GPT class
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
# Initialize the embedding layer
self.embedding = nn.Embedding(config.vocab_size, config.n_embd)
# Initialize the Transformer decoder
decoder_layer = nn.TransformerDecoderLayer(d_model=config.n_embd, nhead=config.n_head, dim_feedforward=config.n_embd, dropout=0.1)
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=config.n_layer)
# Initialize the language model head
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
def forward(self, input_ids):
# Embed the input tokens
input_embeddings = self.embedding(input_ids)
# Transpose the input to match the expected shape for TransformerDecoder
input_embeddings = input_embeddings.transpose(0, 1)
# Pass through the Transformer decoder
transformer_output = self.transformer(input_embeddings, input_embeddings)
# Transpose back to the original shape
transformer_output = transformer_output.transpose(0, 1)
# Get the logits from the language model head
logits = self.lm_head(transformer_output)
return logits
def generate(self, input_ids, max_new_tokens, temperature, top_k):
# Implement the text generation logic
output_ids = input_ids
for _ in range(max_new_tokens):
logits = self.forward(output_ids[:, -1:])
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
# Ensure probs is 2D
if probs.dim() == 3:
probs = probs.squeeze(0) # Remove the batch dimension if it exists
top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
# Ensure top_k_probs is 2D
if top_k_probs.dim() == 1:
top_k_probs = top_k_probs.unsqueeze(0)
next_token = torch.multinomial(top_k_probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token)
# Ensure next_token is 2D
if next_token.dim() == 1:
next_token = next_token.unsqueeze(0)
output_ids = torch.cat([output_ids, next_token], dim=1)
return output_ids
# Initialize global variables
model = None
tokenizer = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
"""Load the Leap0 model and tokenizer."""
global model, tokenizer
try:
# Paths to config and model files
config_path = "config.json"
model_path = "model.safetensors"
print(f"Loading configuration from {config_path}...")
# Load the configuration
with open(config_path, "r") as f:
config_dict = json.load(f)
print("Configuration loaded. Creating model config...")
config = GPTConfig.from_dict(config_dict)
print(f"Model config created: {config}")
print(f"Loading model weights from {model_path}...")
# Load the model weights
tensors = load_file(model_path)
print("Instantiating model...")
# Instantiate the model with the loaded config
model = GPT(config)
print("Loading weights into model...")
model.load_state_dict(tensors, strict=False)
model.to(device)
model.eval()
print("Loading tokenizer...")
# Load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
print("Model and tokenizer loaded successfully")
except Exception as e:
print(f"Error loading model: {str(e)}")
raise
def generate_text(prompt, max_length=50, temperature=0.7, top_k=40):
"""Generate text based on the provided prompt."""
if model is None or tokenizer is None:
load_model()
# Tokenize the input text
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Generate text
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_length,
temperature=temperature,
top_k=top_k
)
# Decode the output
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return output_text
# Create the Gradio interface
def create_interface():
with gr.Blocks(css="footer {visibility: hidden}") as demo:
gr.Markdown("# Leap0 Language Model")
gr.Markdown("A GPT-2 based model trained on the Tiny Stories dataset")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Enter your prompt",
placeholder="once upon a time in the village of",
lines=3
)
with gr.Row():
max_length = gr.Slider(
minimum=1,
maximum=200,
value=50,
step=1,
label="Max Length"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=40,
step=1,
label="Top K"
)
generate_btn = gr.Button("Generate Text")
with gr.Column():
output = gr.Textbox(
label="Generated Output",
lines=10,
placeholder="Your generated text will appear here..."
)
generate_btn.click(
fn=generate_text,
inputs=[prompt, max_length, temperature, top_k],
outputs=output
)
return demo
# Load the model when the script is run
load_model()
# Create and launch the interface
demo = create_interface()
if __name__ == "__main__":
demo.launch()