Add function to get number of model embeddings
Browse files
geneformer/perturber_utils.py
CHANGED
@@ -156,6 +156,10 @@ def quant_layers(model):
|
|
156 |
return int(max(layer_nums)) + 1
|
157 |
|
158 |
|
|
|
|
|
|
|
|
|
159 |
def get_model_input_size(model):
|
160 |
return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
|
161 |
|
|
|
156 |
return int(max(layer_nums)) + 1
|
157 |
|
158 |
|
159 |
+
def get_model_embedding_dimensions(model):
|
160 |
+
return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[2].strip().replace(")", ""))
|
161 |
+
|
162 |
+
|
163 |
def get_model_input_size(model):
|
164 |
return int(re.split("\(|,", str(model.bert.embeddings.position_embeddings))[1])
|
165 |
|