import gdown from git import Repo import gradio as gr from huggingface_hub import snapshot_download import os import penman import sys import time import torch from transformers import pipeline if not os.path.exists("amr-tst-indo"): Repo.clone_from("https://github.com/AbdiHaryadi/amr-tst-indo.git", "amr-tst-indo") sys.path.append("./amr-tst-indo") from text_to_amr import TextToAMR from style_detector import StyleDetector from style_rewriting import StyleRewriting from amr_to_text import AMRToTextWithTaufiqMethod amr_parsing_model_name = "mbart-en-id-smaller-indo-amr-parsing-translated-nafkhan" snapshot_download( repo_id=f"abdiharyadi/{amr_parsing_model_name}", local_dir=f"./amr-tst-indo/AMRBART-id/models/{amr_parsing_model_name}", ignore_patterns=[ "*log*", "*checkpoint*", ] ) t2a = TextToAMR(model_name=amr_parsing_model_name) gdown.download( "https://drive.google.com/uc?id=1J_6PbYsQ6Kl4Qfs1wBVwd52_r9uTpIxx", "./model-best.pt" ) sd = StyleDetector( config_path="./amr-tst-indo/indonesian-aste-generative/resources/exp-v2/exp-m0.yaml", model_path="./model-best.pt" ) device_type = "cuda" if torch.cuda.is_available() else "cpu" clf_pipeline = pipeline( "text-classification", model="abdiharyadi/roberta-base-indonesian-522M-with-sa-william-dataset", device=device_type ) gdown.download( "https://drive.google.com/uc?id=15KctCcsHgTFMUh_tWNBNUiCyX56fq6p-", "./fasttext_skipgram_indo.bin" ) sr = StyleRewriting( clf_pipeline=clf_pipeline, fasttext_model_path="./fasttext_skipgram_indo.bin", position_aware_concatenation=False, reset_sense_strategy=False, max_score_strategy=True, maximize_style_words_expansion=False ) amr_gen_model_name = "taufiq-indo-amr-generation-gold-uncased" model_path = f"./{amr_gen_model_name}" snapshot_download( repo_id=f"abdiharyadi/{amr_gen_model_name}", local_dir=model_path, allow_patterns=[ "*checkpoint-3*" ] ) a2t = AMRToTextWithTaufiqMethod( model_path=os.path.join(model_path, "checkpoint-3"), lowercase=True, ) def run(text, source_style): yield ( "(Memproses ...)", "(Menunggu ...)", "(Menunggu ...)", "(Menunggu ...)", "(Menunggu ...)", ) start_time = time.time() source_amr, *_ = t2a([text]) source_amr.metadata = {} source_amr_display = penman.encode(source_amr) source_amr_display += f"\n\n({time.time() - start_time:.2f} s)" yield ( source_amr_display, "(Memproses ...)", "(Menunggu ...)", "(Menunggu ...)", "(Menunggu ...)", ) triplets = sd.get_triplets(text) triplets_display = "\n".join(f"({x[0]}, {x[1]}, {x[2]})" for x in triplets) triplets_display += f"\n\n({time.time() - start_time:.2f} s)" yield ( source_amr_display, triplets_display, "(Memproses ...)", "(Menunggu ...)", "(Menunggu ...)", ) style_words = sd.get_style_words_from_triplets(triplets) style_words_display = ", ".join(style_words) style_words_display += f"\n\n({time.time() - start_time:.2f} s)" yield ( source_amr_display, triplets_display, style_words_display, "(Memproses ...)", "(Menunggu ...)", ) target_amr = sr(text, source_amr, source_style, style_words) target_amr_display = penman.encode(target_amr) target_amr_display += f"\n\n({time.time() - start_time:.2f} s)" yield ( source_amr_display, triplets_display, style_words_display, target_amr_display, "(Memproses ...)", ) result, *_ = a2t([target_amr]) result += f"\n\n({time.time() - start_time:.2f} s)" yield ( source_amr_display, triplets_display, style_words_display, target_amr_display, result ) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_textbox = gr.Textbox(label="Teks (Text)") style_choices = gr.Radio( label="Gaya sumber (Source style)", choices=[ ("Positif (Positive)", "LABEL_1"), ("Negatif (Negative)", "LABEL_0"), ], value="LABEL_1" ) submit_btn = gr.Button("Submit") with gr.Column(): with gr.Row(): src_amr_graph_output = gr.Textbox( label="Graf AMR sumber (Source AMR graph)", min_width=320, ) triplets_output = gr.Textbox( label="Triplet (Triplets)", min_width=320, ) with gr.Row(): style_words_output = gr.Textbox( label="Kata bergaya (Style words)", min_width=320, ) tgt_amr_graph_output = gr.Textbox( label="Graf AMR target (Target AMR graph)", min_width=320, ) result_output = gr.Textbox(label="Hasil (Result)") with gr.Column(): gr.Markdown(""" # Pengakuan Demo ini disiapkan untuk Program Penelitian dan Pengabdian Masyarakat STEI ITB 2024. **Tim Peneliti**: - Masayu Leylia Khodra (masayu@staff.stei.itb.ac.id) - M. Abdi Haryadi. H (abdiharyadi.ah@gmail.com) """) submit_btn.click( run, [input_textbox, style_choices], [src_amr_graph_output, triplets_output, style_words_output, tgt_amr_graph_output, result_output] ) demo.launch()