Christina Theodoris commited on
Commit
57b9778
1 Parent(s): acd253c

Update tokenizer to allow tokenization without custom cell attributes

Browse files
Files changed (1) hide show
  1. geneformer/tokenizer.py +21 -12
geneformer/tokenizer.py CHANGED
@@ -42,7 +42,7 @@ def tokenize_cell(gene_vector, gene_tokens):
42
  class TranscriptomeTokenizer:
43
  def __init__(
44
  self,
45
- custom_attr_name_dict,
46
  nproc=1,
47
  gene_median_file=GENE_MEDIAN_FILE,
48
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
@@ -52,7 +52,7 @@ class TranscriptomeTokenizer:
52
 
53
  Parameters
54
  ----------
55
- custom_attr_name_dict : dict
56
  Dictionary of custom attributes to be added to the dataset.
57
  Keys are the names of the attributes in the loom file.
58
  Values are the names of the attributes in the dataset.
@@ -106,8 +106,9 @@ class TranscriptomeTokenizer:
106
 
107
  def tokenize_files(self, loom_data_directory):
108
  tokenized_cells = []
109
- loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
110
- cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
 
111
 
112
  # loops through directories to tokenize .loom files
113
  for loom_file_path in loom_data_directory.glob("*.loom"):
@@ -116,15 +117,19 @@ class TranscriptomeTokenizer:
116
  loom_file_path
117
  )
118
  tokenized_cells += file_tokenized_cells
119
- for k in loom_cell_attr:
120
- cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
 
 
 
121
 
122
  return tokenized_cells, cell_metadata
123
 
124
  def tokenize_file(self, loom_file_path):
125
- file_cell_metadata = {
126
- attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
127
- }
 
128
 
129
  with lp.connect(str(loom_file_path)) as data:
130
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
@@ -181,15 +186,19 @@ class TranscriptomeTokenizer:
181
  ]
182
 
183
  # add custom attributes for subview to dict
184
- for k in file_cell_metadata.keys():
185
- file_cell_metadata[k] += subview.ca[k].tolist()
 
 
 
186
 
187
  return tokenized_cells, file_cell_metadata
188
 
189
  def create_dataset(self, tokenized_cells, cell_metadata):
190
  # create dict for dataset creation
191
  dataset_dict = {"input_ids": tokenized_cells}
192
- dataset_dict.update(cell_metadata)
 
193
 
194
  # create dataset
195
  output_dataset = Dataset.from_dict(dataset_dict)
 
42
  class TranscriptomeTokenizer:
43
  def __init__(
44
  self,
45
+ custom_attr_name_dict=None,
46
  nproc=1,
47
  gene_median_file=GENE_MEDIAN_FILE,
48
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
 
52
 
53
  Parameters
54
  ----------
55
+ custom_attr_name_dict : None, dict
56
  Dictionary of custom attributes to be added to the dataset.
57
  Keys are the names of the attributes in the loom file.
58
  Values are the names of the attributes in the dataset.
 
106
 
107
  def tokenize_files(self, loom_data_directory):
108
  tokenized_cells = []
109
+ if self.custom_attr_name_dict is not None:
110
+ loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()]
111
+ cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()}
112
 
113
  # loops through directories to tokenize .loom files
114
  for loom_file_path in loom_data_directory.glob("*.loom"):
 
117
  loom_file_path
118
  )
119
  tokenized_cells += file_tokenized_cells
120
+ if self.custom_attr_name_dict is not None:
121
+ for k in loom_cell_attr:
122
+ cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k]
123
+ else:
124
+ cell_metadata = None
125
 
126
  return tokenized_cells, cell_metadata
127
 
128
  def tokenize_file(self, loom_file_path):
129
+ if self.custom_attr_name_dict is not None:
130
+ file_cell_metadata = {
131
+ attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
132
+ }
133
 
134
  with lp.connect(str(loom_file_path)) as data:
135
  # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
 
186
  ]
187
 
188
  # add custom attributes for subview to dict
189
+ if self.custom_attr_name_dict is not None:
190
+ for k in file_cell_metadata.keys():
191
+ file_cell_metadata[k] += subview.ca[k].tolist()
192
+ else:
193
+ file_cell_metadata = None
194
 
195
  return tokenized_cells, file_cell_metadata
196
 
197
  def create_dataset(self, tokenized_cells, cell_metadata):
198
  # create dict for dataset creation
199
  dataset_dict = {"input_ids": tokenized_cells}
200
+ if self.custom_attr_name_dict is not None:
201
+ dataset_dict.update(cell_metadata)
202
 
203
  # create dataset
204
  output_dataset = Dataset.from_dict(dataset_dict)