Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import json | |
from transformers import GPT2Tokenizer | |
from safetensors.torch import load_file | |
from transformers import GPT2Config as GPTConfig | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from dataclasses import dataclass | |
# Define the GPTConfig class with filtering | |
class GPTConfig: | |
def __init__(self, n_embd, n_head, n_layer, vocab_size): | |
self.n_embd = n_embd | |
self.n_head = n_head | |
self.n_layer = n_layer | |
self.vocab_size = vocab_size | |
def from_dict(cls, config_dict): | |
# Define the expected keys | |
expected_keys = {'n_embd', 'n_head', 'n_layer', 'vocab_size'} | |
# Filter out unexpected keys | |
filtered_dict = {key: value for key, value in config_dict.items() if key in expected_keys} | |
return cls(**filtered_dict) | |
# Define the GPT class | |
class GPT(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
# Initialize the embedding layer | |
self.embedding = nn.Embedding(config.vocab_size, config.n_embd) | |
# Initialize the Transformer decoder | |
decoder_layer = nn.TransformerDecoderLayer(d_model=config.n_embd, nhead=config.n_head, dim_feedforward=config.n_embd, dropout=0.1) | |
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=config.n_layer) | |
# Initialize the language model head | |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size) | |
def forward(self, input_ids): | |
# Embed the input tokens | |
input_embeddings = self.embedding(input_ids) | |
# Transpose the input to match the expected shape for TransformerDecoder | |
input_embeddings = input_embeddings.transpose(0, 1) | |
# Pass through the Transformer decoder | |
transformer_output = self.transformer(input_embeddings, input_embeddings) | |
# Transpose back to the original shape | |
transformer_output = transformer_output.transpose(0, 1) | |
# Get the logits from the language model head | |
logits = self.lm_head(transformer_output) | |
return logits | |
def generate(self, input_ids, max_new_tokens, temperature, top_k): | |
# Implement the text generation logic | |
output_ids = input_ids | |
for _ in range(max_new_tokens): | |
logits = self.forward(output_ids[:, -1:]) | |
logits = logits / temperature | |
probs = F.softmax(logits, dim=-1) | |
# Ensure probs is 2D | |
if probs.dim() == 3: | |
probs = probs.squeeze(0) # Remove the batch dimension if it exists | |
top_k_probs, top_k_indices = torch.topk(probs, k=top_k) | |
# Ensure top_k_probs is 2D | |
if top_k_probs.dim() == 1: | |
top_k_probs = top_k_probs.unsqueeze(0) | |
next_token = torch.multinomial(top_k_probs, num_samples=1) | |
next_token = top_k_indices.gather(-1, next_token) | |
# Ensure next_token is 2D | |
if next_token.dim() == 1: | |
next_token = next_token.unsqueeze(0) | |
output_ids = torch.cat([output_ids, next_token], dim=1) | |
return output_ids | |
# Initialize global variables | |
model = None | |
tokenizer = None | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_model(): | |
"""Load the Leap0 model and tokenizer.""" | |
global model, tokenizer | |
try: | |
# Paths to config and model files | |
config_path = "config.json" | |
model_path = "model.safetensors" | |
print(f"Loading configuration from {config_path}...") | |
# Load the configuration | |
with open(config_path, "r") as f: | |
config_dict = json.load(f) | |
print("Configuration loaded. Creating model config...") | |
config = GPTConfig.from_dict(config_dict) | |
print(f"Model config created: {config}") | |
print(f"Loading model weights from {model_path}...") | |
# Load the model weights | |
tensors = load_file(model_path) | |
print("Instantiating model...") | |
# Instantiate the model with the loaded config | |
model = GPT(config) | |
print("Loading weights into model...") | |
model.load_state_dict(tensors, strict=False) | |
model.to(device) | |
model.eval() | |
print("Loading tokenizer...") | |
# Load the tokenizer | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
print("Model and tokenizer loaded successfully") | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
raise | |
def generate_text(prompt, max_length=50, temperature=0.7, top_k=40): | |
"""Generate text based on the provided prompt.""" | |
if model is None or tokenizer is None: | |
load_model() | |
# Tokenize the input text | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
# Generate text | |
with torch.no_grad(): | |
output_ids = model.generate( | |
input_ids, | |
max_new_tokens=max_length, | |
temperature=temperature, | |
top_k=top_k | |
) | |
# Decode the output | |
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return output_text | |
# Create the Gradio interface | |
def create_interface(): | |
with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
gr.Markdown("# Leap0 Language Model") | |
gr.Markdown("A GPT-2 based model trained on the Tiny Stories dataset") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Enter your prompt", | |
placeholder="once upon a time in the village of", | |
lines=3 | |
) | |
with gr.Row(): | |
max_length = gr.Slider( | |
minimum=1, | |
maximum=200, | |
value=50, | |
step=1, | |
label="Max Length" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
) | |
top_k = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=40, | |
step=1, | |
label="Top K" | |
) | |
generate_btn = gr.Button("Generate Text") | |
with gr.Column(): | |
output = gr.Textbox( | |
label="Generated Output", | |
lines=10, | |
placeholder="Your generated text will appear here..." | |
) | |
generate_btn.click( | |
fn=generate_text, | |
inputs=[prompt, max_length, temperature, top_k], | |
outputs=output | |
) | |
return demo | |
# Load the model when the script is run | |
load_model() | |
# Create and launch the interface | |
demo = create_interface() | |
if __name__ == "__main__": | |
demo.launch() |