import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import math import json import os from pathlib import Path # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def generate_square_subsequent_mask(sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask class PositionalEncoding(nn.Module): def __init__(self, max_len, d_model, dropout=0.1): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe) def forward(self, x): x = x +[:, :x.size(1)] return self.dropout(x) class TextGen(nn.Module): def __init__(self, vocab_size, embed_dim, num_layers, num_heads, sequence_length): super(TextGen, self).__init__() self.pos_encoder = PositionalEncoding(max_len=sequence_length, d_model=embed_dim) self.emb = nn.Embedding(vocab_size, embed_dim) self.decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True) self.decoder = nn.TransformerDecoder(decoder_layer=self.decoder_layer, num_layers=num_layers) self.linear = nn.Linear(embed_dim, vocab_size) self.dropout = nn.Dropout(0.2) def forward(self, x): emb = self.emb(x) input_mask = generate_square_subsequent_mask(x.size(1)).to(x.device) x = self.pos_encoder(emb) x = self.decoder(x, memory=x, tgt_mask=input_mask, memory_mask=input_mask) x = self.dropout(x) out = self.linear(x) return out def load_model(): # Get the directory where the model files are stored model_dir = Path("/home/user/app/") # This is the path in HF Spaces # Load configuration with open(model_dir / 'model_config.json', 'r') as f: config = json.load(f) # Load vocabularies with open(model_dir / 'word_to_int.json', 'r', encoding='utf-8') as f: word_to_int = json.load(f) with open(model_dir / 'int_to_word.json', 'r', encoding='utf-8') as f: int_to_word = json.load(f) # Initialize model model = TextGen( vocab_size=config['vocab_size'], embed_dim=config['embed_dim'], num_layers=config['num_layers'], num_heads=config['num_heads'], sequence_length=config['sequence_length'] ).to(device) # Load model weights model.load_state_dict(torch.load(model_dir / '', map_location=device)) model.eval() return model, word_to_int, int_to_word, config['sequence_length'] @torch.no_grad() def generate_text(model, prompt, word_to_int, int_to_word, sequence_length, max_length=100, temperature=1.0): model.eval() words = prompt.split() current_seq = torch.LongTensor([word_to_int.get(w, 0) for w in words]).unsqueeze(0).to(device) for _ in range(max_length): if current_seq.size(1) > sequence_length: current_seq = current_seq[:, -sequence_length:] output = model(current_seq) next_token_logits = output[:, -1, :] / temperature next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1) current_seq =[current_seq, next_token], dim=1) next_word = int_to_word.get(str(next_token.item()), "") words.append(next_word) if next_word == ".": break return " ".join(words) # Load the model and vocabularies globally model, word_to_int, int_to_word, sequence_length = load_model() def generate_kashmiri_text(prompt, max_length, temperature): try: generated = generate_text( model, prompt, word_to_int, int_to_word, sequence_length, max_length=max_length, temperature=temperature ) return generated except Exception as e: return f"Error: {str(e)}" # Create Gradio interface iface = gr.Interface( fn=generate_kashmiri_text, inputs=[ gr.Textbox(label="Enter Kashmiri Text Prompt", placeholder="دِتم مصمت۔یم بگُل غلام چھُ آں تس اکھ حمزہ گویی"), gr.Slider(minimum=10, maximum=200, value=100, step=10, label="Maximum Length"), gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature") ], outputs=gr.Textbox(label="Generated Text"), title="Kashmiri Text Generation", description="Generate Kashmiri text using a transformer-based model. Enter your prompt in Kashmiri script.", examples=[["دِتم مصمت۔یم بگُل غلام چھُ آں تس اکھ حمزہ گویی", 100, 1.0]] ) # Launch the interface iface.launch()