Spaces:
Running
Running
abdiharyadi
commited on
Commit
β’
35d002b
1
Parent(s):
ce0ea7d
feat: integrate StyleRewriting
Browse files
app.py
CHANGED
@@ -6,6 +6,8 @@ import os
|
|
6 |
import penman
|
7 |
import sys
|
8 |
import time
|
|
|
|
|
9 |
|
10 |
if not os.path.exists("amr-tst-indo"):
|
11 |
Repo.clone_from("https://github.com/AbdiHaryadi/amr-tst-indo.git", "amr-tst-indo")
|
@@ -13,6 +15,7 @@ sys.path.append("./amr-tst-indo")
|
|
13 |
|
14 |
from text_to_amr import TextToAMR
|
15 |
from style_detector import StyleDetector
|
|
|
16 |
|
17 |
amr_parsing_model_name = "mbart-en-id-smaller-indo-amr-parsing-translated-nafkhan"
|
18 |
snapshot_download(
|
@@ -34,6 +37,25 @@ sd = StyleDetector(
|
|
34 |
model_path="./model-best.pt"
|
35 |
)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
def run(text, source_style):
|
38 |
yield (
|
39 |
"(Memproses ...)",
|
@@ -46,9 +68,16 @@ def run(text, source_style):
|
|
46 |
start_time = time.time()
|
47 |
|
48 |
# source_amr, *_ = t2a([text])
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
source_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
|
53 |
yield (
|
54 |
source_amr_display,
|
@@ -85,7 +114,7 @@ def run(text, source_style):
|
|
85 |
"(Menunggu ...)",
|
86 |
)
|
87 |
|
88 |
-
target_amr =
|
89 |
target_amr_display = penman.encode(target_amr)
|
90 |
target_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
|
91 |
yield (
|
|
|
6 |
import penman
|
7 |
import sys
|
8 |
import time
|
9 |
+
import torch
|
10 |
+
from transformers import pipeline
|
11 |
|
12 |
if not os.path.exists("amr-tst-indo"):
|
13 |
Repo.clone_from("https://github.com/AbdiHaryadi/amr-tst-indo.git", "amr-tst-indo")
|
|
|
15 |
|
16 |
from text_to_amr import TextToAMR
|
17 |
from style_detector import StyleDetector
|
18 |
+
from style_rewriting import StyleRewriting
|
19 |
|
20 |
amr_parsing_model_name = "mbart-en-id-smaller-indo-amr-parsing-translated-nafkhan"
|
21 |
snapshot_download(
|
|
|
37 |
model_path="./model-best.pt"
|
38 |
)
|
39 |
|
40 |
+
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
41 |
+
clf_pipeline = pipeline(
|
42 |
+
"text-classification",
|
43 |
+
model="abdiharyadi/roberta-base-indonesian-522M-with-sa-william-dataset",
|
44 |
+
device=device_type
|
45 |
+
)
|
46 |
+
gdown.download(
|
47 |
+
"https://drive.google.com/uc?id=15KctCcsHgTFMUh_tWNBNUiCyX56fq6p-",
|
48 |
+
"./fasttext_skipgram_indo.bin"
|
49 |
+
)
|
50 |
+
sr = StyleRewriting(
|
51 |
+
clf_pipeline=clf_pipeline,
|
52 |
+
fasttext_model_path="./fasttext_skipgram_indo.bin",
|
53 |
+
position_aware_concatenation=False,
|
54 |
+
reset_sense_strategy=False,
|
55 |
+
max_score_strategy=True,
|
56 |
+
maximize_style_words_expansion=False
|
57 |
+
)
|
58 |
+
|
59 |
def run(text, source_style):
|
60 |
yield (
|
61 |
"(Memproses ...)",
|
|
|
68 |
start_time = time.time()
|
69 |
|
70 |
# source_amr, *_ = t2a([text])
|
71 |
+
source_amr = penman.decode("""
|
72 |
+
(z0 / dan
|
73 |
+
:op1 (z1 / bagus-01
|
74 |
+
:ARG1 (z2 / tempat)
|
75 |
+
:degree (z3 / sangat))
|
76 |
+
:op2 (z4 / bersih-01
|
77 |
+
:ARG1 z2))
|
78 |
+
""")
|
79 |
+
source_amr.metadata = {}
|
80 |
+
source_amr_display = penman.encode(source_amr)
|
81 |
source_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
|
82 |
yield (
|
83 |
source_amr_display,
|
|
|
114 |
"(Menunggu ...)",
|
115 |
)
|
116 |
|
117 |
+
target_amr = sr(text, source_amr, source_style, style_words)
|
118 |
target_amr_display = penman.encode(target_amr)
|
119 |
target_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
|
120 |
yield (
|