Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import pandas as pd | |
import PyPDF2 | |
import pickle | |
import os | |
from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig | |
from huggingface_hub import login, hf_hub_download | |
import time | |
from ch09util import subsequent_mask, create_model | |
# Device setup | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Set page configuration | |
st.set_page_config( | |
page_title="Translator Agent", | |
page_icon="🚀", | |
layout="centered" | |
) | |
# Model repository name | |
MODEL_NAME = "amiguel/custom-en2fr-transformer-v1" | |
# Retrieve Hugging Face token from environment variable | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
if not HF_TOKEN: | |
st.error("🔐 Hugging Face token not found in environment variables. Please set HF_TOKEN in Space secrets.") | |
st.stop() | |
# Title with rocket emojis | |
st.title("🚀 English to French Translator 🚀") | |
# Configure Avatars | |
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png" | |
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg" | |
# Sidebar configuration (removed token input) | |
with st.sidebar: | |
st.header("Upload Documents 📂") | |
uploaded_file = st.file_uploader( | |
"Choose a PDF or XLSX file to translate", | |
type=["pdf", "xlsx"], | |
label_visibility="collapsed" | |
) | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# File processing function | |
def process_file(uploaded_file): | |
if uploaded_file is None: | |
return "" | |
try: | |
if uploaded_file.type == "application/pdf": | |
pdf_reader = PyPDF2.PdfReader(uploaded_file) | |
return "\n".join([page.extract_text() for page in pdf_reader.pages]) | |
elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": | |
df = pd.read_excel(uploaded_file) | |
return df.to_markdown() | |
except Exception as e: | |
st.error(f"📄 Error processing file: {str(e)}") | |
return "" | |
# Custom model loading function | |
def load_model_and_resources(): | |
try: | |
login(token=HF_TOKEN) | |
# Load tokenizer from the model repo | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_NAME, | |
token=HF_TOKEN | |
) | |
# Define Transformer configuration | |
class TransformerConfig(PretrainedConfig): | |
model_type = "custom_transformer" | |
def __init__(self, src_vocab_size=11055, tgt_vocab_size=11239, d_model=256, d_ff=1024, h=8, N=6, dropout=0.1, **kwargs): | |
super().__init__(**kwargs) | |
self.src_vocab_size = src_vocab_size | |
self.tgt_vocab_size = tgt_vocab_size | |
self.d_model = d_model | |
self.d_ff = d_ff | |
self.h = h | |
self.N = N | |
self.dropout = dropout | |
# Define Transformer model | |
class CustomTransformer(PreTrainedModel): | |
config_class = TransformerConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = create_model( | |
config.src_vocab_size, | |
config.tgt_vocab_size, | |
N=config.N, | |
d_model=config.d_model, | |
d_ff=config.d_ff, | |
h=config.h, | |
dropout=config.dropout | |
) | |
def forward(self, src, tgt, src_mask, tgt_mask, **kwargs): | |
return self.model(src, tgt, src_mask, tgt_mask) | |
# Load config with validation from the model repo | |
config_dict = TransformerConfig.from_pretrained(MODEL_NAME, token=HF_TOKEN).to_dict() | |
if "src_vocab_size" not in config_dict or "tgt_vocab_size" not in config_dict: | |
st.warning( | |
f"Config at {MODEL_NAME}/config.json is missing 'src_vocab_size' or 'tgt_vocab_size'. " | |
"Using defaults (11055, 11239). For accuracy, update the training script to save these values." | |
) | |
config = TransformerConfig() | |
else: | |
config = TransformerConfig(**config_dict) | |
# Initialize model on meta device and load weights explicitly | |
model = CustomTransformer(config) | |
weights_path = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors", token=HF_TOKEN) | |
from safetensors.torch import load_file | |
state_dict = load_file(weights_path) | |
model.load_state_dict(state_dict) | |
# Move model to the target device safely | |
if DEVICE == "cuda": | |
model = model.to_empty(device=DEVICE) # Move structure to GPU | |
model.load_state_dict(state_dict) # Reload weights on GPU | |
else: | |
model = model.to(DEVICE) # CPU can handle direct move after loading weights | |
model.eval() | |
# Load dictionaries from the model repo | |
dict_path = hf_hub_download(repo_id=MODEL_NAME, filename="dict.p", token=HF_TOKEN) | |
with open(dict_path, "rb") as fb: | |
en_word_dict, en_idx_dict, fr_word_dict, fr_idx_dict = pickle.load(fb) | |
return model, tokenizer, en_word_dict, fr_word_dict, en_idx_dict, fr_idx_dict | |
except Exception as e: | |
st.error(f"🤖 Model loading failed: {str(e)}") | |
return None | |
# Custom streaming generation function | |
def custom_streaming_generate(input_text, model, tokenizer, en_word_dict, fr_word_dict, fr_idx_dict): | |
try: | |
model.eval() | |
PAD, UNK = 0, 1 | |
tokenized_en = ["BOS"] + tokenizer.tokenize(input_text) + ["EOS"] | |
enidx = [en_word_dict.get(i, UNK) for i in tokenized_en] | |
src = torch.tensor(enidx).long().to(DEVICE).unsqueeze(0) | |
src_mask = (src != 0).unsqueeze(-2) | |
memory = model.model.encode(src, src_mask) | |
start_symbol = fr_word_dict["BOS"] | |
ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data) | |
for _ in range(100): | |
out = model.model.decode(memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)) | |
prob = model.model.generator(out[:, -1]) | |
_, next_word = torch.max(prob, dim=1) | |
next_word = next_word.data[0] | |
sym = fr_idx_dict.get(next_word, "UNK") | |
if sym != "EOS": | |
token = sym.replace("</w>", " ") | |
for x in '''?:;.,'("-!&)%''': | |
token = token.replace(f" {x}", f"{x}") | |
yield token | |
else: | |
break | |
ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1) | |
# Yield a final empty token to ensure completion | |
yield "" | |
except Exception as e: | |
raise Exception(f"Generation error: {str(e)}") | |
# Display chat messages | |
for message in st.session_state.messages: | |
try: | |
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR | |
with st.chat_message(message["role"], avatar=avatar): | |
st.markdown(message["content"]) | |
except: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Chat input handling | |
if prompt := st.chat_input("Enter text to translate into French..."): | |
# Load model and resources if not already loaded | |
if "model" not in st.session_state: | |
model_data = load_model_and_resources() | |
if model_data is None: | |
st.error("Failed to load model. Please check the HF_TOKEN in Space secrets and try again.") | |
st.stop() | |
st.session_state.model, st.session_state.tokenizer, \ | |
st.session_state.en_word_dict, st.session_state.fr_word_dict, \ | |
st.session_state.en_idx_dict, st.session_state.fr_idx_dict = model_data | |
model = st.session_state.model | |
tokenizer = st.session_state.tokenizer | |
en_word_dict = st.session_state.en_word_dict | |
fr_word_dict = st.session_state.fr_word_dict | |
fr_idx_dict = st.session_state.fr_idx_dict | |
# Add user message | |
with st.chat_message("user", avatar=USER_AVATAR): | |
st.markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Process file or use prompt directly | |
file_context = process_file(uploaded_file) | |
input_text = file_context if file_context else prompt | |
# Generate translation with streaming | |
if model and tokenizer: | |
try: | |
with st.chat_message("assistant", avatar=BOT_AVATAR): | |
start_time = time.time() | |
# Create a placeholder for streaming output | |
response_container = st.empty() | |
full_response = "" | |
# Stream translation tokens | |
for token in custom_streaming_generate( | |
input_text, model, tokenizer, en_word_dict, fr_word_dict, fr_idx_dict | |
): | |
if token: # Only append non-empty tokens | |
full_response += token | |
response_container.markdown(full_response) | |
# Calculate performance metrics | |
end_time = time.time() | |
input_tokens = len(tokenizer(input_text)["input_ids"]) | |
output_tokens = len(tokenizer(full_response)["input_ids"]) | |
speed = output_tokens / (end_time - start_time) if (end_time - start_time) > 0 else 0 | |
# Calculate costs (hypothetical pricing model) | |
input_cost = (input_tokens / 1000000) * 5 # $5 per million input tokens | |
output_cost = (output_tokens / 1000000) * 15 # $15 per million output tokens | |
total_cost_usd = input_cost + output_cost | |
total_cost_aoa = total_cost_usd * 1160 # Convert to AOA (Angolan Kwanza) | |
# Display metrics | |
st.caption( | |
f"🤖 Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | " | |
f"🕒 Speed: {speed:.1f}t/s | 💰 Cost (USD): ${total_cost_usd:.4f} | " | |
f"💵 Cost (AOA): {total_cost_aoa:.4f}" | |
) | |
# Store the full response in chat history | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
except Exception as e: | |
st.error(f"⚡ Translation error: {str(e)}") | |
else: | |
st.error("🤖 Model not loaded!") |