Omarrran's picture
Update app.py
f9e31a0 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import json
import os
from pathlib import Path
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def generate_square_subsequent_mask(sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
class PositionalEncoding(nn.Module):
def __init__(self, max_len, d_model, dropout=0.1):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class TextGen(nn.Module):
def __init__(self, vocab_size, embed_dim, num_layers, num_heads, sequence_length):
super(TextGen, self).__init__()
self.pos_encoder = PositionalEncoding(max_len=sequence_length, d_model=embed_dim)
self.emb = nn.Embedding(vocab_size, embed_dim)
self.decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
self.decoder = nn.TransformerDecoder(decoder_layer=self.decoder_layer, num_layers=num_layers)
self.linear = nn.Linear(embed_dim, vocab_size)
self.dropout = nn.Dropout(0.2)
def forward(self, x):
emb = self.emb(x)
input_mask = generate_square_subsequent_mask(x.size(1)).to(x.device)
x = self.pos_encoder(emb)
x = self.decoder(x, memory=x, tgt_mask=input_mask, memory_mask=input_mask)
x = self.dropout(x)
out = self.linear(x)
return out
def load_model():
# Get the directory where the model files are stored
model_dir = Path("/home/user/app/") # This is the path in HF Spaces
# Load configuration
with open(model_dir / 'model_config.json', 'r') as f:
config = json.load(f)
# Load vocabularies
with open(model_dir / 'word_to_int.json', 'r', encoding='utf-8') as f:
word_to_int = json.load(f)
with open(model_dir / 'int_to_word.json', 'r', encoding='utf-8') as f:
int_to_word = json.load(f)
# Initialize model
model = TextGen(
vocab_size=config['vocab_size'],
embed_dim=config['embed_dim'],
num_layers=config['num_layers'],
num_heads=config['num_heads'],
sequence_length=config['sequence_length']
).to(device)
# Load model weights
model.load_state_dict(torch.load(model_dir / 'model.pt', map_location=device))
model.eval()
return model, word_to_int, int_to_word, config['sequence_length']
@torch.no_grad()
def generate_text(model, prompt, word_to_int, int_to_word, sequence_length, max_length=100, temperature=1.0):
model.eval()
words = prompt.split()
current_seq = torch.LongTensor([word_to_int.get(w, 0) for w in words]).unsqueeze(0).to(device)
for _ in range(max_length):
if current_seq.size(1) > sequence_length:
current_seq = current_seq[:, -sequence_length:]
output = model(current_seq)
next_token_logits = output[:, -1, :] / temperature
next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1)
current_seq = torch.cat([current_seq, next_token], dim=1)
next_word = int_to_word.get(str(next_token.item()), "<UNK>")
words.append(next_word)
if next_word == ".":
break
return " ".join(words)
# Load the model and vocabularies globally
model, word_to_int, int_to_word, sequence_length = load_model()
def generate_kashmiri_text(prompt, max_length, temperature):
try:
generated = generate_text(
model,
prompt,
word_to_int,
int_to_word,
sequence_length,
max_length=max_length,
temperature=temperature
)
return generated
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=generate_kashmiri_text,
inputs=[
gr.Textbox(label="Enter Kashmiri Text Prompt", placeholder="دِتم مصمت۔یم بگُل غلام چھُ آں تس اکھ حمزہ گویی"),
gr.Slider(minimum=10, maximum=200, value=100, step=10, label="Maximum Length"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
],
outputs=gr.Textbox(label="Generated Text"),
title="Kashmiri Text Generation",
description="Generate Kashmiri text using a transformer-based model. Enter your prompt in Kashmiri script.",
examples=[["دِتم مصمت۔یم بگُل غلام چھُ آں تس اکھ حمزہ گویی", 100, 1.0]]
)
# Launch the interface
iface.launch()