eubinecto commited on
Commit
c1728bd
·
1 Parent(s): 59df933

[#9] idiomifier:m-1-3 is ready. main_deploy.py is updated accordingly

Browse files
explore/explore_bart_tokenizer_decode_idiom_special_tokens.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from idiomify.fetchers import fetch_tokenizer
2
+
3
+
4
+ def main():
5
+ tokenizer = fetch_tokenizer("t-1-1")
6
+ sent = "There will always be a <idiom> silver lining </idiom> even when things look pitch black"
7
+ ids = tokenizer(sent)['input_ids']
8
+ print(ids)
9
+ decoded = tokenizer.decode(ids)
10
+ print(decoded)
11
+
12
+
13
+ if __name__ == '__main__':
14
+ main()
idiomify/pipeline.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import List
2
  from transformers import BartTokenizer
3
  from idiomify.builders import SourcesBuilder
@@ -18,5 +19,9 @@ class Pipeline:
18
  decoder_start_token_id=self.model.hparams['bos_token_id'],
19
  max_length=max_length,
20
  ) # -> (N, L_t)
21
- tgts = self.builder.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
 
 
 
 
22
  return tgts
 
1
+ import re
2
  from typing import List
3
  from transformers import BartTokenizer
4
  from idiomify.builders import SourcesBuilder
 
19
  decoder_start_token_id=self.model.hparams['bos_token_id'],
20
  max_length=max_length,
21
  ) # -> (N, L_t)
22
+ tgts = self.builder.tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
23
+ tgts = [
24
+ re.sub(r"<s>|</s>", "", tgt)
25
+ for tgt in tgts
26
+ ]
27
  return tgts
main_deploy.py CHANGED
@@ -1,9 +1,9 @@
1
  """
2
  we deploy the pipeline via streamlit.
3
  """
 
4
  import streamlit as st
5
- from transformers import BartTokenizer
6
- from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_idioms
7
  from idiomify.pipeline import Pipeline
8
 
9
 
@@ -11,7 +11,7 @@ from idiomify.pipeline import Pipeline
11
  def fetch_resources() -> tuple:
12
  config = fetch_config()['idiomifier']
13
  model = fetch_idiomifier(config['ver'])
14
- tokenizer = BartTokenizer.from_pretrained(config['bart'])
15
  idioms = fetch_idioms(config['idioms_ver'])
16
  return config, model, tokenizer, idioms
17
 
@@ -23,17 +23,20 @@ def main():
23
  pipeline = Pipeline(model, tokenizer)
24
  st.title("Idiomify Demo")
25
  text = st.text_area("Type sentences here",
26
- value="Just remember there will always be a hope even when things look black")
27
  with st.sidebar:
28
  st.subheader("Supported idioms")
 
29
  st.write(" / ".join(idioms))
30
 
31
  if st.button(label="Idiomify"):
32
  with st.spinner("Please wait..."):
33
  sents = [sent for sent in text.split(".") if sent]
34
- sents = pipeline(sents, max_length=200)
35
  # highlight the rule & honorifics that were applied
36
- st.write(". ".join(sents))
 
 
37
 
38
 
39
  if __name__ == '__main__':
 
1
  """
2
  we deploy the pipeline via streamlit.
3
  """
4
+ import re
5
  import streamlit as st
6
+ from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_idioms, fetch_tokenizer
 
7
  from idiomify.pipeline import Pipeline
8
 
9
 
 
11
  def fetch_resources() -> tuple:
12
  config = fetch_config()['idiomifier']
13
  model = fetch_idiomifier(config['ver'])
14
+ tokenizer = fetch_tokenizer(config['tokenizer_ver'])
15
  idioms = fetch_idioms(config['idioms_ver'])
16
  return config, model, tokenizer, idioms
17
 
 
23
  pipeline = Pipeline(model, tokenizer)
24
  st.title("Idiomify Demo")
25
  text = st.text_area("Type sentences here",
26
+ value="Just remember that there will always be a hope even when things look hopeless")
27
  with st.sidebar:
28
  st.subheader("Supported idioms")
29
+ idioms = [row["Idiom"] for _, row in idioms.iterrows()]
30
  st.write(" / ".join(idioms))
31
 
32
  if st.button(label="Idiomify"):
33
  with st.spinner("Please wait..."):
