|
import gradio as gr |
|
import os |
|
from mtranslate import translate |
|
import requests |
|
|
|
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN") |
|
indochat_api = 'https://cahya-indonesian-whisperer.hf.space/api/text-generator/v1' |
|
indochat_api_auth_token = os.getenv("INDOCHAT_API_AUTH_TOKEN", "") |
|
|
|
def get_answer(user_input, decoding_method, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha): |
|
print(user_input, decoding_method, top_k, top_p, temperature, repetition_penalty, penalty_alpha) |
|
headers = {'Authorization': 'Bearer ' + indochat_api_auth_token} |
|
data = { |
|
"model_name": "indochat-tiny", |
|
"text": user_input, |
|
"min_length": len(user_input) + 20, |
|
"max_length": 200, |
|
"decoding_method": decoding_method, |
|
"num_beams": num_beams, |
|
"top_k": top_k, |
|
"top_p": top_p, |
|
"temperature": temperature, |
|
"seed": -1, |
|
"repetition_penalty": repetition_penalty, |
|
"penalty_alpha": penalty_alpha |
|
} |
|
r = requests.post(indochat_api, headers=headers, data=data) |
|
if r.status_code == 200: |
|
result = r.json() |
|
answer = result["generated_text"] |
|
user_input_en = translate(user_input, "en", "id") |
|
answer_en = translate(answer, "en", "id") |
|
return [(f"{user_input}\n", None), (answer, "")], \ |
|
[(f"{user_input_en}\n", None), (answer_en, "")] |
|
else: |
|
return "Error: " + r.text |
|
|
|
|
|
css = """ |
|
#answer_id span {white-space: pre-line} |
|
#answer_id span.label {display: none} |
|
#answer_en span {white-space: pre-line} |
|
#answer_en span.label {display: none} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
with gr.Row(): |
|
gr.Markdown("""## IndoChat |
|
|
|
A Prove of Concept of a multilingual Chatbot (in this case a bilingual, English and Indonesian), fine-tuned with |
|
multilingual instructions dataset. The base model is a GPT2-Medium (340M params) which was pretrained with 75GB |
|
of Indonesian and English dataset, where English part is only less than 1% of the whole dataset. |
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
user_input = gr.inputs.Textbox(placeholder="", |
|
label="Ask me something in Indonesian or English", |
|
default="Bagaimana cara mendidik anak supaya tidak berbohong?") |
|
decoding_method = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"], |
|
default="Sampling", label="Decoding Method") |
|
num_beams = gr.inputs.Slider(label="Number of beams for beam search", |
|
default=1, minimum=1, maximum=10, step=1) |
|
top_k = gr.inputs.Slider(label="Top K", |
|
default=30, maximum=50, minimum=1, step=1) |
|
top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0) |
|
temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0) |
|
repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0) |
|
penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search", |
|
default=0.5, step=0.05, minimum=0.05, maximum=1.0) |
|
with gr.Row(): |
|
button_generate_story = gr.Button("Submit") |
|
with gr.Column(): |
|
|
|
generated_answer = gr.HighlightedText( |
|
elem_id="answer_id", |
|
label="Generated Text", |
|
combine_adjacent=True, |
|
css="#htext span {white-space: pre-line}", |
|
).style(color_map={"": "blue", "-": "green"}) |
|
generated_answer_en = gr.HighlightedText( |
|
elem_id="answer_en", |
|
label="Translation", |
|
combine_adjacent=True, |
|
).style(color_map={"": "blue", "-": "green"}) |
|
with gr.Row(): |
|
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)") |
|
|
|
button_generate_story.click(get_answer, |
|
inputs=[user_input, decoding_method, num_beams, top_k, top_p, temperature, |
|
repetition_penalty, penalty_alpha], |
|
outputs=[generated_answer, generated_answer_en]) |
|
|
|
demo.launch(enable_queue=False) |