Commit
•
c90d791
1
Parent(s):
57f02a4
Add function to get number of model embeddings (#364)
Browse files- Add function to get number of model embeddings (01313fd5b2416d697ae4d21db9f1ec80c91ab8bc)
- extract emb_dims and input_size from config (89fc737b00596ae9b4f226da087e872022bf5a26)
Co-authored-by: Han Chen <[email protected]>
geneformer/perturber_utils.py
CHANGED
@@ -156,8 +156,12 @@ def quant_layers(model):
|
|
156 |
return int(max(layer_nums)) + 1
|
157 |
|
158 |
|
|
|
|
|
|
|
|
|
159 |
def get_model_input_size(model):
|
160 |
-
return
|
161 |
|
162 |
|
163 |
def flatten_list(megalist):
|
|
|
156 |
return int(max(layer_nums)) + 1
|
157 |
|
158 |
|
159 |
+
def get_model_emb_dims(model):
|
160 |
+
return model.config.hidden_size
|
161 |
+
|
162 |
+
|
163 |
def get_model_input_size(model):
|
164 |
+
return model.config.max_position_embeddings
|
165 |
|
166 |
|
167 |
def flatten_list(megalist):
|