ctheodoris hchen725 commited on
Commit
c90d791
1 Parent(s): 57f02a4

Add function to get number of model embeddings (#364)

Browse files

- Add function to get number of model embeddings (01313fd5b2416d697ae4d21db9f1ec80c91ab8bc)
- extract emb_dims and input_size from config (89fc737b00596ae9b4f226da087e872022bf5a26)


Co-authored-by: Han Chen <[email protected]>

Files changed (1) hide show
  1. geneformer/perturber_utils.py +5 -1
geneformer/perturber_utils.py CHANGED
@@ -156,8 +156,12 @@ def quant_layers(model):
156
  return int(max(layer_nums)) + 1
157
 
158
 
 
 
 
 
159
  def get_model_input_size(model):
160
- return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
161
 
162
 
163
  def flatten_list(megalist):
 
156
  return int(max(layer_nums)) + 1
157
 
158
 
159
+ def get_model_emb_dims(model):
160
+ return model.config.hidden_size
161
+
162
+
163
  def get_model_input_size(model):
164
+ return model.config.max_position_embeddings
165
 
166
 
167
  def flatten_list(megalist):