Commit
•
3a94209
1
Parent(s):
9169bfd
Fixed bug in gen_attention_mask with len > max_len (#158)
Browse files- Fixed bug in gen_attention_mask with len > max_len (7c77bae654e0d93a27e1988e107b3258902b3d05)
Co-authored-by: David Wen <[email protected]>
geneformer/in_silico_perturber.py
CHANGED
@@ -342,7 +342,6 @@ def quant_cos_sims(model,
|
|
342 |
max_range = min(i+forward_batch_size, total_batch_length)
|
343 |
|
344 |
perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
|
345 |
-
|
346 |
# determine if need to pad or truncate batch
|
347 |
minibatch_length_set = set(perturbation_minibatch["length"])
|
348 |
minibatch_lengths = perturbation_minibatch["length"]
|
@@ -354,12 +353,14 @@ def quant_cos_sims(model,
|
|
354 |
|
355 |
if needs_pad_or_trunc == True:
|
356 |
max_len = min(max(minibatch_length_set),model_input_size)
|
|
|
357 |
def pad_or_trunc_example(example):
|
358 |
example["input_ids"] = pad_or_truncate_encoding(example["input_ids"],
|
359 |
pad_token_id,
|
360 |
max_len)
|
361 |
return example
|
362 |
perturbation_minibatch = perturbation_minibatch.map(pad_or_trunc_example, num_proc=nproc)
|
|
|
363 |
perturbation_minibatch.set_format(type="torch")
|
364 |
|
365 |
input_data_minibatch = perturbation_minibatch["input_ids"]
|
@@ -570,6 +571,8 @@ def gen_attention_mask(minibatch_encoding, max_len = None):
|
|
570 |
original_lens = minibatch_encoding["length"]
|
571 |
attention_mask = [[1]*original_len
|
572 |
+[0]*(max_len - original_len)
|
|
|
|
|
573 |
for original_len in original_lens]
|
574 |
return torch.tensor(attention_mask).to("cuda")
|
575 |
|
|
|
342 |
max_range = min(i+forward_batch_size, total_batch_length)
|
343 |
|
344 |
perturbation_minibatch = perturbation_batch.select([i for i in range(i, max_range)])
|
|
|
345 |
# determine if need to pad or truncate batch
|
346 |
minibatch_length_set = set(perturbation_minibatch["length"])
|
347 |
minibatch_lengths = perturbation_minibatch["length"]
|
|
|
353 |
|
354 |
if needs_pad_or_trunc == True:
|
355 |
max_len = min(max(minibatch_length_set),model_input_size)
|
356 |
+
print(max_len)
|
357 |
def pad_or_trunc_example(example):
|
358 |
example["input_ids"] = pad_or_truncate_encoding(example["input_ids"],
|
359 |
pad_token_id,
|
360 |
max_len)
|
361 |
return example
|
362 |
perturbation_minibatch = perturbation_minibatch.map(pad_or_trunc_example, num_proc=nproc)
|
363 |
+
|
364 |
perturbation_minibatch.set_format(type="torch")
|
365 |
|
366 |
input_data_minibatch = perturbation_minibatch["input_ids"]
|
|
|
571 |
original_lens = minibatch_encoding["length"]
|
572 |
attention_mask = [[1]*original_len
|
573 |
+[0]*(max_len - original_len)
|
574 |
+
if original_len <= max_len
|
575 |
+
else [1]*max_len
|
576 |
for original_len in original_lens]
|
577 |
return torch.tensor(attention_mask).to("cuda")
|
578 |
|