Spaces:
Sleeping
Sleeping
File size: 4,474 Bytes
d4efbb2 63e7eb5 437fb7f 63e7eb5 d4efbb2 c6dbfda 63e7eb5 ce0ea7d 35d002b 63e7eb5 d4efbb2 63e7eb5 d4efbb2 35d002b 1556875 63e7eb5 437fb7f d4efbb2 35d002b 1556875 c6dbfda ce0ea7d 1556875 35d002b ce0ea7d 437fb7f 17757ae ce0ea7d 17757ae c6dbfda ce0ea7d c6dbfda 35d002b 991fc8d ce0ea7d c6dbfda 1556875 ce0ea7d c6dbfda cb5c6f6 c446280 c6dbfda cb5c6f6 c6dbfda cb5c6f6 c6dbfda 437fb7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import gdown
from git import Repo
import gradio as gr
from huggingface_hub import snapshot_download
import os
import penman
import sys
import time
import torch
from transformers import pipeline
if not os.path.exists("amr-tst-indo"):
Repo.clone_from("https://github.com/AbdiHaryadi/amr-tst-indo.git", "amr-tst-indo")
sys.path.append("./amr-tst-indo")
from text_to_amr import TextToAMR
from style_detector import StyleDetector
from style_rewriting import StyleRewriting
from amr_to_text import AMRToTextWithTaufiqMethod
amr_parsing_model_name = "mbart-en-id-smaller-indo-amr-parsing-translated-nafkhan"
snapshot_download(
repo_id=f"abdiharyadi/{amr_parsing_model_name}",
local_dir=f"./amr-tst-indo/AMRBART-id/models/{amr_parsing_model_name}",
ignore_patterns=[
"*log*",
"*checkpoint*",
]
)
t2a = TextToAMR(model_name=amr_parsing_model_name)
gdown.download(
"https://drive.google.com/uc?id=1J_6PbYsQ6Kl4Qfs1wBVwd52_r9uTpIxx",
"./model-best.pt"
)
sd = StyleDetector(
config_path="./amr-tst-indo/indonesian-aste-generative/resources/exp-v2/exp-m0.yaml",
model_path="./model-best.pt"
)
device_type = "cuda" if torch.cuda.is_available() else "cpu"
clf_pipeline = pipeline(
"text-classification",
model="abdiharyadi/roberta-base-indonesian-522M-with-sa-william-dataset",
device=device_type
)
gdown.download(
"https://drive.google.com/uc?id=15KctCcsHgTFMUh_tWNBNUiCyX56fq6p-",
"./fasttext_skipgram_indo.bin"
)
sr = StyleRewriting(
clf_pipeline=clf_pipeline,
fasttext_model_path="./fasttext_skipgram_indo.bin",
position_aware_concatenation=False,
reset_sense_strategy=False,
max_score_strategy=True,
maximize_style_words_expansion=False
)
amr_gen_model_name = "taufiq-indo-amr-generation-gold-uncased"
model_path = f"./{amr_gen_model_name}"
snapshot_download(
repo_id=f"abdiharyadi/{amr_gen_model_name}",
local_dir=model_path,
allow_patterns=[
"*checkpoint-3*"
]
)
a2t = AMRToTextWithTaufiqMethod(
model_path=os.path.join(model_path, "checkpoint-3"),
lowercase=True,
)
def run(text, source_style):
yield (
"(Memproses ...)",
"(Menunggu ...)",
"(Menunggu ...)",
"(Menunggu ...)",
"(Menunggu ...)",
)
start_time = time.time()
source_amr, *_ = t2a([text])
source_amr.metadata = {}
source_amr_display = penman.encode(source_amr)
source_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
"(Memproses ...)",
"(Menunggu ...)",
"(Menunggu ...)",
"(Menunggu ...)",
)
triplets = sd.get_triplets(text)
triplets_display = "\n".join(f"({x[0]}, {x[1]}, {x[2]})" for x in triplets)
triplets_display += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
triplets_display,
"(Memproses ...)",
"(Menunggu ...)",
"(Menunggu ...)",
)
style_words = sd.get_style_words_from_triplets(triplets)
style_words_display = ", ".join(style_words)
style_words_display += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
triplets_display,
style_words_display,
"(Memproses ...)",
"(Menunggu ...)",
)
target_amr = sr(text, source_amr, source_style, style_words)
target_amr_display = penman.encode(target_amr)
target_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
triplets_display,
style_words_display,
target_amr_display,
"(Memproses ...)",
)
result, *_ = a2t([target_amr])
result += f"\n\n({time.time() - start_time:.2f} s)"
yield (
source_amr_display,
triplets_display,
style_words_display,
target_amr_display,
result
)
demo = gr.Interface(
fn=run,
inputs=[
gr.Textbox(label="Teks (Text)"),
gr.Radio(label="Gaya sumber (Source style)", choices=[
("Positif (Positive)", "LABEL_1"),
("Negatif (Negative)", "LABEL_0"),
], value="LABEL_1"),
],
outputs=[
gr.Textbox(label="Graf AMR sumber (Source AMR graph)"),
gr.Textbox(label="Triplet (Triplets)"),
gr.Textbox(label="Kata bergaya (Style words)"),
gr.Textbox(label="Graf AMR target (Target AMR graph)"),
gr.Textbox(label="Hasil (Result)"),
]
)
demo.launch()
|