jamieb-nvs
commited on
Commit
•
caec285
1
Parent(s):
1366905
Adding option to keep uncropped input_ids as a feature
Browse files- geneformer/tokenizer.py +16 -13
geneformer/tokenizer.py
CHANGED
@@ -297,7 +297,7 @@ class TranscriptomeTokenizer:
|
|
297 |
|
298 |
return tokenized_cells, file_cell_metadata
|
299 |
|
300 |
-
def create_dataset(self, tokenized_cells, cell_metadata, use_generator=False):
|
301 |
print("Creating dataset.")
|
302 |
# create dict for dataset creation
|
303 |
dataset_dict = {"input_ids": tokenized_cells}
|
@@ -312,21 +312,24 @@ class TranscriptomeTokenizer:
|
|
312 |
output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
|
313 |
else:
|
314 |
output_dataset = Dataset.from_dict(dataset_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
|
316 |
-
|
317 |
-
|
318 |
-
example[
|
319 |
-
return example
|
320 |
-
|
321 |
-
output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
|
322 |
|
323 |
-
# measure lengths of dataset
|
324 |
-
def measure_length(example):
|
325 |
-
example["length"] = len(example["input_ids"])
|
326 |
return example
|
327 |
|
328 |
-
|
329 |
-
|
|
|
330 |
)
|
|
|
|
|
331 |
|
332 |
-
|
|
|
297 |
|
298 |
return tokenized_cells, file_cell_metadata
|
299 |
|
300 |
+
def create_dataset(self, tokenized_cells, cell_metadata, use_generator=False, keep_uncropped_input_ids=False):
|
301 |
print("Creating dataset.")
|
302 |
# create dict for dataset creation
|
303 |
dataset_dict = {"input_ids": tokenized_cells}
|
|
|
312 |
output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
|
313 |
else:
|
314 |
output_dataset = Dataset.from_dict(dataset_dict)
|
315 |
+
|
316 |
+
def format_cell_features(example):
|
317 |
+
# Store original uncropped input_ids in separate feature
|
318 |
+
if keep_uncropped_input_ids:
|
319 |
+
example['input_ids_uncropped'] = example['input_ids']
|
320 |
+
example['length_uncropped'] = len(example['input_ids'])
|
321 |
|
322 |
+
# Truncate/Crop input_ids to size 2,048
|
323 |
+
example['input_ids'] = example['input_ids'][0:2048]
|
324 |
+
example['length'] = len(example['input_ids'])
|
|
|
|
|
|
|
325 |
|
|
|
|
|
|
|
326 |
return example
|
327 |
|
328 |
+
output_dataset_truncated = output_dataset.map(
|
329 |
+
format_cell_features,
|
330 |
+
num_proc=self.nproc
|
331 |
)
|
332 |
+
return output_dataset_truncated
|
333 |
+
|
334 |
|
335 |
+
|