tokenizer-uncropped-input_ids

#275
Files changed (1) hide show
  1. geneformer/tokenizer.py +16 -13
geneformer/tokenizer.py CHANGED
@@ -297,7 +297,7 @@ class TranscriptomeTokenizer:
297
 
298
  return tokenized_cells, file_cell_metadata
299
 
300
- def create_dataset(self, tokenized_cells, cell_metadata, use_generator=False):
301
  print("Creating dataset.")
302
  # create dict for dataset creation
303
  dataset_dict = {"input_ids": tokenized_cells}
@@ -312,21 +312,24 @@ class TranscriptomeTokenizer:
312
  output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
313
  else:
314
  output_dataset = Dataset.from_dict(dataset_dict)
 
 
 
 
 
 
315
 
316
- # truncate dataset
317
- def truncate(example):
318
- example["input_ids"] = example["input_ids"][:2048]
319
- return example
320
-
321
- output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
322
 
323
- # measure lengths of dataset
324
- def measure_length(example):
325
- example["length"] = len(example["input_ids"])
326
  return example
327
 
328
- output_dataset_truncated_w_length = output_dataset_truncated.map(
329
- measure_length, num_proc=self.nproc
 
330
  )
 
 
331
 
332
- return output_dataset_truncated_w_length
 
297
 
298
  return tokenized_cells, file_cell_metadata
299
 
300
+ def create_dataset(self, tokenized_cells, cell_metadata, use_generator=False, keep_uncropped_input_ids=False):
301
  print("Creating dataset.")
302
  # create dict for dataset creation
303
  dataset_dict = {"input_ids": tokenized_cells}
 
312
  output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
313
  else:
314
  output_dataset = Dataset.from_dict(dataset_dict)
315
+
316
+ def format_cell_features(example):
317
+ # Store original uncropped input_ids in separate feature
318
+ if keep_uncropped_input_ids:
319
+ example['input_ids_uncropped'] = example['input_ids']
320
+ example['length_uncropped'] = len(example['input_ids'])
321
 
322
+ # Truncate/Crop input_ids to size 2,048
323
+ example['input_ids'] = example['input_ids'][0:2048]
324
+ example['length'] = len(example['input_ids'])
 
 
 
325
 
 
 
 
326
  return example
327
 
328
+ output_dataset_truncated = output_dataset.map(
329
+ format_cell_features,
330
+ num_proc=self.nproc
331
  )
332
+ return output_dataset_truncated
333
+
334
 
335
+