from transformers import AutoTokenizer, AutoModelForCausalLM import torch import torch.nn import gradio as gr import re # CONF MAX_LENGTH = 1024 device = 'cuda' if torch.cuda.is_available() else 'cpu'; print("Using:", device) tokenizer = AutoTokenizer.from_pretrained("MarkelFe/PoliticalSpeech2", padding_side='left') model = AutoModelForCausalLM.from_pretrained("MarkelFe/PoliticalSpeech2").to(device) def return_conf(max_tokens, conf, ngram, beams, top_k, top_p): if conf == "Ezer": options = {"max_new_tokens": max_tokens, "do_sample": False} elif conf == "Beam Search": options = {"no_repeat_ngram_size": ngram, "num_beams": beams, "max_new_tokens": max_tokens, "do_sample": False} elif conf == "Top K": options = {"top_k": top_k, "max_new_tokens": max_tokens, "do_sample": False} elif conf == "Top P": options = {"top_p": top_p, "max_new_tokens": max_tokens, "do_sample": False} return options def sortu_testua(alderdia, testua, max_tokens, conf, ngram, beams, top_k, top_p): options = return_conf(max_tokens, conf, ngram, beams, top_k, top_p) prompt = f"[{alderdia}] {testua}" tokens = tokenizer(prompt, return_tensors="pt").to(device) generation = model.generate(inputs=tokens['input_ids'], attention_mask = tokens['attention_mask'], **options)[0] text = tokenizer.decode(generation) return re.split("\[(.*?)\] ", text)[-1] def sortu_testu_guztiak(testua, max_tokens, conf, ngram, beams, top_k, top_p): options = return_conf(max_tokens, conf, ngram, beams, top_k, top_p) prompts = [f"[\"EAJ\"] {testua}", f"[\"EH Bildu\"] {testua}", f"[\"PP\"] {testua}", f"[\"PSE-EE\"] {testua}", f"[\"EP\"] {testua}", f"[\"UPyD\"] {testua}"] tokens = tokenizer(prompts, padding = True, return_tensors="pt").to(device) generation = model.generate(inputs=tokens['input_ids'], attention_mask = tokens['attention_mask'], **options) texts = tokenizer.batch_decode(generation) texts = list(map(lambda text: re.split("\[(.*?)\] ", text)[-1], texts)) return (texts[0], texts[1], texts[2], texts[3], texts[4], texts[5]) with gr.Blocks() as demo: with gr.Tab("Alderdi bakarra"): with gr.Row(): with gr.Column(scale=4, min_width=400): alderdia = gr.Dropdown(["EAJ", "EH Bildu", "PP", "PSE-EE", "EP", "UPyD"], label="Alderdi politikoa",) testua = gr.Textbox(label="Testua") greet_btn = gr.Button("Sortu testua") gr.Markdown("""Aldatu konfigurazioa""") new_token = gr.Slider(minimum=1, maximum=MAX_LENGTH, value=30, label="Luzera", info="Zenbat token berri sortuko diren.") confi = gr.Radio(["Ezer", "Beam Search", "Top K", "Top P"], value="Beam Search", label="Estrategia", info="Aukeratu ze estrategia erabiliko den erantzunak hobetzeko") ngram = gr.Slider(minimum=1, maximum=50, value=4, step=1, label="ngram kopurua", info="Bakarrik kontuan hartuko da \"Beam Search\" aukeratuta badago") beams = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Beam kopurua", info="Bakarrik kontuan hartuko da \"Beam Search\" aukeratuta badago") top_k = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="K-balioa", info="Bakarrik kontuan hartuko da \"Top K\" aukeratuta badago") top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.01, label="P-balioa", info="Bakarrik kontuan hartuko da \"Top P\" aukeratuta badago") with gr.Column(scale=3, min_width=200): output = gr.Textbox(label="Output") with gr.Tab("Alderdi guztiak"): with gr.Row(): with gr.Column(scale=4, min_width=400): testua2 = gr.Textbox(label="Testua") greet_btn2 = gr.Button("Sortu testuak") gr.Markdown("""Aldatu konfigurazioa""") new_token2 = gr.Slider(minimum=1, maximum=MAX_LENGTH, value=30, label="Luzera", info="Zenbat token berri sortuko diren.") confi2 = gr.Radio(["Ezer", "Beam Search", "Top K", "Top P"], value="Beam Search", label="Estrategia", info="Aukeratu ze estrategia erabiliko den erantzunak hobetzeko") ngram2 = gr.Slider(minimum=1, maximum=50, value=4, step=1, label="ngram kopurua", info="Bakarrik kontuan hartuko da \"Beam Search\" aukeratuta badago") beams2 = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Beam kopurua", info="Bakarrik kontuan hartuko da \"Beam Search\" aukeratuta badago") top_k2 = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="K-balioa", info="Bakarrik kontuan hartuko da \"Top K\" aukeratuta badago") top_p2 = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.01, label="P-balioa", info="Bakarrik kontuan hartuko da \"Top P\" aukeratuta badago") with gr.Column(scale=3, min_width=200): outputEAJ = gr.Textbox(label="EAJ") outputBildu = gr.Textbox(label="EH Bildu") outputPP = gr.Textbox(label="PP") outputPSE = gr.Textbox(label="PSE-EE") outputEP = gr.Textbox(label="EP") outputUPyD = gr.Textbox(label="UPyD") greet_btn.click(fn=sortu_testua, inputs=[alderdia, testua, new_token, confi, ngram, beams, top_k, top_p], outputs=output, api_name="sortu_testua") greet_btn2.click(fn=sortu_testu_guztiak, inputs=[testua2, new_token2, confi2, ngram2, beams2, top_k2, top_p2], outputs=[outputEAJ, outputBildu, outputPP, outputPSE, outputEP, outputUPyD], api_name="sortu_testu_guztiak") demo.launch()