eubinecto commited on
Commit
3be6142
·
1 Parent(s): e9d1a5a

[#1] SRCBuilder, TGTBuilder implemented and tested

Browse files
explore/explore_src_builder.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BartTokenizer
2
+ from idiomify.builders import SRCBuilder
3
+
4
+ BATCH = [
5
+ ("I could die at any moment", "I could kick the bucket at any moment"),
6
+ ("Speak plainly", "Don't beat around the bush")
7
+ ]
8
+
9
+
10
+ def main():
11
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
12
+ builder = SRCBuilder(tokenizer)
13
+ src = builder(BATCH)
14
+ print(src)
15
+
16
+
17
+ if __name__ == '__main__':
18
+ main()
explore/explore_tgt_builder.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BartTokenizer
2
+ from idiomify.builders import TGTBuilder
3
+
4
+ BATCH = [
5
+ ("I could die at any moment", "I could kick the bucket at any moment"),
6
+ ("Speak plainly", "Don't beat around the bush")
7
+ ]
8
+
9
+
10
+ def main():
11
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
12
+ builder = TGTBuilder(tokenizer)
13
+ tgt_r, tgt = builder(BATCH)
14
+ print(tgt_r)
15
+ print(tgt)
16
+
17
+
18
+ if __name__ == '__main__':
19
+ main()
idiomify/builders.py CHANGED
@@ -45,40 +45,39 @@ class Idiom2SubwordsBuilder(TensorBuilder):
45
  return input_ids
46
 
47
 
48
- class Idiom2ContextBuilder(TensorBuilder):
49
-
50
- def __call__(self, idiom2context: List[Tuple[str, str]]):
51
- """
52
- Given a list of tuples of idiom and context,
53
- it returns a tensor of shape (batch_size, 3, max_seq_len)
54
- :param idiom2context: List[Tuple[str, str]], a list of tuples of idiom and context
55
- :type idiom2context: List[Tuple[str, str]]
56
- :return: The input_ids, token_type_ids, and attention_mask for each context.
57
- """
58
- contexts = [context for _, context in idiom2context]
59
- encodings = self.tokenizer(text=contexts,
60
  return_tensors="pt",
61
- add_special_tokens=True,
62
- truncation=True,
63
  padding=True,
64
- verbose=True)
65
- return torch.stack([encodings['input_ids'],
66
- encodings['token_type_ids'],
67
- encodings['attention_mask']], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
- class TargetsBuilder(TensorBuilder):
71
 
72
- def __call__(self, idiom2sent: List[Tuple[str, str]], idioms: List[str]) -> torch.Tensor:
73
- """
74
- Given a list of idioms and a list of sentences, return a list of indices of the idioms in the sentences
75
- :param idiom2sent: A list of tuples, where each tuple is an idiom and its corresponding sentence
76
- :type idiom2sent: List[Tuple[str, str]]
77
- :param idioms: A list of idioms
78
- :type idioms: List[str]
79
- :return: A tensor of indices of the idioms in the list of idioms.
80
- """
81
- return torch.LongTensor([
82
- idioms.index(idiom)
83
- for idiom, _ in idiom2sent
84
- ])
 
45
  return input_ids
46
 
47
 
48
+ class SRCBuilder(TensorBuilder):
49
+ """
50
+ to be used for both training and inference
51
+ """
52
+ def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> torch.Tensor:
53
+ encodings = self.tokenizer(text=[literal for literal, _ in literal2idiomatic],
 
 
 
 
 
 
54
  return_tensors="pt",
 
 
55
  padding=True,
56
+ truncation=True,
57
+ add_special_tokens=True)
58
+ src = torch.stack([encodings['input_ids'],
59
+ encodings['attention_mask']], dim=1) # (N, 2, L)
60
+ return src # (N, 2, L)
61
+
62
+
63
+ class TGTBuilder(TensorBuilder):
64
+ """
65
+ This is to be used only for training. As for inference, we don't need this.
66
+ """
67
+ def __call__(self, literal2idiomatic: List[Tuple[str, str]]) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ encodings_r = self.tokenizer([
69
+ self.tokenizer.bos_token + idiomatic # starts with bos, but does not end with eos (right-shifted)
70
+ for _, idiomatic in literal2idiomatic
71
+ ], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
72
+ encodings = self.tokenizer([
73
+ idiomatic + self.tokenizer.eos_token # no bos, but ends with eos
74
+ for _, idiomatic in literal2idiomatic
75
+ ], return_tensors="pt", add_special_tokens=False, padding=True, truncation=True)
76
+ tgt_r = torch.stack([encodings_r['input_ids'],
77
+ encodings_r['attention_mask']], dim=1) # (N, 2, L)
78
+ tgt = torch.stack([encodings['input_ids'],
79
+ encodings['attention_mask']], dim=1) # (N, 2, L)
80
+ return tgt_r, tgt
81
 
82
 
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
idiomify/datamodules.py CHANGED
@@ -3,7 +3,7 @@ from typing import Tuple, Optional, List
3
  from torch.utils.data import Dataset, DataLoader
4
  from pytorch_lightning import LightningDataModule
5
  from idiomify.fetchers import fetch_idiom2def, fetch_epie
6
- from idiomify.builders import Idiom2DefBuilder, Idiom2ContextBuilder, TargetsBuilder
7
  from transformers import BertTokenizer
8
 
9
 
@@ -67,7 +67,7 @@ class Idiom2DefDataModule(LightningDataModule):
67
  # --- set up the builders --- #
68
  # build the datasets
69
  X = Idiom2DefBuilder(self.tokenizer)(self.idiom2def, self.config['k'])
70
- y = TargetsBuilder(self.tokenizer)(self.idiom2def, self.idioms)
71
  self.dataset = IdiomifyDataset(X, y)
72
 
73
  def train_dataloader(self) -> DataLoader:
@@ -107,7 +107,7 @@ class Idiom2ContextsDataModule(LightningDataModule):
107
  def setup(self, stage: Optional[str] = None):
108
  # build the datasets
109
  X = Idiom2ContextBuilder(self.tokenizer)(self.idiom2context)
110
- y = TargetsBuilder(self.tokenizer)(self.idiom2context, self.idioms)
111
  self.dataset = IdiomifyDataset(X, y)
112
 
113
  def train_dataloader(self):
 
3
  from torch.utils.data import Dataset, DataLoader
4
  from pytorch_lightning import LightningDataModule
5
  from idiomify.fetchers import fetch_idiom2def, fetch_epie
6
+ from idiomify.builders import Idiom2DefBuilder, Idiom2ContextBuilder, LabelsBuilder
7
  from transformers import BertTokenizer
8
 
9
 
 
67
  # --- set up the builders --- #
68
  # build the datasets
69
  X = Idiom2DefBuilder(self.tokenizer)(self.idiom2def, self.config['k'])
70
+ y = LabelsBuilder(self.tokenizer)(self.idiom2def, self.idioms)
71
  self.dataset = IdiomifyDataset(X, y)
72
 
73
  def train_dataloader(self) -> DataLoader:
 
107
  def setup(self, stage: Optional[str] = None):
108
  # build the datasets
109
  X = Idiom2ContextBuilder(self.tokenizer)(self.idiom2context)
110
+ y = LabelsBuilder(self.tokenizer)(self.idiom2context, self.idioms)
111
  self.dataset = IdiomifyDataset(X, y)
112
 
113
  def train_dataloader(self):