[#2] evaluating m-1-2 works. config.yaml simplified.
Browse files- config.yaml +12 -13
- explore/explore_bart_logits_shape.py +1 -1
- explore/explore_idiomifydatamodule.py +1 -1
- idiomify/{data.py → datamodules.py} +1 -1
- idiomify/fetchers.py +2 -2
- idiomify/models.py +3 -21
- idiomify/pipeline.py +24 -0
- main_eval.py +5 -5
- main_infer.py +4 -4
- main_train.py +2 -2
- main_upload_idioms.py +1 -1
- main_upload_literal2idiomatic.py +1 -1
config.yaml
CHANGED
@@ -1,20 +1,19 @@
|
|
1 |
-
|
2 |
ver: m-1-2
|
3 |
desc: just overfitting the model, but on the entire PIE dataset.
|
4 |
bart: facebook/bart-base
|
5 |
lr: 0.0001
|
6 |
literal2idiomatic_ver: d-1-2
|
7 |
-
max_epochs:
|
8 |
-
batch_size:
|
9 |
shuffle: true
|
10 |
|
11 |
-
# for building & uploading datasets or
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
seed: 104
|
|
|
1 |
+
idiomifier:
|
2 |
ver: m-1-2
|
3 |
desc: just overfitting the model, but on the entire PIE dataset.
|
4 |
bart: facebook/bart-base
|
5 |
lr: 0.0001
|
6 |
literal2idiomatic_ver: d-1-2
|
7 |
+
max_epochs: 2
|
8 |
+
batch_size: 40
|
9 |
shuffle: true
|
10 |
|
11 |
+
# for building & uploading datasets or tokenizer
|
12 |
+
idioms:
|
13 |
+
ver: d-1-2
|
14 |
+
description: the set of idioms in the traning set of literal2idiomatic_d-1-2.
|
15 |
+
literal2idiomatic:
|
16 |
+
ver: d-1-2
|
17 |
+
description: PIE data split into train & test set (80 / 20 split). There is no validation set because I don't intend to do any hyperparameter tuning on this thing.
|
18 |
+
train_ratio: 0.8
|
19 |
+
seed: 104
|
|
explore/explore_bart_logits_shape.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
2 |
|
3 |
-
from
|
4 |
|
5 |
|
6 |
CONFIG = {
|
|
|
1 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
2 |
|
3 |
+
from datamodules import IdiomifyDataModule
|
4 |
|
5 |
|
6 |
CONFIG = {
|
explore/explore_idiomifydatamodule.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from transformers import BartTokenizer
|
2 |
-
from idiomify.
|
3 |
|
4 |
|
5 |
CONFIG = {
|
|
|
1 |
from transformers import BartTokenizer
|
2 |
+
from idiomify.datamodules import IdiomifyDataModule
|
3 |
|
4 |
|
5 |
CONFIG = {
|
idiomify/{data.py → datamodules.py}
RENAMED
@@ -84,6 +84,6 @@ class IdiomifyDataModule(LightningDataModule):
|
|
84 |
return DataLoader(self.train_dataset, batch_size=self.config['batch_size'],
|
85 |
shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
|
86 |
|
87 |
-
def test_dataloader(self):
|
88 |
return DataLoader(self.test_dataset, batch_size=self.config['batch_size'],
|
89 |
shuffle=False, num_workers=self.config['num_workers'])
|
|
|
84 |
return DataLoader(self.train_dataset, batch_size=self.config['batch_size'],
|
85 |
shuffle=self.config['shuffle'], num_workers=self.config['num_workers'])
|
86 |
|
87 |
+
def test_dataloader(self) -> DataLoader:
|
88 |
return DataLoader(self.test_dataset, batch_size=self.config['batch_size'],
|
89 |
shuffle=False, num_workers=self.config['num_workers'])
|
idiomify/fetchers.py
CHANGED
@@ -53,9 +53,9 @@ def fetch_idiomifier(ver: str, run: Run = None) -> Idiomifier:
|
|
53 |
The current Idiomifier then turns into a pipeline.
|
54 |
"""
|
55 |
if run:
|
56 |
-
artifact = run.use_artifact(f"
|
57 |
else:
|
58 |
-
artifact = wandb.Api().artifact(f"eubinecto/idiomify/
|
59 |
config = artifact.metadata
|
60 |
artifact_dir = artifact.download(root=seq2seq_dir(ver))
|
61 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
|
|
53 |
The current Idiomifier then turns into a pipeline.
|
54 |
"""
|
55 |
if run:
|
56 |
+
artifact = run.use_artifact(f"idiomifier:{ver}", type="model")
|
57 |
else:
|
58 |
+
artifact = wandb.Api().artifact(f"eubinecto/idiomify/idiomifier:{ver}", type="model")
|
59 |
config = artifact.metadata
|
60 |
artifact_dir = artifact.download(root=seq2seq_dir(ver))
|
61 |
ckpt_path = path.join(artifact_dir, "model.ckpt")
|
idiomify/models.py
CHANGED
@@ -48,19 +48,19 @@ class Idiomifier(pl.LightningModule): # noqa
|
|
48 |
"loss": loss
|
49 |
}
|
50 |
|
51 |
-
def on_train_batch_end(self, outputs: dict, **kwargs):
|
52 |
self.log("Train/Loss", outputs['loss'])
|
53 |
|
54 |
def on_train_epoch_end(self, *args, **kwargs) -> None:
|
55 |
self.log("Train/Accuracy", self.acc_train.compute())
|
56 |
self.acc_train.reset()
|
57 |
|
58 |
-
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], **kwargs):
|
59 |
srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
|
60 |
logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
|
61 |
self.acc_test.update(logits.detach(), target=tgts.detach())
|
62 |
|
63 |
-
def
|
64 |
self.log("Test/Accuracy", self.acc_test.compute())
|
65 |
self.acc_test.reset()
|
66 |
|
@@ -72,21 +72,3 @@ class Idiomifier(pl.LightningModule): # noqa
|
|
72 |
# The authors used Adam, so we might as well use it as well.
|
73 |
return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
|
74 |
|
75 |
-
|
76 |
-
# for inference
|
77 |
-
class Pipeline:
|
78 |
-
|
79 |
-
def __init__(self, model: Idiomifier, tokenizer: BartTokenizer):
|
80 |
-
self.model = model
|
81 |
-
self.builder = SourcesBuilder(tokenizer)
|
82 |
-
|
83 |
-
def __call__(self, src: str, max_length=100) -> str:
|
84 |
-
srcs = self.builder(literal2idiomatic=[(src, "")])
|
85 |
-
pred_ids = self.model.bart.generate(
|
86 |
-
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
87 |
-
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
88 |
-
decoder_start_token_id=self.model.hparams['bos_token_id'],
|
89 |
-
max_length=max_length,
|
90 |
-
).squeeze() # -> (N, L_t) -> (L_t)
|
91 |
-
tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
|
92 |
-
return tgt
|
|
|
48 |
"loss": loss
|
49 |
}
|
50 |
|
51 |
+
def on_train_batch_end(self, outputs: dict, *args, **kwargs):
|
52 |
self.log("Train/Loss", outputs['loss'])
|
53 |
|
54 |
def on_train_epoch_end(self, *args, **kwargs) -> None:
|
55 |
self.log("Train/Accuracy", self.acc_train.compute())
|
56 |
self.acc_train.reset()
|
57 |
|
58 |
+
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], *args, **kwargs):
|
59 |
srcs, tgts_r, tgts = batch # (N, 2, L_s), (N, 2, L_t), (N, 2, L_t)
|
60 |
logits = self.forward(srcs, tgts_r).transpose(1, 2) # ... -> (N, L, |V|) -> (N, |V|, L)
|
61 |
self.acc_test.update(logits.detach(), target=tgts.detach())
|
62 |
|
63 |
+
def on_test_epoch_end(self, *args, **kwargs) -> None:
|
64 |
self.log("Test/Accuracy", self.acc_test.compute())
|
65 |
self.acc_test.reset()
|
66 |
|
|
|
72 |
# The authors used Adam, so we might as well use it as well.
|
73 |
return torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
idiomify/pipeline.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# for inference
|
3 |
+
from transformers import BartTokenizer
|
4 |
+
|
5 |
+
from builders import SourcesBuilder
|
6 |
+
from models import Idiomifier
|
7 |
+
|
8 |
+
|
9 |
+
class Pipeline:
|
10 |
+
|
11 |
+
def __init__(self, model: Idiomifier, tokenizer: BartTokenizer):
|
12 |
+
self.model = model
|
13 |
+
self.builder = SourcesBuilder(tokenizer)
|
14 |
+
|
15 |
+
def __call__(self, src: str, max_length=100) -> str:
|
16 |
+
srcs = self.builder(literal2idiomatic=[(src, "")])
|
17 |
+
pred_ids = self.model.bart.generate(
|
18 |
+
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
19 |
+
attention_mask=srcs[:, 1], # (N, 2, L) -> (N, L)
|
20 |
+
decoder_start_token_id=self.model.hparams['bos_token_id'],
|
21 |
+
max_length=max_length,
|
22 |
+
).squeeze() # -> (N, L_t) -> (L_t)
|
23 |
+
tgt = self.builder.tokenizer.decode(pred_ids, skip_special_tokens=True)
|
24 |
+
return tgt
|
main_eval.py
CHANGED
@@ -5,22 +5,22 @@ import wandb
|
|
5 |
import pytorch_lightning as pl
|
6 |
from pytorch_lightning.loggers import WandbLogger
|
7 |
from transformers import BartTokenizer
|
8 |
-
from idiomify.
|
9 |
from idiomify.fetchers import fetch_config, fetch_idiomifier
|
10 |
-
from paths import ROOT_DIR
|
11 |
|
12 |
|
13 |
def main():
|
14 |
parser = argparse.ArgumentParser()
|
15 |
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
|
|
16 |
args = parser.parse_args()
|
17 |
-
config = fetch_config()['
|
18 |
config.update(vars(args))
|
19 |
-
# prepare the model
|
20 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
21 |
# prepare the datamodule
|
22 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
23 |
-
model = fetch_idiomifier(config['ver'], run)
|
24 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
25 |
logger = WandbLogger(log_model=False)
|
26 |
trainer = pl.Trainer(fast_dev_run=config['fast_dev_run'],
|
|
|
5 |
import pytorch_lightning as pl
|
6 |
from pytorch_lightning.loggers import WandbLogger
|
7 |
from transformers import BartTokenizer
|
8 |
+
from idiomify.datamodules import IdiomifyDataModule
|
9 |
from idiomify.fetchers import fetch_config, fetch_idiomifier
|
10 |
+
from idiomify.paths import ROOT_DIR
|
11 |
|
12 |
|
13 |
def main():
|
14 |
parser = argparse.ArgumentParser()
|
15 |
parser.add_argument("--num_workers", type=int, default=os.cpu_count())
|
16 |
+
parser.add_argument("--fast_dev_run", action="store_true", default=False)
|
17 |
args = parser.parse_args()
|
18 |
+
config = fetch_config()['idiomifier']
|
19 |
config.update(vars(args))
|
|
|
20 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
21 |
# prepare the datamodule
|
22 |
with wandb.init(entity="eubinecto", project="idiomify", config=config) as run:
|
23 |
+
model = fetch_idiomifier(config['ver'], run) # fetch a pre-trained model
|
24 |
datamodule = IdiomifyDataModule(config, tokenizer, run)
|
25 |
logger = WandbLogger(log_model=False)
|
26 |
trainer = pl.Trainer(fast_dev_run=config['fast_dev_run'],
|
main_infer.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import argparse
|
2 |
-
from idiomify.models import
|
3 |
from idiomify.fetchers import fetch_config, fetch_idiomifier
|
4 |
from transformers import BartTokenizer
|
5 |
|
@@ -10,14 +10,14 @@ def main():
|
|
10 |
default="If there's any good to loosing my job,"
|
11 |
" it's that I'll now be able to go to school full-time and finish my degree earlier.")
|
12 |
args = parser.parse_args()
|
13 |
-
config = fetch_config()['
|
14 |
config.update(vars(args))
|
15 |
model = fetch_idiomifier(config['ver'])
|
16 |
model.eval() # this is crucial
|
17 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
18 |
-
|
19 |
src = config['src']
|
20 |
-
tgt =
|
21 |
print(src, "\n->", tgt)
|
22 |
|
23 |
|
|
|
1 |
import argparse
|
2 |
+
from idiomify.models import Pipeline
|
3 |
from idiomify.fetchers import fetch_config, fetch_idiomifier
|
4 |
from transformers import BartTokenizer
|
5 |
|
|
|
10 |
default="If there's any good to loosing my job,"
|
11 |
" it's that I'll now be able to go to school full-time and finish my degree earlier.")
|
12 |
args = parser.parse_args()
|
13 |
+
config = fetch_config()['idiomifier']
|
14 |
config.update(vars(args))
|
15 |
model = fetch_idiomifier(config['ver'])
|
16 |
model.eval() # this is crucial
|
17 |
tokenizer = BartTokenizer.from_pretrained(config['bart'])
|
18 |
+
pipeline = Pipeline(model, tokenizer)
|
19 |
src = config['src']
|
20 |
+
tgt = pipeline(src=config['src'])
|
21 |
print(src, "\n->", tgt)
|
22 |
|
23 |
|
main_train.py
CHANGED
@@ -6,7 +6,7 @@ import pytorch_lightning as pl
|
|
6 |
from termcolor import colored
|
7 |
from pytorch_lightning.loggers import WandbLogger
|
8 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
9 |
-
from idiomify.
|
10 |
from idiomify.fetchers import fetch_config
|
11 |
from idiomify.models import Idiomifier
|
12 |
from idiomify.paths import ROOT_DIR
|
@@ -19,7 +19,7 @@ def main():
|
|
19 |
parser.add_argument("--fast_dev_run", action="store_true", default=False)
|
20 |
parser.add_argument("--upload", dest='upload', action='store_true', default=False)
|
21 |
args = parser.parse_args()
|
22 |
-
config = fetch_config()['
|
23 |
config.update(vars(args))
|
24 |
if not config['upload']:
|
25 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
|
|
6 |
from termcolor import colored
|
7 |
from pytorch_lightning.loggers import WandbLogger
|
8 |
from transformers import BartTokenizer, BartForConditionalGeneration
|
9 |
+
from idiomify.datamodules import IdiomifyDataModule
|
10 |
from idiomify.fetchers import fetch_config
|
11 |
from idiomify.models import Idiomifier
|
12 |
from idiomify.paths import ROOT_DIR
|
|
|
19 |
parser.add_argument("--fast_dev_run", action="store_true", default=False)
|
20 |
parser.add_argument("--upload", dest='upload', action='store_true', default=False)
|
21 |
args = parser.parse_args()
|
22 |
+
config = fetch_config()['idiomifier']
|
23 |
config.update(vars(args))
|
24 |
if not config['upload']:
|
25 |
print(colored("WARNING: YOU CHOSE NOT TO UPLOAD. NOTHING BUT LOGS WILL BE SAVED TO WANDB", color="red"))
|
main_upload_idioms.py
CHANGED
@@ -9,7 +9,7 @@ from idiomify.paths import ROOT_DIR
|
|
9 |
|
10 |
|
11 |
def main():
|
12 |
-
config = fetch_config()['
|
13 |
train_df, _ = fetch_literal2idiomatic(config['ver'])
|
14 |
idioms = train_df['Idiom'].tolist()
|
15 |
idioms = list(set(idioms))
|
|
|
9 |
|
10 |
|
11 |
def main():
|
12 |
+
config = fetch_config()['idioms']
|
13 |
train_df, _ = fetch_literal2idiomatic(config['ver'])
|
14 |
idioms = train_df['Idiom'].tolist()
|
15 |
idioms = list(set(idioms))
|
main_upload_literal2idiomatic.py
CHANGED
@@ -12,7 +12,7 @@ def main():
|
|
12 |
|
13 |
# here, we use all of them, while splitting them into train & test
|
14 |
pie_df = fetch_pie()
|
15 |
-
config = fetch_config()['
|
16 |
train_df, test_df = pie_df.pipe(cleanse)\
|
17 |
.pipe(upsample, seed=config['seed'])\
|
18 |
.pipe(stratified_split, ratio=config['train_ratio'], seed=config['seed'])
|
|
|
12 |
|
13 |
# here, we use all of them, while splitting them into train & test
|
14 |
pie_df = fetch_pie()
|
15 |
+
config = fetch_config()['literal2idiomatic']
|
16 |
train_df, test_df = pie_df.pipe(cleanse)\
|
17 |
.pipe(upsample, seed=config['seed'])\
|
18 |
.pipe(stratified_split, ratio=config['train_ratio'], seed=config['seed'])
|