ctheodoris commited on
Commit
89fc737
1 Parent(s): 01313fd

extract emb_dims and input_size from config

Browse files
Files changed (1) hide show
  1. geneformer/perturber_utils.py +3 -3
geneformer/perturber_utils.py CHANGED
@@ -156,12 +156,12 @@ def quant_layers(model):
156
  return int(max(layer_nums)) + 1
157
 
158
 
159
- def get_model_embedding_dimensions(model):
160
- return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[2].strip().replace(")", ""))
161
 
162
 
163
  def get_model_input_size(model):
164
- return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
165
 
166
 
167
  def flatten_list(megalist):
 
156
  return int(max(layer_nums)) + 1
157
 
158
 
159
+ def get_model_emb_dims(model):
160
+ return model.config.hidden_size
161
 
162
 
163
  def get_model_input_size(model):
164
+ return model.config.max_position_embeddings
165
 
166
 
167
  def flatten_list(megalist):