[#1] checkpoint before amending builders.py
Browse files- explore/explore_bart.py +16 -0
- main_upload_idiom2context.py → explore/explore_bart_for_conditional_generation.py +3 -4
- explore/explore_fetch_epie.py +1 -1
- explore/explore_fetch_epie_counts.py +0 -1
- explore/explore_fetch_idiom2def.py +0 -15
- explore/explore_fetch_idioms.py +1 -1
- explore/explore_fetch_literal2idiom.py +10 -0
- explore/explore_fetch_pie.py +14 -0
- idiomify/builders.py +26 -26
- idiomify/fetchers.py +66 -47
- idiomify/models.py +13 -110
- idiomify/paths.py +4 -4
- idiomify/urls.py +5 -0
- main_infer.py +37 -36
- main_upload_idioms.py +32 -4
- main_upload_literal2idiom.py +46 -0
- main_upload_tokenizer.py +0 -13
explore/explore_bart.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BartTokenizer, BartModel
|
2 |
+
|
3 |
+
|
4 |
+
def main():
|
5 |
+
|
6 |
+
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
|
7 |
+
model = BartModel.from_pretrained('facebook/bart-large')
|
8 |
+
|
9 |
+
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
10 |
+
outputs = model(**inputs)
|
11 |
+
H_all = outputs.last_hidden_state # noqa
|
12 |
+
print(H_all.shape) # (1, 8, 1024)
|
13 |
+
|
14 |
+
|
15 |
+
if __name__ == '__main__':
|
16 |
+
main()
|
main_upload_idiom2context.py → explore/explore_bart_for_conditional_generation.py
RENAMED
@@ -1,6 +1,5 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
"""
|
4 |
|
5 |
|
6 |
def main():
|
@@ -8,4 +7,4 @@ def main():
|
|
8 |
|
9 |
|
10 |
if __name__ == '__main__':
|
11 |
-
main()
|
|
|
1 |
+
|
2 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
|
|
3 |
|
4 |
|
5 |
def main():
|
|
|
7 |
|
8 |
|
9 |
if __name__ == '__main__':
|
10 |
+
main()
|
explore/explore_fetch_epie.py
CHANGED
@@ -11,7 +11,7 @@ def main():
|
|
11 |
|
12 |
# so, what do you want? you want to build an idiom-masked language modeling?
|
13 |
for idiom, context, tag in epie:
|
14 |
-
print(context)
|
15 |
|
16 |
for idx, idiom in enumerate(idioms):
|
17 |
print(idx, idiom)
|
|
|
11 |
|
12 |
# so, what do you want? you want to build an idiom-masked language modeling?
|
13 |
for idiom, context, tag in epie:
|
14 |
+
print(idiom, context)
|
15 |
|
16 |
for idx, idiom in enumerate(idioms):
|
17 |
print(idx, idiom)
|
explore/explore_fetch_epie_counts.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
from idiomify.fetchers import fetch_epie
|
3 |
|
4 |
|
|
|
|
|
1 |
from idiomify.fetchers import fetch_epie
|
2 |
|
3 |
|
explore/explore_fetch_idiom2def.py
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
from idiomify.fetchers import fetch_idiom2def
|
2 |
-
|
3 |
-
|
4 |
-
def main():
|
5 |
-
idiom2def = fetch_idiom2def("c")
|
6 |
-
for idiom, definition in idiom2def:
|
7 |
-
print(idiom, definition)
|
8 |
-
|
9 |
-
df = fetch_idiom2def("d")
|
10 |
-
for idiom, definition in idiom2def:
|
11 |
-
print(idiom, definition)
|
12 |
-
|
13 |
-
|
14 |
-
if __name__ == '__main__':
|
15 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explore/explore_fetch_idioms.py
CHANGED
@@ -2,7 +2,7 @@ from idiomify.fetchers import fetch_idioms
|
|
2 |
|
3 |
|
4 |
def main():
|
5 |
-
print(fetch_idioms("
|
6 |
|
7 |
|
8 |
if __name__ == '__main__':
|
|
|
2 |
|
3 |
|
4 |
def main():
|
5 |
+
print(fetch_idioms("pie_v0"))
|
6 |
|
7 |
|
8 |
if __name__ == '__main__':
|
explore/explore_fetch_literal2idiom.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from idiomify.fetchers import fetch_literal2idiom
|
2 |
+
|
3 |
+
|
4 |
+
def main():
|
5 |
+
for src, tgt in fetch_literal2idiom("pie_v0"):
|
6 |
+
print(src, "->", tgt)
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == '__main__':
|
10 |
+
main()
|
explore/explore_fetch_pie.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from idiomify.fetchers import fetch_pie
|
3 |
+
|
4 |
+
|
5 |
+
def main():
|
6 |
+
for idx, row in enumerate(fetch_pie()):
|
7 |
+
print(idx, row)
|
8 |
+
# the first 105 = V0.
|
9 |
+
if idx == 105:
|
10 |
+
break
|
11 |
+
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
main()
|
idiomify/builders.py
CHANGED
@@ -19,6 +19,16 @@ class TensorBuilder:
|
|
19 |
class Idiom2SubwordsBuilder(TensorBuilder):
|
20 |
|
21 |
def __call__(self, idioms: List[str], k: int) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
mask_id = self.tokenizer.mask_token_id
|
23 |
pad_id = self.tokenizer.pad_token_id
|
24 |
# temporarily disable single-token status of the idioms
|
@@ -31,38 +41,20 @@ class Idiom2SubwordsBuilder(TensorBuilder):
|
|
31 |
max_length=k, # set to k
|
32 |
return_tensors="pt")
|
33 |
input_ids = encodings['input_ids']
|
34 |
-
input_ids[input_ids == pad_id] = mask_id
|
35 |
return input_ids
|
36 |
|
37 |
|
38 |
-
class Idiom2DefBuilder(TensorBuilder):
|
39 |
-
|
40 |
-
def __call__(self, idiom2def: List[Tuple[str, str]], k: int) -> torch.Tensor:
|
41 |
-
defs = [definition for _, definition in idiom2def]
|
42 |
-
lefts = [" ".join(["[MASK]"] * k)] * len(defs)
|
43 |
-
encodings = self.tokenizer(text=lefts,
|
44 |
-
text_pair=defs,
|
45 |
-
return_tensors="pt",
|
46 |
-
add_special_tokens=True,
|
47 |
-
truncation=True,
|
48 |
-
padding=True,
|
49 |
-
verbose=True)
|
50 |
-
input_ids: torch.Tensor = encodings['input_ids']
|
51 |
-
cls_id: int = self.tokenizer.cls_token_id
|
52 |
-
sep_id: int = self.tokenizer.sep_token_id
|
53 |
-
mask_id: int = self.tokenizer.mask_token_id
|
54 |
-
wisdom_mask = torch.where(input_ids == mask_id, 1, 0)
|
55 |
-
desc_mask = torch.where(((input_ids != cls_id) & (input_ids != sep_id) & (input_ids != mask_id)), 1, 0)
|
56 |
-
return torch.stack([input_ids,
|
57 |
-
encodings['token_type_ids'],
|
58 |
-
encodings['attention_mask'],
|
59 |
-
wisdom_mask,
|
60 |
-
desc_mask], dim=1)
|
61 |
-
|
62 |
-
|
63 |
class Idiom2ContextBuilder(TensorBuilder):
|
64 |
|
65 |
def __call__(self, idiom2context: List[Tuple[str, str]]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
contexts = [context for _, context in idiom2context]
|
67 |
encodings = self.tokenizer(text=contexts,
|
68 |
return_tensors="pt",
|
@@ -78,6 +70,14 @@ class Idiom2ContextBuilder(TensorBuilder):
|
|
78 |
class TargetsBuilder(TensorBuilder):
|
79 |
|
80 |
def __call__(self, idiom2sent: List[Tuple[str, str]], idioms: List[str]) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
return torch.LongTensor([
|
82 |
idioms.index(idiom)
|
83 |
for idiom, _ in idiom2sent
|
|
|
19 |
class Idiom2SubwordsBuilder(TensorBuilder):
|
20 |
|
21 |
def __call__(self, idioms: List[str], k: int) -> torch.Tensor:
|
22 |
+
"""
|
23 |
+
1. The function takes in a list of idioms, and a maximum length of the input sequence.
|
24 |
+
2. It then splits the idioms into words, and pads the sequence to the maximum length.
|
25 |
+
3. It masks the padding tokens, and returns the input ids
|
26 |
+
:param idioms: a list of idioms, each of which is a list of tokens
|
27 |
+
:type idioms: List[str]
|
28 |
+
:param k: the maximum length of the idioms
|
29 |
+
:type k: int
|
30 |
+
:return: The input_ids of the idioms, with the pad tokens replaced by the mask token.
|
31 |
+
"""
|
32 |
mask_id = self.tokenizer.mask_token_id
|
33 |
pad_id = self.tokenizer.pad_token_id
|
34 |
# temporarily disable single-token status of the idioms
|
|
|
41 |
max_length=k, # set to k
|
42 |
return_tensors="pt")
|
43 |
input_ids = encodings['input_ids']
|
44 |
+
input_ids[input_ids == pad_id] = mask_id
|
45 |
return input_ids
|
46 |
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
class Idiom2ContextBuilder(TensorBuilder):
|
49 |
|
50 |
def __call__(self, idiom2context: List[Tuple[str, str]]):
|
51 |
+
"""
|
52 |
+
Given a list of tuples of idiom and context,
|
53 |
+
it returns a tensor of shape (batch_size, 3, max_seq_len)
|
54 |
+
:param idiom2context: List[Tuple[str, str]], a list of tuples of idiom and context
|
55 |
+
:type idiom2context: List[Tuple[str, str]]
|
56 |
+
:return: The input_ids, token_type_ids, and attention_mask for each context.
|
57 |
+
"""
|
58 |
contexts = [context for _, context in idiom2context]
|
59 |
encodings = self.tokenizer(text=contexts,
|
60 |
return_tensors="pt",
|
|
|
70 |
class TargetsBuilder(TensorBuilder):
|
71 |
|
72 |
def __call__(self, idiom2sent: List[Tuple[str, str]], idioms: List[str]) -> torch.Tensor:
|
73 |
+
"""
|
74 |
+
Given a list of idioms and a list of sentences, return a list of indices of the idioms in the sentences
|
75 |
+
:param idiom2sent: A list of tuples, where each tuple is an idiom and its corresponding sentence
|
76 |
+
:type idiom2sent: List[Tuple[str, str]]
|
77 |
+
:param idioms: A list of idioms
|
78 |
+
:type idioms: List[str]
|
79 |
+
:return: A tensor of indices of the idioms in the list of idioms.
|
80 |
+
"""
|
81 |
return torch.LongTensor([
|
82 |
idioms.index(idiom)
|
83 |
for idiom, _ in idiom2sent
|
idiomify/fetchers.py
CHANGED
@@ -1,73 +1,91 @@
|
|
1 |
import csv
|
|
|
2 |
import yaml
|
3 |
import wandb
|
4 |
import requests
|
5 |
from typing import Tuple, List
|
6 |
-
|
7 |
from wandb.sdk.wandb_run import Run
|
8 |
-
|
|
|
9 |
from idiomify.models import Alpha, RD
|
10 |
-
from idiomify.paths import
|
11 |
from idiomify.urls import (
|
12 |
EPIE_IMMUTABLE_IDIOMS_URL,
|
13 |
EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL,
|
14 |
EPIE_IMMUTABLE_IDIOMS_TAGS_URL,
|
15 |
EPIE_MUTABLE_IDIOMS_URL,
|
16 |
EPIE_MUTABLE_IDIOMS_CONTEXTS_URL,
|
17 |
-
EPIE_MUTABLE_IDIOMS_TAGS_URL
|
|
|
18 |
)
|
19 |
-
from idiomify.builders import Idiom2SubwordsBuilder
|
20 |
-
from transformers import AutoModelForMaskedLM, AutoConfig, BertTokenizer
|
21 |
|
22 |
|
23 |
# sources for dataset
|
24 |
-
def fetch_epie() -> List[Tuple[str, str, str]]:
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
return list(zip(idioms.strip().split("\n"),
|
32 |
contexts.strip().split("\n"),
|
33 |
tags.strip().split("\n")))
|
34 |
|
35 |
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
"""
|
39 |
-
|
40 |
"""
|
|
|
|
|
41 |
if run:
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
artifact.download(root=str(artifact_path))
|
50 |
-
tsv_path = artifact_path / "all.tsv"
|
51 |
-
with open(tsv_path, 'r') as fh:
|
52 |
-
reader = csv.reader(fh, delimiter="\t")
|
53 |
-
return [
|
54 |
-
(row[0], row[1])
|
55 |
-
for row in reader
|
56 |
-
]
|
57 |
|
58 |
|
59 |
-
def
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
64 |
with open(tsv_path, 'r') as fh:
|
65 |
reader = csv.reader(fh, delimiter="\t")
|
66 |
-
|
67 |
-
return [
|
68 |
-
row[0]
|
69 |
-
for row in reader
|
70 |
-
]
|
71 |
|
72 |
|
73 |
def fetch_rd(model: str, ver: str) -> RD:
|
@@ -80,12 +98,13 @@ def fetch_rd(model: str, ver: str) -> RD:
|
|
80 |
idioms = fetch_idioms(config['idioms_ver'])
|
81 |
tokenizer = BertTokenizer.from_pretrained(config['bert'])
|
82 |
idiom2subwords = Idiom2SubwordsBuilder(tokenizer)(idioms, config['k'])
|
83 |
-
if model == Alpha.name():
|
84 |
-
|
85 |
-
elif model == Gamma.name():
|
86 |
-
|
87 |
-
else:
|
88 |
-
|
|
|
89 |
return rd
|
90 |
|
91 |
|
|
|
1 |
import csv
|
2 |
+
from os import path
|
3 |
import yaml
|
4 |
import wandb
|
5 |
import requests
|
6 |
from typing import Tuple, List
|
|
|
7 |
from wandb.sdk.wandb_run import Run
|
8 |
+
from transformers import AutoModelForMaskedLM, AutoConfig, BertTokenizer
|
9 |
+
from idiomify.builders import Idiom2SubwordsBuilder
|
10 |
from idiomify.models import Alpha, RD
|
11 |
+
from idiomify.paths import CONFIG_YAML, idioms_dir, alpha_dir, literal2idiom
|
12 |
from idiomify.urls import (
|
13 |
EPIE_IMMUTABLE_IDIOMS_URL,
|
14 |
EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL,
|
15 |
EPIE_IMMUTABLE_IDIOMS_TAGS_URL,
|
16 |
EPIE_MUTABLE_IDIOMS_URL,
|
17 |
EPIE_MUTABLE_IDIOMS_CONTEXTS_URL,
|
18 |
+
EPIE_MUTABLE_IDIOMS_TAGS_URL,
|
19 |
+
PIE_URL
|
20 |
)
|
|
|
|
|
21 |
|
22 |
|
23 |
# sources for dataset
|
24 |
+
def fetch_epie(ver: str) -> List[Tuple[str, str, str]]:
|
25 |
+
"""
|
26 |
+
It fetches the EPIE idioms, contexts, and tags from the web
|
27 |
+
:param ver: str
|
28 |
+
:type ver: str
|
29 |
+
:return: A list of tuples. Each tuple contains three strings: an idiom, a context, and a tag.
|
30 |
+
"""
|
31 |
+
if ver == "immutable":
|
32 |
+
idioms_url = EPIE_IMMUTABLE_IDIOMS_URL
|
33 |
+
contexts_url = EPIE_IMMUTABLE_IDIOMS_CONTEXTS_URL
|
34 |
+
tags_url = EPIE_IMMUTABLE_IDIOMS_TAGS_URL
|
35 |
+
elif ver == "mutable":
|
36 |
+
idioms_url = EPIE_MUTABLE_IDIOMS_URL
|
37 |
+
contexts_url = EPIE_MUTABLE_IDIOMS_CONTEXTS_URL
|
38 |
+
tags_url = EPIE_MUTABLE_IDIOMS_TAGS_URL
|
39 |
+
else:
|
40 |
+
raise ValueError
|
41 |
+
idioms = requests.get(idioms_url).text
|
42 |
+
contexts = requests.get(contexts_url).text
|
43 |
+
tags = requests.get(tags_url).text
|
44 |
return list(zip(idioms.strip().split("\n"),
|
45 |
contexts.strip().split("\n"),
|
46 |
tags.strip().split("\n")))
|
47 |
|
48 |
|
49 |
+
def fetch_pie() -> list:
|
50 |
+
text = requests.get(PIE_URL).text
|
51 |
+
lines = (line for line in text.split("\n") if line)
|
52 |
+
reader = csv.reader(lines)
|
53 |
+
next(reader) # skip the header
|
54 |
+
return [
|
55 |
+
row
|
56 |
+
for row in reader
|
57 |
+
]
|
58 |
+
|
59 |
+
|
60 |
+
# --- from wandb --- #
|
61 |
+
def fetch_idioms(ver: str, run: Run = None) -> List[str]:
|
62 |
"""
|
63 |
+
why do you need this? -> you need this to have access to the idiom embeddings.
|
64 |
"""
|
65 |
+
# if run object is given, we track the lineage of the data.
|
66 |
+
# if not, we get the dataset via wandb Api.
|
67 |
if run:
|
68 |
+
artifact = run.use_artifact("idioms", type="dataset", aliases=ver)
|
69 |
+
else:
|
70 |
+
artifact = wandb.Api().artifact(f"eubinecto/idiomify/idioms:{ver}", type="dataset")
|
71 |
+
artifact_dir = artifact.download(root=idioms_dir(ver))
|
72 |
+
txt_path = path.join(artifact_dir, "all.txt")
|
73 |
+
with open(txt_path, 'r') as fh:
|
74 |
+
return [line.strip() for line in fh]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
|
77 |
+
def fetch_literal2idiom(ver: str, run: Run = None) -> List[Tuple[str, str]]:
|
78 |
+
# if run object is given, we track the lineage of the data.
|
79 |
+
# if not, we get the dataset via wandb Api.
|
80 |
+
if run:
|
81 |
+
artifact = run.use_artifact("literal2idiom", type="dataset", aliases=ver)
|
82 |
+
else:
|
83 |
+
artifact = wandb.Api().artifact(f"eubinecto/idiomify/literal2idiom:{ver}", type="dataset")
|
84 |
+
artifact_dir = artifact.download(root=literal2idiom(ver))
|
85 |
+
tsv_path = path.join(artifact_dir, "all.tsv")
|
86 |
with open(tsv_path, 'r') as fh:
|
87 |
reader = csv.reader(fh, delimiter="\t")
|
88 |
+
return [(row[0], row[1]) for row in reader]
|
|
|
|
|
|
|
|
|
89 |
|
90 |
|
91 |
def fetch_rd(model: str, ver: str) -> RD:
|
|
|
98 |
idioms = fetch_idioms(config['idioms_ver'])
|
99 |
tokenizer = BertTokenizer.from_pretrained(config['bert'])
|
100 |
idiom2subwords = Idiom2SubwordsBuilder(tokenizer)(idioms, config['k'])
|
101 |
+
# if model == Alpha.name():
|
102 |
+
# rd = Alpha.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
|
103 |
+
# elif model == Gamma.name():
|
104 |
+
# rd = Gamma.load_from_checkpoint(str(ckpt_path), mlm=mlm, idiom2subwords=idiom2subwords)
|
105 |
+
# else:
|
106 |
+
# raise ValueError
|
107 |
+
rd = ...
|
108 |
return rd
|
109 |
|
110 |
|
idiomify/models.py
CHANGED
@@ -8,14 +8,12 @@ import pytorch_lightning as pl
|
|
8 |
from transformers import BertForMaskedLM
|
9 |
|
10 |
|
11 |
-
class
|
12 |
"""
|
13 |
@eubinecto
|
14 |
The superclass of all the reverse-dictionaries. This class houses any methods that are required by
|
15 |
whatever reverse-dictionaries we define.
|
16 |
"""
|
17 |
-
|
18 |
-
# --- boilerplate; the loaders are defined in datamodules, so we don't define them here
|
19 |
# passing them to avoid warnings --- #
|
20 |
def train_dataloader(self):
|
21 |
pass
|
@@ -35,119 +33,24 @@ class RD(pl.LightningModule):
|
|
35 |
:param idiom2subwords: (|W|, K)
|
36 |
:return: (N, K, |V|); (num samples, k, the size of the vocabulary of subwords)
|
37 |
"""
|
38 |
-
|
39 |
-
# -- hyper params --- #
|
40 |
-
# should be saved to self.hparams
|
41 |
-
# https://github.com/PyTorchLightning/pytorch-lightning/issues/4390#issue-730493746
|
42 |
-
self.save_hyperparameters(ignore=["mlm", "idiom2subwords"])
|
43 |
-
# -- the only neural network we need -- #
|
44 |
-
self.mlm = mlm
|
45 |
-
# --- to be used for getting H_k --- #
|
46 |
-
self.wisdom_mask: Optional[torch.Tensor] = None # (N, L)
|
47 |
-
# --- to be used for getting H_desc --- #
|
48 |
-
self.desc_mask: Optional[torch.Tensor] = None # (N, L)
|
49 |
-
# -- constant tensors -- #
|
50 |
-
self.register_buffer("idiom2subwords", idiom2subwords) # (|W|, K)
|
51 |
|
52 |
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
53 |
"""
|
54 |
-
|
55 |
-
|
56 |
-
:return: (N, L, H)
|
57 |
-
"""
|
58 |
-
input_ids = X[:, 0] # (N, 4, L) -> (N, L)
|
59 |
-
token_type_ids = X[:, 1] # (N, 4, L) -> (N, L)
|
60 |
-
attention_mask = X[:, 2] # (N, 4, L) -> (N, L)
|
61 |
-
self.wisdom_mask = X[:, 3] # (N, 4, L) -> (N, L)
|
62 |
-
self.desc_mask = X[:, 4] # (N, 4, L) -> (N, L)
|
63 |
-
H_all = self.mlm.bert.forward(input_ids, attention_mask, token_type_ids)[0] # (N, 3, L) -> (N, L, H)
|
64 |
-
return H_all
|
65 |
-
|
66 |
-
def H_k(self, H_all: torch.Tensor) -> torch.Tensor:
|
67 |
-
"""
|
68 |
-
You may want to override this. (e.g. RDGamma - the k's could be anywhere)
|
69 |
-
:param H_all (N, L, H)
|
70 |
-
:return H_k (N, K, H)
|
71 |
-
"""
|
72 |
-
N, _, H = H_all.size()
|
73 |
-
# refer to: wisdomify/examples/explore_masked_select.py
|
74 |
-
wisdom_mask = self.wisdom_mask.unsqueeze(2).expand(H_all.shape) # (N, L) -> (N, L, 1) -> (N, L, H)
|
75 |
-
H_k = torch.masked_select(H_all, wisdom_mask.bool()) # (N, L, H), (N, L, H) -> (N * K * H)
|
76 |
-
H_k = H_k.reshape(N, self.hparams['k'], H) # (N * K * H) -> (N, K, H)
|
77 |
-
return H_k
|
78 |
-
|
79 |
-
def H_desc(self, H_all: torch.Tensor) -> torch.Tensor:
|
80 |
"""
|
81 |
-
|
82 |
-
:return H_desc (N, L - (K + 3), H)
|
83 |
-
"""
|
84 |
-
N, L, H = H_all.size()
|
85 |
-
desc_mask = self.desc_mask.unsqueeze(2).expand(H_all.shape)
|
86 |
-
H_desc = torch.masked_select(H_all, desc_mask.bool()) # (N, L, H), (N, L, H) -> (N * (L - (K + 3)) * H)
|
87 |
-
H_desc = H_desc.reshape(N, L - (self.hparams['k'] + 3), H) # (N * (L - (K + 3)) * H) -> (N, L - (K + 3), H)
|
88 |
-
return H_desc
|
89 |
-
|
90 |
-
def S_wisdom_literal(self, H_k: torch.Tensor) -> torch.Tensor:
|
91 |
-
"""
|
92 |
-
To be used for both RDAlpha & RDBeta
|
93 |
-
:param H_k: (N, K, H)
|
94 |
-
:return: S_wisdom_literal (N, |W|)
|
95 |
-
"""
|
96 |
-
S_vocab = self.mlm.cls(H_k) # bmm; (N, K, H) * (H, |V|) -> (N, K, |V|)
|
97 |
-
indices = self.idiom2subwords.T.repeat(S_vocab.shape[0], 1, 1) # (|W|, K) -> (N, K, |W|)
|
98 |
-
S_wisdom_literal = S_vocab.gather(dim=-1, index=indices) # (N, K, |V|) -> (N, K, |W|)
|
99 |
-
S_wisdom_literal = S_wisdom_literal.sum(dim=1) # (N, K, |W|) -> (N, |W|)
|
100 |
-
return S_wisdom_literal
|
101 |
-
|
102 |
-
def S_wisdom(self, H_all: torch.Tensor) -> torch.Tensor:
|
103 |
-
"""
|
104 |
-
:param H_all: (N, L, H)
|
105 |
-
:return S_wisdom: (N, |W|)
|
106 |
-
"""
|
107 |
-
raise NotImplementedError("An RD class must implement S_wisdom")
|
108 |
-
|
109 |
-
def P_wisdom(self, X: torch.Tensor) -> torch.Tensor:
|
110 |
-
"""
|
111 |
-
:param X: (N, 3, L)
|
112 |
-
:return P_wisdom: (N, |W|), normalized over dim 1.
|
113 |
-
"""
|
114 |
-
H_all = self.forward(X) # (N, 3, L) -> (N, L, H)
|
115 |
-
S_wisdom = self.S_wisdom(H_all) # (N, L, H) -> (N, W)
|
116 |
-
P_wisdom = F.softmax(S_wisdom, dim=1) # (N, W) -> (N, W)
|
117 |
-
return P_wisdom
|
118 |
-
|
119 |
-
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> dict:
|
120 |
-
X, y = batch
|
121 |
-
H_all = self.forward(X) # (N, 3, L) -> (N, L, H)
|
122 |
-
S_wisdom = self.S_wisdom(H_all) # (N, L, H) -> (N, |W|)
|
123 |
-
loss = F.cross_entropy(S_wisdom, y) # (N, |W|), (N,) -> (N,)
|
124 |
-
loss = loss.sum() # (N,) -> (1,)
|
125 |
-
# so that the metrics accumulate over the course of this epoch
|
126 |
-
# why dict? - just a boilerplate
|
127 |
-
return {
|
128 |
-
# you cannot change the keyword for the loss
|
129 |
-
"loss": loss,
|
130 |
-
}
|
131 |
-
|
132 |
-
def on_train_batch_end(self, outputs: dict, *args, **kwargs) -> None:
|
133 |
-
# watch the loss for this batch
|
134 |
-
self.log("Train/Loss", outputs['loss'])
|
135 |
-
|
136 |
-
def training_epoch_end(self, outputs: List[dict]) -> None:
|
137 |
-
# to see an average performance over the batches in this specific epoch
|
138 |
-
avg_loss = torch.stack([output['loss'].detach() for output in outputs]).mean()
|
139 |
-
self.log("Train/Average Loss", avg_loss)
|
140 |
|
141 |
-
def
|
142 |
-
|
143 |
|
144 |
-
def
|
145 |
-
|
146 |
|
147 |
-
def
|
148 |
-
|
149 |
-
avg_loss = torch.stack([output['loss'].detach() for output in outputs]).mean()
|
150 |
-
self.log("Validation/Average Loss", avg_loss)
|
151 |
|
152 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
153 |
"""
|
@@ -162,7 +65,7 @@ class RD(pl.LightningModule):
|
|
162 |
return cls.__name__.lower()
|
163 |
|
164 |
|
165 |
-
class Alpha(
|
166 |
"""
|
167 |
@eubinecto
|
168 |
The first prototype.
|
|
|
8 |
from transformers import BertForMaskedLM
|
9 |
|
10 |
|
11 |
+
class Idiomifier(pl.LightningModule):
|
12 |
"""
|
13 |
@eubinecto
|
14 |
The superclass of all the reverse-dictionaries. This class houses any methods that are required by
|
15 |
whatever reverse-dictionaries we define.
|
16 |
"""
|
|
|
|
|
17 |
# passing them to avoid warnings --- #
|
18 |
def train_dataloader(self):
|
19 |
pass
|
|
|
33 |
:param idiom2subwords: (|W|, K)
|
34 |
:return: (N, K, |V|); (num samples, k, the size of the vocabulary of subwords)
|
35 |
"""
|
36 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def forward(self, X: torch.Tensor) -> torch.Tensor:
|
39 |
"""
|
40 |
+
given a batch, forward returns a batch of hidden vectors
|
41 |
+
:param X: (N, 3, L). input_ids, token_type_ids, and what was the last one...?
|
42 |
+
:return: (N, L, H)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
"""
|
44 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
def step(self):
|
47 |
+
pass
|
48 |
|
49 |
+
def predict(self):
|
50 |
+
pass
|
51 |
|
52 |
+
def training_step(self):
|
53 |
+
pass
|
|
|
|
|
54 |
|
55 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
56 |
"""
|
|
|
65 |
return cls.__name__.lower()
|
66 |
|
67 |
|
68 |
+
class Alpha(Idiomifier):
|
69 |
"""
|
70 |
@eubinecto
|
71 |
The first prototype.
|
idiomify/paths.py
CHANGED
@@ -5,14 +5,14 @@ ARTIFACTS_DIR = ROOT_DIR / "artifacts"
|
|
5 |
CONFIG_YAML = ROOT_DIR / "config.yaml"
|
6 |
|
7 |
|
8 |
-
def idiom2def_dir(ver: str) -> Path:
|
9 |
-
return ARTIFACTS_DIR / f"idiom2def_{ver}"
|
10 |
-
|
11 |
-
|
12 |
def idioms_dir(ver: str) -> Path:
|
13 |
return ARTIFACTS_DIR / f"idioms_{ver}"
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
16 |
def alpha_dir(ver: str) -> Path:
|
17 |
return ARTIFACTS_DIR / f"alpha_{ver}"
|
18 |
|
|
|
5 |
CONFIG_YAML = ROOT_DIR / "config.yaml"
|
6 |
|
7 |
|
|
|
|
|
|
|
|
|
8 |
def idioms_dir(ver: str) -> Path:
|
9 |
return ARTIFACTS_DIR / f"idioms_{ver}"
|
10 |
|
11 |
|
12 |
+
def literal2idiom(ver: str) -> Path:
|
13 |
+
return ARTIFACTS_DIR / f"literal2idiom_{ver}"
|
14 |
+
|
15 |
+
|
16 |
def alpha_dir(ver: str) -> Path:
|
17 |
return ARTIFACTS_DIR / f"alpha_{ver}"
|
18 |
|
idiomify/urls.py
CHANGED
@@ -7,5 +7,10 @@ EPIE_MUTABLE_IDIOMS_TAGS_URL = "https://raw.githubusercontent.com/prateeksaxena2
|
|
7 |
EPIE_MUTABLE_IDIOMS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Formal_Idioms_Corpus/Formal_Idioms_Candidates.txt" # noqa
|
8 |
EPIE_MUTABLE_IDIOMS_CONTEXTS_URL = "https://github.com/prateeksaxena2809/EPIE_Corpus/blob/master/Formal_Idioms_Corpus/Formal_Idioms_Words.txt" # noqa
|
9 |
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
|
|
7 |
EPIE_MUTABLE_IDIOMS_URL = "https://raw.githubusercontent.com/prateeksaxena2809/EPIE_Corpus/master/Formal_Idioms_Corpus/Formal_Idioms_Candidates.txt" # noqa
|
8 |
EPIE_MUTABLE_IDIOMS_CONTEXTS_URL = "https://github.com/prateeksaxena2809/EPIE_Corpus/blob/master/Formal_Idioms_Corpus/Formal_Idioms_Words.txt" # noqa
|
9 |
|
10 |
+
# PIE dataset (Zhou, 2021)
|
11 |
+
# https://aclanthology.org/2021.mwe-1.5/
|
12 |
+
# right, let's just work on it.
|
13 |
+
PIE_URL = "https://raw.githubusercontent.com/zhjjn/MWE_PIE/main/data_cleaned.csv"
|
14 |
+
|
15 |
|
16 |
|
main_infer.py
CHANGED
@@ -1,36 +1,37 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
from idiomify.fetchers import fetch_config, fetch_rd, fetch_idioms
|
4 |
-
from transformers import BertTokenizer
|
5 |
-
from termcolor import colored
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
1 |
+
# we disable them for now.
|
2 |
+
# import argparse
|
3 |
+
# from idiomify.fetchers import fetch_config, fetch_rd, fetch_idioms
|
4 |
+
# from transformers import BertTokenizer
|
5 |
+
# from termcolor import colored
|
6 |
+
#
|
7 |
+
#
|
8 |
+
# def main():
|
9 |
+
# parser = argparse.ArgumentParser()
|
10 |
+
# parser.add_argument("--model", type=str,
|
11 |
+
# default="alpha")
|
12 |
+
# parser.add_argument("--ver", type=str,
|
13 |
+
# default="eng2eng")
|
14 |
+
# parser.add_argument("--sent", type=str,
|
15 |
+
# default="to avoid getting to the point")
|
16 |
+
# args = parser.parse_args()
|
17 |
+
# config = fetch_config()[args.model][args.ver]
|
18 |
+
# config.update(vars(args))
|
19 |
+
# idioms = fetch_idioms(config['idioms_ver'])
|
20 |
+
# rd = fetch_rd(config['model'], config['ver'])
|
21 |
+
# rd.eval()
|
22 |
+
# tokenizer = BertTokenizer.from_pretrained(config['bert'])
|
23 |
+
# X = T.inputs([("", config['sent'])], tokenizer, config['k'])
|
24 |
+
# probs = rd.P_wisdom(X).squeeze().tolist()
|
25 |
+
# wisdom2prob = [
|
26 |
+
# (wisdom, prob)
|
27 |
+
# for wisdom, prob in zip(idioms, probs)
|
28 |
+
# ]
|
29 |
+
# # sort and append
|
30 |
+
# res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
|
31 |
+
# print(f"query: {colored(text=config['sent'], color='blue')}")
|
32 |
+
# for idx, (idiom, prob) in enumerate(res):
|
33 |
+
# print(idx, idiom, prob)
|
34 |
+
#
|
35 |
+
#
|
36 |
+
# if __name__ == '__main__':
|
37 |
+
# main()
|
main_upload_idioms.py
CHANGED
@@ -1,12 +1,40 @@
|
|
1 |
"""
|
2 |
-
Here,
|
3 |
-
|
4 |
-
ver b:
|
5 |
"""
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
def main():
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
if __name__ == '__main__':
|
|
|
1 |
"""
|
2 |
+
Here, what should you do here?
|
3 |
+
just upload all idioms here - name it as epie.
|
|
|
4 |
"""
|
5 |
+
import os
|
6 |
+
from idiomify.paths import ROOT_DIR
|
7 |
+
from idiomify.fetchers import fetch_pie
|
8 |
+
import argparse
|
9 |
+
import wandb
|
10 |
|
11 |
|
12 |
def main():
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("--ver", type=str, default="pie_v0",
|
15 |
+
choices=["pie_v0", "pie_v1"])
|
16 |
+
config = vars(parser.parse_args())
|
17 |
+
|
18 |
+
# get the idioms here
|
19 |
+
if config['ver'] == "pie_v0":
|
20 |
+
# only the first 106, and this is for piloting
|
21 |
+
idioms = set([row[0] for row in fetch_pie()[:106]])
|
22 |
+
elif config['ver'] == "pie_v1":
|
23 |
+
# just include all
|
24 |
+
idioms = set([row[0] for row in fetch_pie()])
|
25 |
+
else:
|
26 |
+
raise NotImplementedError
|
27 |
+
idioms = list(idioms)
|
28 |
+
|
29 |
+
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
30 |
+
artifact = wandb.Artifact(name="idioms", type="dataset")
|
31 |
+
txt_path = ROOT_DIR / "all.txt"
|
32 |
+
with open(txt_path, 'w') as fh:
|
33 |
+
for idiom in idioms:
|
34 |
+
fh.write(idiom + "\n")
|
35 |
+
artifact.add_file(txt_path)
|
36 |
+
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
37 |
+
os.remove(txt_path)
|
38 |
|
39 |
|
40 |
if __name__ == '__main__':
|
main_upload_literal2idiom.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Here, what should you do here?
|
3 |
+
just upload all idioms here - name it as epie.
|
4 |
+
"""
|
5 |
+
import csv
|
6 |
+
import os
|
7 |
+
from idiomify.paths import ROOT_DIR
|
8 |
+
from idiomify.fetchers import fetch_pie
|
9 |
+
import argparse
|
10 |
+
import wandb
|
11 |
+
|
12 |
+
|
13 |
+
def main():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument("--ver", type=str, default="pie_v0",
|
16 |
+
choices=["pie_v0", "pie_v1"])
|
17 |
+
config = vars(parser.parse_args())
|
18 |
+
|
19 |
+
# get the idioms here
|
20 |
+
if config['ver'] == "pie_v0":
|
21 |
+
# only the first 106, and we use this just for piloting
|
22 |
+
literal2idiom = [
|
23 |
+
(row[3], row[2]) for row in fetch_pie()[:106]
|
24 |
+
]
|
25 |
+
elif config['ver'] == "pie_v1":
|
26 |
+
# just include all
|
27 |
+
literal2idiom = [
|
28 |
+
(row[3], row[2]) for row in fetch_pie()
|
29 |
+
]
|
30 |
+
else:
|
31 |
+
raise NotImplementedError
|
32 |
+
|
33 |
+
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
34 |
+
artifact = wandb.Artifact(name="literal2idiom", type="dataset")
|
35 |
+
tsv_path = ROOT_DIR / "all.tsv"
|
36 |
+
with open(tsv_path, 'w') as fh:
|
37 |
+
writer = csv.writer(fh, delimiter="\t")
|
38 |
+
for row in literal2idiom:
|
39 |
+
writer.writerow(row)
|
40 |
+
artifact.add_file(tsv_path)
|
41 |
+
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
42 |
+
os.remove(tsv_path)
|
43 |
+
|
44 |
+
|
45 |
+
if __name__ == '__main__':
|
46 |
+
main()
|
main_upload_tokenizer.py
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Build & upload a tokenizer to wandb.
|
3 |
-
You need this if you were to add more tokens there.
|
4 |
-
"""
|
5 |
-
|
6 |
-
|
7 |
-
def main():
|
8 |
-
pass
|
9 |
-
# TODO: fetch the dataset from wandb first!
|
10 |
-
|
11 |
-
|
12 |
-
if __name__ == '__main__':
|
13 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|