34
  sents = [sent for sent in text.split(".") if sent]
35
+ preds = pipeline(sents, max_length=200)
36
  # highlight the rule & honorifics that were applied
37
+ preds = [re.sub(r"<idiom>|</idiom>", "`", pred)
38
+ for pred in preds]
39
+ st.markdown(". ".join(preds))
40
 
41
 
42
  if __name__ == '__main__':
main_eval.py CHANGED
@@ -6,7 +6,7 @@ import pytorch_lightning as pl
6
  from pytorch_lightning.loggers import WandbLogger
7
  from transformers import BartTokenizer
8
  from idiomify.datamodules import IdiomifyDataModule
9
- from idiomify.fetchers import fetch_config, fetch_idiomifier
10
  from idiomify.paths import ROOT_DIR
11
 
12
 
@@ -17,10 +17,10 @@ def main():
17
  args = parser.parse_args()
18
  config = fetch_config()['idiomifier']
19
  config.update(vars(args))
20
- tokenizer = BartTokenizer.from_pretrained(config['bart'])
21
  # prepare the datamodule
22
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
23
  model = fetch_idiomifier(config['ver'], run) # fetch a pre-trained model
 
24
  datamodule = IdiomifyDataModule(config, tokenizer, run)
25
  logger = WandbLogger(log_model=False)
26
  trainer = pl.Trainer(fast_dev_run=config['fast_dev_run'],
 
6
  from pytorch_lightning.loggers import WandbLogger
7
  from transformers import BartTokenizer
8
  from idiomify.datamodules import IdiomifyDataModule
9
+ from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_tokenizer
10
  from idiomify.paths import ROOT_DIR
11
 
12
 
 
17
  args = parser.parse_args()
18
  config = fetch_config()['idiomifier']
19
  config.update(vars(args))
 
20
  # prepare the datamodule
21
  with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
22
  model = fetch_idiomifier(config['ver'], run) # fetch a pre-trained model
23
+ tokenizer = fetch_tokenizer(config['tokenizer_ver'], run)
24
  datamodule = IdiomifyDataModule(config, tokenizer, run)
25
  logger = WandbLogger(log_model=False)
26
  trainer = pl.Trainer(fast_dev_run=config['fast_dev_run'],
main_infer.py CHANGED
@@ -3,25 +3,24 @@ This is for just a simple sanity check on the inference.
3
  """
4
  import argparse
5
  from idiomify.pipeline import Pipeline
6
- from idiomify.fetchers import fetch_config, fetch_idiomifier
7
  from transformers import BartTokenizer
8
 
9
 
10
  def main():
11
  parser = argparse.ArgumentParser()
12
  parser.add_argument("--sent", type=str,
13
- default="If there's any good to loosing my job,"
14
- " it's that I'll now be able to go to school full-time and finish my degree earlier.")
15
  args = parser.parse_args()
16
  config = fetch_config()['idiomifier']
17
  config.update(vars(args))
18
  model = fetch_idiomifier(config['ver'])
 
19
  model.eval() # this is crucial
20
- tokenizer = BartTokenizer.from_pretrained(config['bart'])
21
  pipeline = Pipeline(model, tokenizer)
22
  src = config['sent']
23
- tgt = pipeline(sents=[config['sent']])
24
- print(src, "\n->", tgt)
25
 
26
 
27
  if __name__ == '__main__':
 
3
  """
4
  import argparse
5
  from idiomify.pipeline import Pipeline
6
+ from idiomify.fetchers import fetch_config, fetch_idiomifier, fetch_tokenizer
7
  from transformers import BartTokenizer
8
 
9
 
10
  def main():
11
  parser = argparse.ArgumentParser()
12
  parser.add_argument("--sent", type=str,
13
+ default="Just remember that there will always be a hope even when things look hopeless")
 
14
  args = parser.parse_args()
15
  config = fetch_config()['idiomifier']
16
  config.update(vars(args))
17
  model = fetch_idiomifier(config['ver'])
18
+ tokenizer = fetch_tokenizer(config['tokenizer_ver'])
19
  model.eval() # this is crucial
 
20
  pipeline = Pipeline(model, tokenizer)
21
  src = config['sent']
22
+ tgts = pipeline(sents=[src])
23
+ print(src, "\n->", tgts[0])
24
 
25
 
26
  if __name__ == '__main__':