hchen725 commited on
Commit
be7ceb5
1 Parent(s): 3b78204

Update geneformer/tokenizer.py

Browse files
Files changed (1) hide show
  1. geneformer/tokenizer.py +27 -22
geneformer/tokenizer.py CHANGED
@@ -63,6 +63,16 @@ logger = logging.getLogger(__name__)
63
 
64
  from . import ENSEMBL_MAPPING_FILE, GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
65
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  def rank_genes(gene_vector, gene_tokens):
68
  """
@@ -100,15 +110,15 @@ def sum_ensembl_ids(
100
  assert (
101
  "ensembl_id" in data.ra.keys()
102
  ), "'ensembl_id' column missing from data.ra.keys()"
 
 
 
103
  gene_ids_in_dict = [
104
  gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()
105
  ]
106
- if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
107
- token_genes_unique = True
108
- else:
109
- token_genes_unique = False
110
  if collapse_gene_ids is False:
111
- if token_genes_unique:
 
112
  return data_directory
113
  else:
114
  raise ValueError("Error: data Ensembl IDs non-unique.")
@@ -116,13 +126,11 @@ def sum_ensembl_ids(
116
  gene_ids_collapsed = [
117
  gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id
118
  ]
119
- gene_ids_collapsed_in_dict = [
120
- gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
121
- ]
122
 
123
- if (
124
- len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))
125
- ) and token_genes_unique:
 
126
  return data_directory
127
  else:
128
  dedup_filename = data_directory.with_name(
@@ -198,28 +206,25 @@ def sum_ensembl_ids(
198
  assert (
199
  "ensembl_id" in data.var.columns
200
  ), "'ensembl_id' column missing from data.var"
 
 
 
201
  gene_ids_in_dict = [
202
  gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()
203
  ]
204
- if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
205
- token_genes_unique = True
206
- else:
207
- token_genes_unique = False
208
  if collapse_gene_ids is False:
209
- if token_genes_unique:
 
210
  return data
211
  else:
212
  raise ValueError("Error: data Ensembl IDs non-unique.")
213
 
 
214
  gene_ids_collapsed = [
215
  gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id
216
  ]
217
- gene_ids_collapsed_in_dict = [
218
- gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
219
- ]
220
- if (
221
- len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))
222
- ) and token_genes_unique:
223
  return data
224
 
225
  else:
 
63
 
64
  from . import ENSEMBL_MAPPING_FILE, GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
65
 
66
+ def rename_attr(data_ra_or_ca, old_name, new_name):
67
+ """ Rename attributes
68
+ Args:
69
+ data_ra_or_ca: data as a record array or column attribute
70
+ old_name (str): old name of attribute
71
+ new_name (str): new name of attribute
72
+ """
73
+ data_ra_or_ca[new_name] = data_ra_or_ca[old_name]
74
+ if new_name != old_name:
75
+ del data_ra_or_ca[old_name]
76
 
77
  def rank_genes(gene_vector, gene_tokens):
78
  """
 
110
  assert (
111
  "ensembl_id" in data.ra.keys()
112
  ), "'ensembl_id' column missing from data.ra.keys()"
113
+
114
+ # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
115
+ # Comparing to gene_token_dict here, would not perform any mapping steps
116
  gene_ids_in_dict = [
117
  gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()
118
  ]
 
 
 
 
119
  if collapse_gene_ids is False:
120
+
121
+ if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
122
  return data_directory
123
  else:
124
  raise ValueError("Error: data Ensembl IDs non-unique.")
 
126
  gene_ids_collapsed = [
127
  gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id
128
  ]
 
 
 
129
 
130
+ if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed)):
131
+ # Keep original Ensembl IDs as `ensembl_id_original`
132
+ rename_attr(data.ra, "ensembl_id", "ensembl_id_original")
133
+ data.ra["ensembl_id"] = gene_ids_collapsed
134
  return data_directory
135
  else:
136
  dedup_filename = data_directory.with_name(
 
206
  assert (
207
  "ensembl_id" in data.var.columns
208
  ), "'ensembl_id' column missing from data.var"
209
+
210
+ # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
211
+ # Comparing to gene_token_dict here, would not perform any mapping steps
212
  gene_ids_in_dict = [
213
  gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()
214
  ]
 
 
 
 
215
  if collapse_gene_ids is False:
216
+
217
+ if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
218
  return data
219
  else:
220
  raise ValueError("Error: data Ensembl IDs non-unique.")
221
 
222
+ # Check for when if collapse_gene_ids is True
223
  gene_ids_collapsed = [
224
  gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id
225
  ]
226
+ if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed)):
227
+ data.var.ensembl_id = data.var.ensembl_id.map(gene_mapping_dict)
 
 
 
 
228
  return data
229
 
230
  else: