File size: 4,514 Bytes
69cc5ab fb69876 fe4be4e 832a8c0 69cc5ab fb69876 b1c4844 832a8c0 69cc5ab 832a8c0 b1c4844 832a8c0 9afde5f 832a8c0 50d93bb 832a8c0 50d93bb 9aae25d 50d93bb e55e013 50d93bb 832a8c0 07c107f a4b0cb5 50d93bb 8454eb5 50d93bb 38f57a9 3284509 3c15525 38f57a9 3aa662c 38f57a9 fe4be4e 3284509 fe4be4e 50d93bb 1ce5f7f 832a8c0 1ce5f7f 50d93bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import gradio as gr
import os
from mtranslate import translate
import requests
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
indochat_api = 'https://cahya-indonesian-whisperer.hf.space/api/text-generator/v1'
indochat_api_auth_token = os.getenv("INDOCHAT_API_AUTH_TOKEN", "")
def get_answer(user_input, decoding_method, num_beams, top_k, top_p, temperature, repetition_penalty, penalty_alpha):
print(user_input, decoding_method, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
headers = {'Authorization': 'Bearer ' + indochat_api_auth_token}
data = {
"model_name": "indochat-tiny",
"text": user_input,
"min_length": len(user_input) + 20,
"max_length": 200,
"decoding_method": decoding_method,
"num_beams": num_beams,
"top_k": top_k,
"top_p": top_p,
"temperature": temperature,
"seed": -1,
"repetition_penalty": repetition_penalty,
"penalty_alpha": penalty_alpha
}
r = requests.post(indochat_api, headers=headers, data=data)
if r.status_code == 200:
result = r.json()
answer = result["generated_text"]
user_input_en = translate(user_input, "en", "id")
answer_en = translate(answer, "en", "id")
return [(f"{user_input}\n", None), (answer, "")], \
[(f"{user_input_en}\n", None), (answer_en, "")]
else:
return "Error: " + r.text
css = """
#answer_id span {white-space: pre-line}
#answer_id span.label {display: none}
#answer_en span {white-space: pre-line}
#answer_en span.label {display: none}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
gr.Markdown("""## IndoChat
A Prove of Concept of a multilingual Chatbot (in this case a bilingual, English and Indonesian), fine-tuned with
multilingual instructions dataset. The base model is a GPT2-Medium (340M params) which was pretrained with 75GB
of Indonesian and English dataset, where English part is only less than 1% of the whole dataset.
""")
with gr.Row():
with gr.Column():
user_input = gr.inputs.Textbox(placeholder="",
label="Ask me something in Indonesian or English",
default="Bagaimana cara mendidik anak supaya tidak berbohong?")
decoding_method = gr.inputs.Dropdown(["Beam Search", "Sampling", "Contrastive Search"],
default="Sampling", label="Decoding Method")
num_beams = gr.inputs.Slider(label="Number of beams for beam search",
default=1, minimum=1, maximum=10, step=1)
top_k = gr.inputs.Slider(label="Top K",
default=30, maximum=50, minimum=1, step=1)
top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search",
default=0.5, step=0.05, minimum=0.05, maximum=1.0)
with gr.Row():
button_generate_story = gr.Button("Submit")
with gr.Column():
# generated_answer = gr.Textbox()
generated_answer = gr.HighlightedText(
elem_id="answer_id",
label="Generated Text",
combine_adjacent=True,
css="#htext span {white-space: pre-line}",
).style(color_map={"": "blue", "-": "green"})
generated_answer_en = gr.HighlightedText(
elem_id="answer_en",
label="Translation",
combine_adjacent=True,
).style(color_map={"": "blue", "-": "green"})
with gr.Row():
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cahya_indochat)")
button_generate_story.click(get_answer,
inputs=[user_input, decoding_method, num_beams, top_k, top_p, temperature,
repetition_penalty, penalty_alpha],
outputs=[generated_answer, generated_answer_en])
demo.launch(enable_queue=False) |