abdiharyadi commited on
Commit
35d002b
β€’
1 Parent(s): ce0ea7d

feat: integrate StyleRewriting

Browse files
Files changed (1) hide show
  1. app.py +33 -4
app.py CHANGED
@@ -6,6 +6,8 @@ import os
6
  import penman
7
  import sys
8
  import time
 
 
9
 
10
  if not os.path.exists("amr-tst-indo"):
11
  Repo.clone_from("https://github.com/AbdiHaryadi/amr-tst-indo.git", "amr-tst-indo")
@@ -13,6 +15,7 @@ sys.path.append("./amr-tst-indo")
13
 
14
  from text_to_amr import TextToAMR
15
  from style_detector import StyleDetector
 
16
 
17
  amr_parsing_model_name = "mbart-en-id-smaller-indo-amr-parsing-translated-nafkhan"
18
  snapshot_download(
@@ -34,6 +37,25 @@ sd = StyleDetector(
34
  model_path="./model-best.pt"
35
  )
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def run(text, source_style):
38
  yield (
39
  "(Memproses ...)",
@@ -46,9 +68,16 @@ def run(text, source_style):
46
  start_time = time.time()
47
 
48
  # source_amr, *_ = t2a([text])
49
- # source_amr.metadata = {}
50
- # source_amr_display = penman.encode(source_amr)
51
- source_amr_display = "(z0 / halo)"
 
 
 
 
 
 
 
52
  source_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
53
  yield (
54
  source_amr_display,
@@ -85,7 +114,7 @@ def run(text, source_style):
85
  "(Menunggu ...)",
86
  )
87
 
88
- target_amr = penman.decode("(z0 / dunia)")
89
  target_amr_display = penman.encode(target_amr)
90
  target_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
91
  yield (
 
6
  import penman
7
  import sys
8
  import time
9
+ import torch
10
+ from transformers import pipeline
11
 
12
  if not os.path.exists("amr-tst-indo"):
13
  Repo.clone_from("https://github.com/AbdiHaryadi/amr-tst-indo.git", "amr-tst-indo")
 
15
 
16
  from text_to_amr import TextToAMR
17
  from style_detector import StyleDetector
18
+ from style_rewriting import StyleRewriting
19
 
20
  amr_parsing_model_name = "mbart-en-id-smaller-indo-amr-parsing-translated-nafkhan"
21
  snapshot_download(
 
37
  model_path="./model-best.pt"
38
  )
39
 
40
+ device_type = "cuda" if torch.cuda.is_available() else "cpu"
41
+ clf_pipeline = pipeline(
42
+ "text-classification",
43
+ model="abdiharyadi/roberta-base-indonesian-522M-with-sa-william-dataset",
44
+ device=device_type
45
+ )
46
+ gdown.download(
47
+ "https://drive.google.com/uc?id=15KctCcsHgTFMUh_tWNBNUiCyX56fq6p-",
48
+ "./fasttext_skipgram_indo.bin"
49
+ )
50
+ sr = StyleRewriting(
51
+ clf_pipeline=clf_pipeline,
52
+ fasttext_model_path="./fasttext_skipgram_indo.bin",
53
+ position_aware_concatenation=False,
54
+ reset_sense_strategy=False,
55
+ max_score_strategy=True,
56
+ maximize_style_words_expansion=False
57
+ )
58
+
59
  def run(text, source_style):
60
  yield (
61
  "(Memproses ...)",
 
68
  start_time = time.time()
69
 
70
  # source_amr, *_ = t2a([text])
71
+ source_amr = penman.decode("""
72
+ (z0 / dan
73
+ :op1 (z1 / bagus-01
74
+ :ARG1 (z2 / tempat)
75
+ :degree (z3 / sangat))
76
+ :op2 (z4 / bersih-01
77
+ :ARG1 z2))
78
+ """)
79
+ source_amr.metadata = {}
80
+ source_amr_display = penman.encode(source_amr)
81
  source_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
82
  yield (
83
  source_amr_display,
 
114
  "(Menunggu ...)",
115
  )
116
 
117
+ target_amr = sr(text, source_amr, source_style, style_words)
118
  target_amr_display = penman.encode(target_amr)
119
  target_amr_display += f"\n\n({time.time() - start_time:.2f} s)"
120
  yield (