hchen725 commited on
Commit
01313fd
1 Parent(s): 57f02a4

Add function to get number of model embeddings

Browse files
Files changed (1) hide show
  1. geneformer/perturber_utils.py +4 -0
geneformer/perturber_utils.py CHANGED
@@ -156,6 +156,10 @@ def quant_layers(model):
156
  return int(max(layer_nums)) + 1
157
 
158
 
 
 
 
 
159
  def get_model_input_size(model):
160
  return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
161
 
 
156
  return int(max(layer_nums)) + 1
157
 
158
 
159
+ def get_model_embedding_dimensions(model):
160
+ return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[2].strip().replace(")", ""))
161
+
162
+
163
  def get_model_input_size(model):
164
  return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
165