eubinecto
commited on
Commit
·
ec156ad
1
Parent(s):
539e83f
demo completed, disabled check pointing
Browse files- idiomify/datamodules.py +3 -3
- idiomify/fetchers.py +2 -3
- idiomify/models.py +4 -4
- idiomify/tensors.py +11 -10
- main_infer.py +6 -4
- main_train.py +1 -0
idiomify/datamodules.py
CHANGED
@@ -2,9 +2,9 @@ import torch
|
|
2 |
from typing import Tuple, Optional, List
|
3 |
from torch.utils.data import Dataset, DataLoader
|
4 |
from pytorch_lightning import LightningDataModule
|
5 |
-
from transformers import BertTokenizer
|
6 |
from idiomify.fetchers import fetch_idiom2def
|
7 |
from idiomify import tensors as T
|
|
|
8 |
|
9 |
|
10 |
class IdiomifyDataset(Dataset):
|
@@ -66,8 +66,8 @@ class IdiomifyDataModule(LightningDataModule):
|
|
66 |
"""
|
67 |
# --- set up the builders --- #
|
68 |
# build the datasets
|
69 |
-
X = T.inputs(
|
70 |
-
y = T.targets(self.idioms)
|
71 |
self.dataset = IdiomifyDataset(X, y)
|
72 |
|
73 |
def train_dataloader(self) -> DataLoader:
|
|
|
2 |
from typing import Tuple, Optional, List
|
3 |
from torch.utils.data import Dataset, DataLoader
|
4 |
from pytorch_lightning import LightningDataModule
|
|
|
5 |
from idiomify.fetchers import fetch_idiom2def
|
6 |
from idiomify import tensors as T
|
7 |
+
from transformers import BertTokenizer
|
8 |
|
9 |
|
10 |
class IdiomifyDataset(Dataset):
|
|
|
66 |
"""
|
67 |
# --- set up the builders --- #
|
68 |
# build the datasets
|
69 |
+
X = T.inputs(self.idiom2def, self.tokenizer, self.config['k'])
|
70 |
+
y = T.targets(self.idiom2def, self.idioms)
|
71 |
self.dataset = IdiomifyDataset(X, y)
|
72 |
|
73 |
def train_dataloader(self) -> DataLoader:
|
idiomify/fetchers.py
CHANGED
@@ -2,10 +2,10 @@ import csv
|
|
2 |
import yaml
|
3 |
import wandb
|
4 |
from typing import Tuple, List
|
5 |
-
from transformers import AutoModelForMaskedLM, AutoConfig, BertTokenizer
|
6 |
from idiomify.models import Alpha, Gamma, RD
|
7 |
from idiomify.paths import idiom2def_dir, CONFIG_YAML, idioms_dir, alpha_dir
|
8 |
from idiomify import tensors as T
|
|
|
9 |
|
10 |
|
11 |
# dataset
|
@@ -16,7 +16,6 @@ def fetch_idiom2def(ver: str) -> List[Tuple[str, str]]:
|
|
16 |
tsv_path = artifact_path / "all.tsv"
|
17 |
with open(tsv_path, 'r') as fh:
|
18 |
reader = csv.reader(fh, delimiter="\t")
|
19 |
-
next(reader)
|
20 |
return [
|
21 |
(row[0], row[1])
|
22 |
for row in reader
|
@@ -32,7 +31,7 @@ def fetch_idioms(ver: str) -> List[str]:
|
|
32 |
reader = csv.reader(fh, delimiter="\t")
|
33 |
next(reader)
|
34 |
return [
|
35 |
-
|
36 |
for row in reader
|
37 |
]
|
38 |
|
|
|
2 |
import yaml
|
3 |
import wandb
|
4 |
from typing import Tuple, List
|
|
|
5 |
from idiomify.models import Alpha, Gamma, RD
|
6 |
from idiomify.paths import idiom2def_dir, CONFIG_YAML, idioms_dir, alpha_dir
|
7 |
from idiomify import tensors as T
|
8 |
+
from transformers import AutoModelForMaskedLM, AutoConfig, BertTokenizer
|
9 |
|
10 |
|
11 |
# dataset
|
|
|
16 |
tsv_path = artifact_path / "all.tsv"
|
17 |
with open(tsv_path, 'r') as fh:
|
18 |
reader = csv.reader(fh, delimiter="\t")
|
|
|
19 |
return [
|
20 |
(row[0], row[1])
|
21 |
for row in reader
|
|
|
31 |
reader = csv.reader(fh, delimiter="\t")
|
32 |
next(reader)
|
33 |
return [
|
34 |
+
row[0]
|
35 |
for row in reader
|
36 |
]
|
37 |
|
idiomify/models.py
CHANGED
@@ -3,9 +3,9 @@ The reverse dictionary models below are based off of: https://github.com/yhcc/Be
|
|
3 |
"""
|
4 |
from typing import Tuple, List, Optional
|
5 |
import torch
|
6 |
-
import pytorch_lightning as pl
|
7 |
-
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
8 |
from torch.nn import functional as F
|
|
|
|
|
9 |
|
10 |
|
11 |
class RD(pl.LightningModule):
|
@@ -135,7 +135,7 @@ class RD(pl.LightningModule):
|
|
135 |
|
136 |
def training_epoch_end(self, outputs: List[dict]) -> None:
|
137 |
# to see an average performance over the batches in this specific epoch
|
138 |
-
avg_loss = torch.stack([output['loss'] for output in outputs]).mean()
|
139 |
self.log("Train/Average Loss", avg_loss)
|
140 |
|
141 |
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> dict:
|
@@ -146,7 +146,7 @@ class RD(pl.LightningModule):
|
|
146 |
|
147 |
def validation_epoch_end(self, outputs: List[dict]) -> None:
|
148 |
# to see an average performance over the batches in this specific epoch
|
149 |
-
avg_loss = torch.stack([output['loss'] for output in outputs]).mean()
|
150 |
self.log("Validation/Average Loss", avg_loss)
|
151 |
|
152 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
|
|
3 |
"""
|
4 |
from typing import Tuple, List, Optional
|
5 |
import torch
|
|
|
|
|
6 |
from torch.nn import functional as F
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
from transformers import BertForMaskedLM
|
9 |
|
10 |
|
11 |
class RD(pl.LightningModule):
|
|
|
135 |
|
136 |
def training_epoch_end(self, outputs: List[dict]) -> None:
|
137 |
# to see an average performance over the batches in this specific epoch
|
138 |
+
avg_loss = torch.stack([output['loss'].detach() for output in outputs]).mean()
|
139 |
self.log("Train/Average Loss", avg_loss)
|
140 |
|
141 |
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> dict:
|
|
|
146 |
|
147 |
def validation_epoch_end(self, outputs: List[dict]) -> None:
|
148 |
# to see an average performance over the batches in this specific epoch
|
149 |
+
avg_loss = torch.stack([output['loss'].detach() for output in outputs]).mean()
|
150 |
self.log("Validation/Average Loss", avg_loss)
|
151 |
|
152 |
def configure_optimizers(self) -> torch.optim.Optimizer:
|
idiomify/tensors.py
CHANGED
@@ -3,18 +3,18 @@ all the functions for building tensors are defined here.
|
|
3 |
builders must accept device as one of the parameters.
|
4 |
"""
|
5 |
import torch
|
6 |
-
from typing import List
|
7 |
from transformers import BertTokenizer
|
8 |
|
9 |
|
10 |
def idiom2subwords(idioms: List[str], tokenizer: BertTokenizer, k: int) -> torch.Tensor:
|
11 |
mask_id = tokenizer.mask_token_id
|
12 |
pad_id = tokenizer.pad_token_id
|
13 |
-
# temporarily disable single-token status of the
|
14 |
-
|
15 |
-
encodings = tokenizer(text=
|
16 |
add_special_tokens=False,
|
17 |
-
# should set this to True, as we already have the
|
18 |
is_split_into_words=True,
|
19 |
padding='max_length',
|
20 |
max_length=k, # set to k
|
@@ -24,10 +24,11 @@ def idiom2subwords(idioms: List[str], tokenizer: BertTokenizer, k: int) -> torch
|
|
24 |
return input_ids
|
25 |
|
26 |
|
27 |
-
def inputs(
|
28 |
-
|
|
|
29 |
encodings = tokenizer(text=lefts,
|
30 |
-
text_pair=
|
31 |
return_tensors="pt",
|
32 |
add_special_tokens=True,
|
33 |
truncation=True,
|
@@ -47,9 +48,9 @@ def inputs(definitions: List[str], tokenizer: BertTokenizer, k: int) -> torch.Te
|
|
47 |
desc_mask], dim=1)
|
48 |
|
49 |
|
50 |
-
def targets(idioms: List[str]) -> torch.Tensor:
|
51 |
return torch.LongTensor([
|
52 |
idioms.index(idiom)
|
53 |
-
for idiom in
|
54 |
])
|
55 |
|
|
|
3 |
builders must accept device as one of the parameters.
|
4 |
"""
|
5 |
import torch
|
6 |
+
from typing import List, Tuple
|
7 |
from transformers import BertTokenizer
|
8 |
|
9 |
|
10 |
def idiom2subwords(idioms: List[str], tokenizer: BertTokenizer, k: int) -> torch.Tensor:
|
11 |
mask_id = tokenizer.mask_token_id
|
12 |
pad_id = tokenizer.pad_token_id
|
13 |
+
# temporarily disable single-token status of the idioms
|
14 |
+
idioms = [idiom.split(" ") for idiom in idioms]
|
15 |
+
encodings = tokenizer(text=idioms,
|
16 |
add_special_tokens=False,
|
17 |
+
# should set this to True, as we already have the idioms split.
|
18 |
is_split_into_words=True,
|
19 |
padding='max_length',
|
20 |
max_length=k, # set to k
|
|
|
24 |
return input_ids
|
25 |
|
26 |
|
27 |
+
def inputs(idiom2def: List[Tuple[str, str]], tokenizer: BertTokenizer, k: int) -> torch.Tensor:
|
28 |
+
defs = [definition for _, definition in idiom2def]
|
29 |
+
lefts = [" ".join(["[MASK]"] * k)] * len(defs)
|
30 |
encodings = tokenizer(text=lefts,
|
31 |
+
text_pair=defs,
|
32 |
return_tensors="pt",
|
33 |
add_special_tokens=True,
|
34 |
truncation=True,
|
|
|
48 |
desc_mask], dim=1)
|
49 |
|
50 |
|
51 |
+
def targets(idiom2def: List[Tuple[str, str]], idioms: List[str]) -> torch.Tensor:
|
52 |
return torch.LongTensor([
|
53 |
idioms.index(idiom)
|
54 |
+
for idiom, _ in idiom2def
|
55 |
])
|
56 |
|
main_infer.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import argparse
|
2 |
-
from idiomify.fetchers import fetch_config, fetch_idioms, fetch_rd
|
3 |
from idiomify import tensors as T
|
|
|
4 |
from transformers import BertTokenizer
|
5 |
|
6 |
|
@@ -11,14 +11,15 @@ def main():
|
|
11 |
parser.add_argument("--ver", type=str,
|
12 |
default="eng2eng")
|
13 |
parser.add_argument("--sent", type=str,
|
14 |
-
default="avoid getting to the point")
|
15 |
args = parser.parse_args()
|
16 |
config = fetch_config()[args.model][args.ver]
|
17 |
config.update(vars(args))
|
18 |
-
tokenizer = BertTokenizer.from_pretrained(config['bert'])
|
19 |
idioms = fetch_idioms(config['idioms_ver'])
|
20 |
-
X = T.inputs([config['sent']], tokenizer, config['k'])
|
21 |
rd = fetch_rd(config['model'], config['ver'])
|
|
|
|
|
|
|
22 |
probs = rd.P_wisdom(X).squeeze().tolist()
|
23 |
wisdom2prob = [
|
24 |
(wisdom, prob)
|
@@ -26,6 +27,7 @@ def main():
|
|
26 |
]
|
27 |
# sort and append
|
28 |
res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
|
|
|
29 |
for idx, (idiom, prob) in enumerate(res):
|
30 |
print(idx, idiom, prob)
|
31 |
|
|
|
1 |
import argparse
|
|
|
2 |
from idiomify import tensors as T
|
3 |
+
from idiomify.fetchers import fetch_config, fetch_rd, fetch_idioms
|
4 |
from transformers import BertTokenizer
|
5 |
|
6 |
|
|
|
11 |
parser.add_argument("--ver", type=str,
|
12 |
default="eng2eng")
|
13 |
parser.add_argument("--sent", type=str,
|
14 |
+
default="to avoid getting to the point")
|
15 |
args = parser.parse_args()
|
16 |
config = fetch_config()[args.model][args.ver]
|
17 |
config.update(vars(args))
|
|
|
18 |
idioms = fetch_idioms(config['idioms_ver'])
|
|
|
19 |
rd = fetch_rd(config['model'], config['ver'])
|
20 |
+
rd.eval()
|
21 |
+
tokenizer = BertTokenizer.from_pretrained(config['bert'])
|
22 |
+
X = T.inputs([("", config['sent'])], tokenizer, config['k'])
|
23 |
probs = rd.P_wisdom(X).squeeze().tolist()
|
24 |
wisdom2prob = [
|
25 |
(wisdom, prob)
|
|
|
27 |
]
|
28 |
# sort and append
|
29 |
res = list(sorted(wisdom2prob, key=lambda x: x[1], reverse=True))
|
30 |
+
print(f"query: {config['sent']}")
|
31 |
for idx, (idiom, prob) in enumerate(res):
|
32 |
print(idx, idiom, prob)
|
33 |
|
main_train.py
CHANGED
@@ -49,6 +49,7 @@ def main():
|
|
49 |
log_every_n_steps=config['log_every_n_steps'],
|
50 |
gpus=torch.cuda.device_count(),
|
51 |
default_root_dir=str(ROOT_DIR),
|
|
|
52 |
logger=logger)
|
53 |
# start training
|
54 |
trainer.fit(model=rd, datamodule=datamodule)
|
|
|
49 |
log_every_n_steps=config['log_every_n_steps'],
|
50 |
gpus=torch.cuda.device_count(),
|
51 |
default_root_dir=str(ROOT_DIR),
|
52 |
+
enable_checkpointing=False,
|
53 |
logger=logger)
|
54 |
# start training
|
55 |
trainer.fit(model=rd, datamodule=datamodule)
|