|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
import json |
|
import os |
|
from pathlib import Path |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
def generate_square_subsequent_mask(sz): |
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) |
|
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) |
|
return mask |
|
|
|
class PositionalEncoding(nn.Module): |
|
def __init__(self, max_len, d_model, dropout=0.1): |
|
super(PositionalEncoding, self).__init__() |
|
self.dropout = nn.Dropout(p=dropout) |
|
pe = torch.zeros(max_len, d_model) |
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
pe = pe.unsqueeze(0) |
|
self.register_buffer('pe', pe) |
|
|
|
def forward(self, x): |
|
x = x + self.pe[:, :x.size(1)] |
|
return self.dropout(x) |
|
|
|
class TextGen(nn.Module): |
|
def __init__(self, vocab_size, embed_dim, num_layers, num_heads, sequence_length): |
|
super(TextGen, self).__init__() |
|
self.pos_encoder = PositionalEncoding(max_len=sequence_length, d_model=embed_dim) |
|
self.emb = nn.Embedding(vocab_size, embed_dim) |
|
self.decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True) |
|
self.decoder = nn.TransformerDecoder(decoder_layer=self.decoder_layer, num_layers=num_layers) |
|
self.linear = nn.Linear(embed_dim, vocab_size) |
|
self.dropout = nn.Dropout(0.2) |
|
|
|
def forward(self, x): |
|
emb = self.emb(x) |
|
input_mask = generate_square_subsequent_mask(x.size(1)).to(x.device) |
|
x = self.pos_encoder(emb) |
|
x = self.decoder(x, memory=x, tgt_mask=input_mask, memory_mask=input_mask) |
|
x = self.dropout(x) |
|
out = self.linear(x) |
|
return out |
|
|
|
def load_model(): |
|
|
|
model_dir = Path("/home/user/app/") |
|
|
|
|
|
with open(model_dir / 'model_config.json', 'r') as f: |
|
config = json.load(f) |
|
|
|
|
|
with open(model_dir / 'word_to_int.json', 'r', encoding='utf-8') as f: |
|
word_to_int = json.load(f) |
|
with open(model_dir / 'int_to_word.json', 'r', encoding='utf-8') as f: |
|
int_to_word = json.load(f) |
|
|
|
|
|
model = TextGen( |
|
vocab_size=config['vocab_size'], |
|
embed_dim=config['embed_dim'], |
|
num_layers=config['num_layers'], |
|
num_heads=config['num_heads'], |
|
sequence_length=config['sequence_length'] |
|
).to(device) |
|
|
|
|
|
model.load_state_dict(torch.load(model_dir / 'model.pt', map_location=device)) |
|
model.eval() |
|
|
|
return model, word_to_int, int_to_word, config['sequence_length'] |
|
|
|
@torch.no_grad() |
|
def generate_text(model, prompt, word_to_int, int_to_word, sequence_length, max_length=100, temperature=1.0): |
|
model.eval() |
|
words = prompt.split() |
|
current_seq = torch.LongTensor([word_to_int.get(w, 0) for w in words]).unsqueeze(0).to(device) |
|
|
|
for _ in range(max_length): |
|
if current_seq.size(1) > sequence_length: |
|
current_seq = current_seq[:, -sequence_length:] |
|
|
|
output = model(current_seq) |
|
next_token_logits = output[:, -1, :] / temperature |
|
next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1) |
|
|
|
current_seq = torch.cat([current_seq, next_token], dim=1) |
|
next_word = int_to_word.get(str(next_token.item()), "<UNK>") |
|
words.append(next_word) |
|
|
|
if next_word == ".": |
|
break |
|
|
|
return " ".join(words) |
|
|
|
|
|
model, word_to_int, int_to_word, sequence_length = load_model() |
|
|
|
def generate_kashmiri_text(prompt, max_length, temperature): |
|
try: |
|
generated = generate_text( |
|
model, |
|
prompt, |
|
word_to_int, |
|
int_to_word, |
|
sequence_length, |
|
max_length=max_length, |
|
temperature=temperature |
|
) |
|
return generated |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=generate_kashmiri_text, |
|
inputs=[ |
|
gr.Textbox(label="Enter Kashmiri Text Prompt", placeholder="دِتم مصمت۔یم بگُل غلام چھُ آں تس اکھ حمزہ گویی"), |
|
gr.Slider(minimum=10, maximum=200, value=100, step=10, label="Maximum Length"), |
|
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature") |
|
], |
|
outputs=gr.Textbox(label="Generated Text"), |
|
title="Kashmiri Text Generation", |
|
description="Generate Kashmiri text using a transformer-based model. Enter your prompt in Kashmiri script.", |
|
examples=[["دِتم مصمت۔یم بگُل غلام چھُ آں تس اکھ حمزہ گویی", 100, 1.0]] |
|
) |
|
|
|
|
|
iface.launch() |