Spaces:
Build error
Build error
PeteBleackley
commited on
Commit
·
7a9be99
1
Parent(s):
c8625dc
Modified CombinedCorpus to use PyTorch
Browse files- qarac/corpora/CombinedCorpus.py +20 -35
qarac/corpora/CombinedCorpus.py
CHANGED
@@ -6,14 +6,11 @@ Created on Wed Sep 20 14:12:34 2023
|
|
6 |
@author: peter
|
7 |
"""
|
8 |
|
9 |
-
import itertools
|
10 |
import collections
|
11 |
-
import
|
12 |
-
import tensorflow
|
13 |
-
import keras
|
14 |
from qarac.corpora import CorpusLoader, CorpusRepeater
|
15 |
|
16 |
-
class CombinedCorpus(
|
17 |
|
18 |
def __init__(self,tokenizer,**kwargs):
|
19 |
"""
|
@@ -82,23 +79,7 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
82 |
"""
|
83 |
return self.n_batches
|
84 |
|
85 |
-
|
86 |
-
"""
|
87 |
-
Retrieves a batch of data
|
88 |
-
|
89 |
-
Parameters
|
90 |
-
----------
|
91 |
-
n : int
|
92 |
-
index of batch to retrieve
|
93 |
-
|
94 |
-
Returns
|
95 |
-
-------
|
96 |
-
tupe(dict,dict)
|
97 |
-
Batch of data
|
98 |
-
|
99 |
-
"""
|
100 |
-
|
101 |
-
return self.batch(next(self.batches))
|
102 |
|
103 |
def samples(self):
|
104 |
"""
|
@@ -123,14 +104,14 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
123 |
Y.update(y)
|
124 |
yield (X,Y)
|
125 |
|
126 |
-
def
|
127 |
batch = []
|
128 |
n=0
|
129 |
for sample in self.samples():
|
130 |
batch.append(sample)
|
131 |
n+=1
|
132 |
if n==32:
|
133 |
-
yield(batch)
|
134 |
batch = []
|
135 |
n=0
|
136 |
|
@@ -149,9 +130,9 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
149 |
|
150 |
Returns
|
151 |
-------
|
152 |
-
X : dict[str,
|
153 |
Batched input samples
|
154 |
-
Y : dict[str,
|
155 |
Batched output samples
|
156 |
|
157 |
"""
|
@@ -167,12 +148,16 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
167 |
|
168 |
X={key:self.pad(value,self.max_lengths[key])
|
169 |
for (key,value) in X.items()}
|
170 |
-
Y={key:
|
171 |
-
|
172 |
-
|
173 |
for (key,value) in Y.items()}
|
174 |
-
Y['question_answering'] =
|
175 |
-
return (X,Y
|
|
|
|
|
|
|
|
|
176 |
|
177 |
def pad(self,batch,maxlen,inputs=True):
|
178 |
"""
|
@@ -191,12 +176,12 @@ class CombinedCorpus(keras.utils.Sequence):
|
|
191 |
"""
|
192 |
for sample in batch:
|
193 |
sample.pad(maxlen,pad_id=self.pad_token)
|
194 |
-
input_ids =
|
195 |
-
|
196 |
result = input_ids
|
197 |
if inputs:
|
198 |
-
attention_mask =
|
199 |
-
|
200 |
result = {'input_ids':input_ids,
|
201 |
'attention_mask':attention_mask}
|
202 |
return result
|
|
|
6 |
@author: peter
|
7 |
"""
|
8 |
|
|
|
9 |
import collections
|
10 |
+
import torch
|
|
|
|
|
11 |
from qarac.corpora import CorpusLoader, CorpusRepeater
|
12 |
|
13 |
+
class CombinedCorpus(torch.utils.data.IterableDataset()):
|
14 |
|
15 |
def __init__(self,tokenizer,**kwargs):
|
16 |
"""
|
|
|
79 |
"""
|
80 |
return self.n_batches
|
81 |
|
82 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
def samples(self):
|
85 |
"""
|
|
|
104 |
Y.update(y)
|
105 |
yield (X,Y)
|
106 |
|
107 |
+
def __iter__(self):
|
108 |
batch = []
|
109 |
n=0
|
110 |
for sample in self.samples():
|
111 |
batch.append(sample)
|
112 |
n+=1
|
113 |
if n==32:
|
114 |
+
yield(self.batch(batch))
|
115 |
batch = []
|
116 |
n=0
|
117 |
|
|
|
130 |
|
131 |
Returns
|
132 |
-------
|
133 |
+
X : dict[str,torch.Tensor]
|
134 |
Batched input samples
|
135 |
+
Y : dict[str,torch.Tensor]
|
136 |
Batched output samples
|
137 |
|
138 |
"""
|
|
|
148 |
|
149 |
X={key:self.pad(value,self.max_lengths[key])
|
150 |
for (key,value) in X.items()}
|
151 |
+
Y={key:torch.tensor(value) if key=='consistency' else self.pad(value,
|
152 |
+
self.max_lengths[key],
|
153 |
+
False)
|
154 |
for (key,value) in Y.items()}
|
155 |
+
Y['question_answering'] = torch.zeros((n,768))
|
156 |
+
return (X,tuple([Y[key]
|
157 |
+
for key in ('encode_decode',
|
158 |
+
'question_answering',
|
159 |
+
'reasoning',
|
160 |
+
'consistency')]))
|
161 |
|
162 |
def pad(self,batch,maxlen,inputs=True):
|
163 |
"""
|
|
|
176 |
"""
|
177 |
for sample in batch:
|
178 |
sample.pad(maxlen,pad_id=self.pad_token)
|
179 |
+
input_ids = torch.tensor([sample.ids
|
180 |
+
for sample in batch])
|
181 |
result = input_ids
|
182 |
if inputs:
|
183 |
+
attention_mask = torch.not_equal(input_ids,
|
184 |
+
self.pad_token)
|
185 |
result = {'input_ids':input_ids,
|
186 |
'attention_mask':attention_mask}
|
187 |
return result
|