PeteBleackley commited on
Commit
7a9be99
·
1 Parent(s): c8625dc

Modified CombinedCorpus to use PyTorch

Browse files
Files changed (1) hide show
  1. qarac/corpora/CombinedCorpus.py +20 -35
qarac/corpora/CombinedCorpus.py CHANGED
@@ -6,14 +6,11 @@ Created on Wed Sep 20 14:12:34 2023
6
  @author: peter
7
  """
8
 
9
- import itertools
10
  import collections
11
- import numpy
12
- import tensorflow
13
- import keras
14
  from qarac.corpora import CorpusLoader, CorpusRepeater
15
 
16
- class CombinedCorpus(keras.utils.Sequence):
17
 
18
  def __init__(self,tokenizer,**kwargs):
19
  """
@@ -82,23 +79,7 @@ class CombinedCorpus(keras.utils.Sequence):
82
  """
83
  return self.n_batches
84
 
85
- def __getitem__(self,n):
86
- """
87
- Retrieves a batch of data
88
-
89
- Parameters
90
- ----------
91
- n : int
92
- index of batch to retrieve
93
-
94
- Returns
95
- -------
96
- tupe(dict,dict)
97
- Batch of data
98
-
99
- """
100
-
101
- return self.batch(next(self.batches))
102
 
103
  def samples(self):
104
  """
@@ -123,14 +104,14 @@ class CombinedCorpus(keras.utils.Sequence):
123
  Y.update(y)
124
  yield (X,Y)
125
 
126
- def make_batches(self):
127
  batch = []
128
  n=0
129
  for sample in self.samples():
130
  batch.append(sample)
131
  n+=1
132
  if n==32:
133
- yield(batch)
134
  batch = []
135
  n=0
136
 
@@ -149,9 +130,9 @@ class CombinedCorpus(keras.utils.Sequence):
149
 
150
  Returns
151
  -------
152
- X : dict[str,tensorflow.Tensor]
153
  Batched input samples
154
- Y : dict[str,tensorflow.Tensor]
155
  Batched output samples
156
 
157
  """
@@ -167,12 +148,16 @@ class CombinedCorpus(keras.utils.Sequence):
167
 
168
  X={key:self.pad(value,self.max_lengths[key])
169
  for (key,value) in X.items()}
170
- Y={key:tensorflow.constant(value) if key=='consistency' else self.pad(value,
171
- self.max_lengths[key],
172
- False)
173
  for (key,value) in Y.items()}
174
- Y['question_answering'] = tensorflow.zeros((n,768))
175
- return (X,Y)
 
 
 
 
176
 
177
  def pad(self,batch,maxlen,inputs=True):
178
  """
@@ -191,12 +176,12 @@ class CombinedCorpus(keras.utils.Sequence):
191
  """
192
  for sample in batch:
193
  sample.pad(maxlen,pad_id=self.pad_token)
194
- input_ids = tensorflow.constant([sample.ids
195
- for sample in batch])
196
  result = input_ids
197
  if inputs:
198
- attention_mask = tensorflow.constant(numpy.not_equal(input_ids.numpy(),
199
- self.pad_token).astype(int))
200
  result = {'input_ids':input_ids,
201
  'attention_mask':attention_mask}
202
  return result
 
6
  @author: peter
7
  """
8
 
 
9
  import collections
10
+ import torch
 
 
11
  from qarac.corpora import CorpusLoader, CorpusRepeater
12
 
13
+ class CombinedCorpus(torch.utils.data.IterableDataset()):
14
 
15
  def __init__(self,tokenizer,**kwargs):
16
  """
 
79
  """
80
  return self.n_batches
81
 
82
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  def samples(self):
85
  """
 
104
  Y.update(y)
105
  yield (X,Y)
106
 
107
+ def __iter__(self):
108
  batch = []
109
  n=0
110
  for sample in self.samples():
111
  batch.append(sample)
112
  n+=1
113
  if n==32:
114
+ yield(self.batch(batch))
115
  batch = []
116
  n=0
117
 
 
130
 
131
  Returns
132
  -------
133
+ X : dict[str,torch.Tensor]
134
  Batched input samples
135
+ Y : dict[str,torch.Tensor]
136
  Batched output samples
137
 
138
  """
 
148
 
149
  X={key:self.pad(value,self.max_lengths[key])
150
  for (key,value) in X.items()}
151
+ Y={key:torch.tensor(value) if key=='consistency' else self.pad(value,
152
+ self.max_lengths[key],
153
+ False)
154
  for (key,value) in Y.items()}
155
+ Y['question_answering'] = torch.zeros((n,768))
156
+ return (X,tuple([Y[key]
157
+ for key in ('encode_decode',
158
+ 'question_answering',
159
+ 'reasoning',
160
+ 'consistency')]))
161
 
162
  def pad(self,batch,maxlen,inputs=True):
163
  """
 
176
  """
177
  for sample in batch:
178
  sample.pad(maxlen,pad_id=self.pad_token)
179
+ input_ids = torch.tensor([sample.ids
180
+ for sample in batch])
181
  result = input_ids
182
  if inputs:
183
+ attention_mask = torch.not_equal(input_ids,
184
+ self.pad_token)
185
  result = {'input_ids':input_ids,
186
  'attention_mask':attention_mask}
187
  return result