Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
from aksharamukha import transliterate | |
import torch | |
from dotenv import load_dotenv | |
import os | |
load_dotenv() | |
access_token = os.getenv('ACCESS_TOKEN') | |
# Set up device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load translation models and tokenizers | |
trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M").to(device) | |
eng_trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
translator = pipeline('translation', model=trans_model, tokenizer=eng_trans_tokenizer, src_lang="eng_Latn", tgt_lang='sin_Sinh', max_length=400, device=device) | |
sin_trans_model = AutoModelForSeq2SeqLM.from_pretrained("thilina/mt5-sinhalese-english").to(device) | |
si_trans_tokenizer = AutoTokenizer.from_pretrained("thilina/mt5-sinhalese-english", use_fast=False) | |
singlish_pipe = pipeline("text2text-generation", model="Dhahlan2000/Simple_Translation-model-for-GPT-v14") | |
# Translation functions | |
def translate_Singlish_to_sinhala(text): | |
translated_text = singlish_pipe(f"translate Singlish to Sinhala: {text}", clean_up_tokenization_spaces=False)[0]['generated_text'] | |
return translated_text | |
def translate_english_to_sinhala(text): | |
parts = text.split("\n") | |
translated_parts = [translator(part, clean_up_tokenization_spaces=False)[0]['translation_text'] for part in parts] | |
return "\n".join(translated_parts).replace("ප් රභූවරුන්", "") | |
def translate_sinhala_to_english(text): | |
parts = text.split("\n") | |
translated_parts = [] | |
for part in parts: | |
inputs = si_trans_tokenizer(part.strip(), return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) | |
outputs = sin_trans_model.generate(**inputs) | |
translated_part = si_trans_tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
translated_parts.append(translated_part) | |
return "\n".join(translated_parts) | |
def transliterate_from_sinhala(text): | |
latin_text = transliterate.process('Sinhala', 'Velthuis', text).replace('.', '').replace('*', '').replace('"', '').lower() | |
return latin_text | |
def transliterate_to_sinhala(text): | |
return transliterate.process('Velthuis', 'Sinhala', text) | |
# Placeholder for conversation model loading and pipeline setup | |
# pipe1 = pipeline("text-generation", model="microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True) | |
# interface = gr.Interface.load("huggingface/microsoft/Phi-3-mini-4k-instruct") | |
API_URL = "https://api-inference.huggingface.co/models/microsoft/Phi-3-mini-4k-instruct" | |
headers = {"Authorization": f"Bearer {access_token}"} | |
def query(payload): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
# def conversation_predict(text): | |
# return interface([text])[0] | |
def ai_predicted(user_input): | |
if user_input.lower() == 'exit': | |
return "Goodbye!" | |
user_input = translate_Singlish_to_sinhala(user_input) | |
user_input = transliterate_to_sinhala(user_input) | |
user_input = translate_sinhala_to_english(user_input) | |
ai_response = query({ | |
"inputs": user_input, | |
}) | |
# ai_response = conversation_predict(user_input) | |
ai_response_lines = ai_response.split("</s>") | |
response = translate_english_to_sinhala(ai_response_lines[-1]) | |
response = transliterate_from_sinhala(response) | |
return response | |
# Gradio Interface | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
response = ai_predicted(message) | |
yield response | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |