Chitti-v1 / app.py
Dhahlan2000's picture
Update app.py
5b1ccca verified
raw
history blame
4.54 kB
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
from aksharamukha import transliterate
import torch
from dotenv import load_dotenv
import os
load_dotenv()
access_token = os.getenv('ACCESS_TOKEN')
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load translation models and tokenizers
trans_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M").to(device)
eng_trans_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
translator = pipeline('translation', model=trans_model, tokenizer=eng_trans_tokenizer, src_lang="eng_Latn", tgt_lang='sin_Sinh', max_length=400, device=device)
sin_trans_model = AutoModelForSeq2SeqLM.from_pretrained("thilina/mt5-sinhalese-english").to(device)
si_trans_tokenizer = AutoTokenizer.from_pretrained("thilina/mt5-sinhalese-english", use_fast=False)
singlish_pipe = pipeline("text2text-generation", model="Dhahlan2000/Simple_Translation-model-for-GPT-v14")
# Translation functions
def translate_Singlish_to_sinhala(text):
translated_text = singlish_pipe(f"translate Singlish to Sinhala: {text}", clean_up_tokenization_spaces=False)[0]['generated_text']
return translated_text
def translate_english_to_sinhala(text):
parts = text.split("\n")
translated_parts = [translator(part, clean_up_tokenization_spaces=False)[0]['translation_text'] for part in parts]
return "\n".join(translated_parts).replace("ප් රභූවරුන්", "")
def translate_sinhala_to_english(text):
parts = text.split("\n")
translated_parts = []
for part in parts:
inputs = si_trans_tokenizer(part.strip(), return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
outputs = sin_trans_model.generate(**inputs)
translated_part = si_trans_tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
translated_parts.append(translated_part)
return "\n".join(translated_parts)
def transliterate_from_sinhala(text):
latin_text = transliterate.process('Sinhala', 'Velthuis', text).replace('.', '').replace('*', '').replace('"', '').lower()
return latin_text
def transliterate_to_sinhala(text):
return transliterate.process('Velthuis', 'Sinhala', text)
# Placeholder for conversation model loading and pipeline setup
# pipe1 = pipeline("text-generation", model="microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True)
# interface = gr.Interface.load("huggingface/microsoft/Phi-3-mini-4k-instruct")
API_URL = "https://api-inference.huggingface.co/models/microsoft/Phi-3-mini-4k-instruct"
headers = {"Authorization": f"Bearer {access_token}"}
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
# def conversation_predict(text):
# return interface([text])[0]
def ai_predicted(user_input):
if user_input.lower() == 'exit':
return "Goodbye!"
user_input = translate_Singlish_to_sinhala(user_input)
user_input = transliterate_to_sinhala(user_input)
user_input = translate_sinhala_to_english(user_input)
ai_response = query({
"inputs": user_input,
})
# ai_response = conversation_predict(user_input)
ai_response_lines = ai_response.split("</s>")
response = translate_english_to_sinhala(ai_response_lines[-1])
response = transliterate_from_sinhala(response)
return response
# Gradio Interface
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ai_predicted(message)
yield response
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch(share=True)