|
import os |
|
import spaces |
|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN_KAZLLM") |
|
|
|
MODELS = { |
|
"V-1: LLama-3.1-KazLLM-8B": { |
|
"model_name": "issai/LLama-3.1-KazLLM-1.0-8B", |
|
"tokenizer_name": "issai/LLama-3.1-KazLLM-1.0-8B", |
|
"duration": 120, |
|
"defaults": { |
|
"max_length": 100, |
|
"temperature": 0.7, |
|
"top_p": 0.9, |
|
"do_sample": True |
|
} |
|
}, |
|
"V-2: LLama-3.1-KazLLM-70B-AWQ4": { |
|
"model_name": "issai/LLama-3.1-KazLLM-1.0-70B-AWQ4", |
|
"tokenizer_name": "issai/LLama-3.1-KazLLM-1.0-70B-AWQ4", |
|
"duration": 180, |
|
"defaults": { |
|
"max_length": 150, |
|
"temperature": 0.8, |
|
"top_p": 0.95, |
|
"do_sample": True |
|
} |
|
} |
|
} |
|
|
|
LANGUAGES = { |
|
"Русский": { |
|
"title": "LLama-3.1 KazLLM с выбором модели и языка", |
|
"description": "Выберите модель, язык интерфейса, введите запрос и получите сгенерированный текст с использованием выбранной модели LLama-3.1 KazLLM.", |
|
"select_model": "Выберите модель", |
|
"enter_prompt": "Введите запрос", |
|
"max_length": "Максимальная длина текста", |
|
"temperature": "Креативность (Температура)", |
|
"top_p": "Top-p (ядро вероятности)", |
|
"do_sample": "Использовать выборку (Do Sample)", |
|
"generate_button": "Сгенерировать текст", |
|
"generated_text": "Сгенерированный текст", |
|
"language": "Выберите язык интерфейса" |
|
}, |
|
"Қазақша": { |
|
"title": "LLama-3.1 KazLLM модель таңдауы және тілін қолдау", |
|
"description": "Модельді, интерфейс тілін таңдаңыз, сұрауыңызды енгізіңіз және таңдалған LLama-3.1 KazLLM моделін пайдаланып генерирленген мәтінді алыңыз.", |
|
"select_model": "Модельді таңдаңыз", |
|
"enter_prompt": "Сұрауыңызды енгізіңіз", |
|
"max_length": "Мәтіннің максималды ұзындығы", |
|
"temperature": "Шығармашылық (Температура)", |
|
"top_p": "Top-p (ықтималдық негізі)", |
|
"do_sample": "Үлгіні қолдану (Do Sample)", |
|
"generate_button": "Мәтінді генерациялау", |
|
"generated_text": "Генерацияланған мәтін", |
|
"language": "Интерфейс тілін таңдаңыз" |
|
} |
|
} |
|
|
|
loaded_models = {} |
|
loaded_tokenizers = {} |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
def load_model_and_tokenizer(model_key): |
|
if model_key not in loaded_models: |
|
model_info = MODELS[model_key] |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_info["model_name"], |
|
token=HF_TOKEN |
|
).to(device) |
|
loaded_models[model_key] = model |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_info["tokenizer_name"], |
|
use_fast=True, |
|
token=HF_TOKEN |
|
) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
loaded_tokenizers[model_key] = tokenizer |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample): |
|
load_model_and_tokenizer(model_choice) |
|
|
|
model = loaded_models[model_choice] |
|
tokenizer = loaded_tokenizers[model_choice] |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(device) |
|
|
|
generation_kwargs = { |
|
"input_ids": inputs["input_ids"], |
|
"attention_mask": inputs["attention_mask"], |
|
"max_length": max_length, |
|
"temperature": temperature, |
|
"repetition_penalty": 1.2, |
|
"no_repeat_ngram_size": 2, |
|
"do_sample": do_sample, |
|
} |
|
|
|
if do_sample: |
|
generation_kwargs["top_p"] = top_p |
|
|
|
outputs = model.generate(**generation_kwargs) |
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
return generated_text |
|
|
|
|
|
def update_settings(model_choice): |
|
defaults = MODELS[model_choice]["defaults"] |
|
return ( |
|
gr.update(value=defaults["max_length"]), |
|
gr.update(value=defaults["temperature"]), |
|
gr.update(value=defaults["top_p"]), |
|
gr.update(value=defaults["do_sample"]) |
|
) |
|
|
|
|
|
def update_language(selected_language): |
|
lang = LANGUAGES[selected_language] |
|
return ( |
|
gr.update(value=lang["title"]), |
|
gr.update(value=lang["description"]), |
|
gr.update(label=lang["select_model"]), |
|
gr.update(label=lang["enter_prompt"]), |
|
gr.update(label=lang["max_length"]), |
|
gr.update(label=lang["temperature"]), |
|
gr.update(label=lang["top_p"]), |
|
gr.update(label=lang["do_sample"]), |
|
gr.update(value=lang["generate_button"]), |
|
gr.update(label=lang["generated_text"]) |
|
) |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def wrapped_generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample): |
|
return generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample) |
|
|
|
|
|
with gr.Blocks() as iface: |
|
with gr.Row(): |
|
language_dropdown = gr.Dropdown( |
|
choices=list(LANGUAGES.keys()), |
|
value="Русский", |
|
label=LANGUAGES["Русский"]["language"] |
|
) |
|
|
|
title = gr.Markdown(LANGUAGES["Русский"]["title"]) |
|
description = gr.Markdown(LANGUAGES["Русский"]["description"]) |
|
|
|
with gr.Row(): |
|
model_dropdown = gr.Dropdown( |
|
choices=list(MODELS.keys()), |
|
value="V-2: LLama-3.1-KazLLM-70B-AWQ4", |
|
label=LANGUAGES["Русский"]["select_model"] |
|
) |
|
|
|
with gr.Row(): |
|
prompt_input = gr.Textbox( |
|
lines=4, |
|
placeholder="Введите ваш запрос здесь...", |
|
label=LANGUAGES["Русский"]["enter_prompt"] |
|
) |
|
|
|
with gr.Row(): |
|
max_length_slider = gr.Slider( |
|
minimum=50, |
|
maximum=1000, |
|
step=10, |
|
value=MODELS["V-2: LLama-3.1-KazLLM-70B-AWQ4"]["defaults"]["max_length"], |
|
label=LANGUAGES["Русский"]["max_length"] |
|
) |
|
temperature_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
step=0.1, |
|
value=MODELS["V-2: LLama-3.1-KazLLM-70B-AWQ4"]["defaults"]["temperature"], |
|
label=LANGUAGES["Русский"]["temperature"] |
|
) |
|
|
|
with gr.Row(): |
|
top_p_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
step=0.05, |
|
value=MODELS["V-2: LLama-3.1-KazLLM-70B-AWQ4"]["defaults"]["top_p"], |
|
label=LANGUAGES["Русский"]["top_p"] |
|
) |
|
do_sample_checkbox = gr.Checkbox( |
|
value=MODELS["V-2: LLama-3.1-KazLLM-70B-AWQ4"]["defaults"]["do_sample"], |
|
label=LANGUAGES["Русский"]["do_sample"] |
|
) |
|
|
|
generate_button = gr.Button(LANGUAGES["Русский"]["generate_button"]) |
|
|
|
output_text = gr.Textbox( |
|
label=LANGUAGES["Русский"]["generated_text"], |
|
lines=10 |
|
) |
|
|
|
model_dropdown.change( |
|
fn=update_settings, |
|
inputs=[model_dropdown], |
|
outputs=[max_length_slider, temperature_slider, top_p_slider, do_sample_checkbox] |
|
) |
|
|
|
language_dropdown.change( |
|
fn=update_language, |
|
inputs=[language_dropdown], |
|
outputs=[title, description, model_dropdown, prompt_input, max_length_slider, temperature_slider, top_p_slider, |
|
do_sample_checkbox, generate_button, output_text] |
|
) |
|
|
|
do_sample_checkbox.change( |
|
fn=lambda do_sample: gr.update(visible=do_sample), |
|
inputs=[do_sample_checkbox], |
|
outputs=[top_p_slider] |
|
) |
|
|
|
generate_button.click( |
|
fn=wrapped_generate_text, |
|
inputs=[ |
|
model_dropdown, |
|
prompt_input, |
|
max_length_slider, |
|
temperature_slider, |
|
top_p_slider, |
|
do_sample_checkbox |
|
], |
|
outputs=output_text |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|