translateEn2FR / app.py
amiguel's picture
Update app.py
26726c7 verified
import streamlit as st
import torch
import pandas as pd
import PyPDF2
import pickle
import os
from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig
from huggingface_hub import login, hf_hub_download
import time
from ch09util import subsequent_mask, create_model
# Device setup
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Set page configuration
st.set_page_config(
page_title="Translator Agent",
page_icon="🚀",
layout="centered"
)
# Model repository name
MODEL_NAME = "amiguel/custom-en2fr-transformer-v1"
# Retrieve Hugging Face token from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
st.error("🔐 Hugging Face token not found in environment variables. Please set HF_TOKEN in Space secrets.")
st.stop()
# Title with rocket emojis
st.title("🚀 English to French Translator 🚀")
# Configure Avatars
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
# Sidebar configuration (removed token input)
with st.sidebar:
st.header("Upload Documents 📂")
uploaded_file = st.file_uploader(
"Choose a PDF or XLSX file to translate",
type=["pdf", "xlsx"],
label_visibility="collapsed"
)
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# File processing function
@st.cache_data
def process_file(uploaded_file):
if uploaded_file is None:
return ""
try:
if uploaded_file.type == "application/pdf":
pdf_reader = PyPDF2.PdfReader(uploaded_file)
return "\n".join([page.extract_text() for page in pdf_reader.pages])
elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
df = pd.read_excel(uploaded_file)
return df.to_markdown()
except Exception as e:
st.error(f"📄 Error processing file: {str(e)}")
return ""
# Custom model loading function
@st.cache_resource
def load_model_and_resources():
try:
login(token=HF_TOKEN)
# Load tokenizer from the model repo
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
token=HF_TOKEN
)
# Define Transformer configuration
class TransformerConfig(PretrainedConfig):
model_type = "custom_transformer"
def __init__(self, src_vocab_size=11055, tgt_vocab_size=11239, d_model=256, d_ff=1024, h=8, N=6, dropout=0.1, **kwargs):
super().__init__(**kwargs)
self.src_vocab_size = src_vocab_size
self.tgt_vocab_size = tgt_vocab_size
self.d_model = d_model
self.d_ff = d_ff
self.h = h
self.N = N
self.dropout = dropout
# Define Transformer model
class CustomTransformer(PreTrainedModel):
config_class = TransformerConfig
def __init__(self, config):
super().__init__(config)
self.model = create_model(
config.src_vocab_size,
config.tgt_vocab_size,
N=config.N,
d_model=config.d_model,
d_ff=config.d_ff,
h=config.h,
dropout=config.dropout
)
def forward(self, src, tgt, src_mask, tgt_mask, **kwargs):
return self.model(src, tgt, src_mask, tgt_mask)
# Load config with validation from the model repo
config_dict = TransformerConfig.from_pretrained(MODEL_NAME, token=HF_TOKEN).to_dict()
if "src_vocab_size" not in config_dict or "tgt_vocab_size" not in config_dict:
st.warning(
f"Config at {MODEL_NAME}/config.json is missing 'src_vocab_size' or 'tgt_vocab_size'. "
"Using defaults (11055, 11239). For accuracy, update the training script to save these values."
)
config = TransformerConfig()
else:
config = TransformerConfig(**config_dict)
# Initialize model on meta device and load weights explicitly
model = CustomTransformer(config)
weights_path = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors", token=HF_TOKEN)
from safetensors.torch import load_file
state_dict = load_file(weights_path)
model.load_state_dict(state_dict)
# Move model to the target device safely
if DEVICE == "cuda":
model = model.to_empty(device=DEVICE) # Move structure to GPU
model.load_state_dict(state_dict) # Reload weights on GPU
else:
model = model.to(DEVICE) # CPU can handle direct move after loading weights
model.eval()
# Load dictionaries from the model repo
dict_path = hf_hub_download(repo_id=MODEL_NAME, filename="dict.p", token=HF_TOKEN)
with open(dict_path, "rb") as fb:
en_word_dict, en_idx_dict, fr_word_dict, fr_idx_dict = pickle.load(fb)
return model, tokenizer, en_word_dict, fr_word_dict, en_idx_dict, fr_idx_dict
except Exception as e:
st.error(f"🤖 Model loading failed: {str(e)}")
return None
# Custom streaming generation function
def custom_streaming_generate(input_text, model, tokenizer, en_word_dict, fr_word_dict, fr_idx_dict):
try:
model.eval()
PAD, UNK = 0, 1
tokenized_en = ["BOS"] + tokenizer.tokenize(input_text) + ["EOS"]
enidx = [en_word_dict.get(i, UNK) for i in tokenized_en]
src = torch.tensor(enidx).long().to(DEVICE).unsqueeze(0)
src_mask = (src != 0).unsqueeze(-2)
memory = model.model.encode(src, src_mask)
start_symbol = fr_word_dict["BOS"]
ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
for _ in range(100):
out = model.model.decode(memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data))
prob = model.model.generator(out[:, -1])
_, next_word = torch.max(prob, dim=1)
next_word = next_word.data[0]
sym = fr_idx_dict.get(next_word, "UNK")
if sym != "EOS":
token = sym.replace("</w>", " ")
for x in '''?:;.,'("-!&)%''':
token = token.replace(f" {x}", f"{x}")
yield token
else:
break
ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
# Yield a final empty token to ensure completion
yield ""
except Exception as e:
raise Exception(f"Generation error: {str(e)}")
# Display chat messages
for message in st.session_state.messages:
try:
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
except:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Chat input handling
if prompt := st.chat_input("Enter text to translate into French..."):
# Load model and resources if not already loaded
if "model" not in st.session_state:
model_data = load_model_and_resources()
if model_data is None:
st.error("Failed to load model. Please check the HF_TOKEN in Space secrets and try again.")
st.stop()
st.session_state.model, st.session_state.tokenizer, \
st.session_state.en_word_dict, st.session_state.fr_word_dict, \
st.session_state.en_idx_dict, st.session_state.fr_idx_dict = model_data
model = st.session_state.model
tokenizer = st.session_state.tokenizer
en_word_dict = st.session_state.en_word_dict
fr_word_dict = st.session_state.fr_word_dict
fr_idx_dict = st.session_state.fr_idx_dict
# Add user message
with st.chat_message("user", avatar=USER_AVATAR):
st.markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
# Process file or use prompt directly
file_context = process_file(uploaded_file)
input_text = file_context if file_context else prompt
# Generate translation with streaming
if model and tokenizer:
try:
with st.chat_message("assistant", avatar=BOT_AVATAR):
start_time = time.time()
# Create a placeholder for streaming output
response_container = st.empty()
full_response = ""
# Stream translation tokens
for token in custom_streaming_generate(
input_text, model, tokenizer, en_word_dict, fr_word_dict, fr_idx_dict
):
if token: # Only append non-empty tokens
full_response += token
response_container.markdown(full_response)
# Calculate performance metrics
end_time = time.time()
input_tokens = len(tokenizer(input_text)["input_ids"])
output_tokens = len(tokenizer(full_response)["input_ids"])
speed = output_tokens / (end_time - start_time) if (end_time - start_time) > 0 else 0
# Calculate costs (hypothetical pricing model)
input_cost = (input_tokens / 1000000) * 5 # $5 per million input tokens
output_cost = (output_tokens / 1000000) * 15 # $15 per million output tokens
total_cost_usd = input_cost + output_cost
total_cost_aoa = total_cost_usd * 1160 # Convert to AOA (Angolan Kwanza)
# Display metrics
st.caption(
f"🤖 Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
f"🕒 Speed: {speed:.1f}t/s | 💰 Cost (USD): ${total_cost_usd:.4f} | "
f"💵 Cost (AOA): {total_cost_aoa:.4f}"
)
# Store the full response in chat history
st.session_state.messages.append({"role": "assistant", "content": full_response})
except Exception as e:
st.error(f"⚡ Translation error: {str(e)}")
else:
st.error("🤖 Model not loaded!")