Spaces:
Runtime error
Runtime error
Soumic
commited on
Commit
·
08f515b
1
Parent(s):
c334cb2
:lady_beetle: Trying to repair git related issues
Browse files- .gitignore +2 -1
- app.py +27 -2
.gitignore
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
lightning_logs/
|
2 |
-
*.pth
|
|
|
|
1 |
lightning_logs/
|
2 |
+
*.pth
|
3 |
+
my-awesome-model/
|
app.py
CHANGED
@@ -13,6 +13,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAtte
|
|
13 |
import torch
|
14 |
from torch import nn
|
15 |
from datasets import load_dataset
|
|
|
16 |
|
17 |
timber = logging.getLogger()
|
18 |
# logging.basicConfig(level=logging.DEBUG)
|
@@ -296,7 +297,9 @@ def create_conv_sequence(in_channel_num_of_nucleotides, num_filters, kernel_size
|
|
296 |
return nn.Sequential(conv1d, activation, pooling)
|
297 |
|
298 |
|
299 |
-
class Cnn1dClassifier(nn.Module
|
|
|
|
|
300 |
def __init__(self,
|
301 |
seq_len,
|
302 |
in_channel_num_of_nucleotides=4,
|
@@ -382,6 +385,7 @@ def start(classifier_model, model_save_path, is_attention_model=False, m_optimiz
|
|
382 |
data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset)
|
383 |
|
384 |
classifier_model = classifier_model #.to(DEVICE)
|
|
|
385 |
|
386 |
classifier_module = MQtlClassifierLightningModule(classifier=classifier_model, regularization=2,
|
387 |
m_optimizer=m_optimizer)
|
@@ -398,7 +402,21 @@ def start(classifier_model, model_save_path, is_attention_model=False, m_optimiz
|
|
398 |
timber.info("\n\n")
|
399 |
torch.save(classifier_module.state_dict(), model_save_path)
|
400 |
|
401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
|
403 |
# start_interpreting_ig_and_dl(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
|
404 |
# start_interpreting_with_dlshap(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
|
@@ -416,3 +434,10 @@ if __name__ == '__main__':
|
|
416 |
dataset_folder_prefix="inputdata/", is_debug=True, max_epochs=3)
|
417 |
|
418 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
import torch
|
14 |
from torch import nn
|
15 |
from datasets import load_dataset
|
16 |
+
from huggingface_hub import PyTorchModelHubMixin
|
17 |
|
18 |
timber = logging.getLogger()
|
19 |
# logging.basicConfig(level=logging.DEBUG)
|
|
|
297 |
return nn.Sequential(conv1d, activation, pooling)
|
298 |
|
299 |
|
300 |
+
class Cnn1dClassifier(nn.Module,
|
301 |
+
PyTorchModelHubMixin
|
302 |
+
):
|
303 |
def __init__(self,
|
304 |
seq_len,
|
305 |
in_channel_num_of_nucleotides=4,
|
|
|
385 |
data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset)
|
386 |
|
387 |
classifier_model = classifier_model #.to(DEVICE)
|
388 |
+
classifier_model = classifier_model.from_pretrained("my-awesome-model")
|
389 |
|
390 |
classifier_module = MQtlClassifierLightningModule(classifier=classifier_model, regularization=2,
|
391 |
m_optimizer=m_optimizer)
|
|
|
402 |
timber.info("\n\n")
|
403 |
torch.save(classifier_module.state_dict(), model_save_path)
|
404 |
|
405 |
+
# save locally
|
406 |
+
classifier_model.save_pretrained("my-awesome-model")
|
407 |
+
|
408 |
+
# push to the hub
|
409 |
+
classifier_model.push_to_hub(repo_id="fahimfarhan/mqtl-classifier-model", commit_message=":tada: Push model using huggingface_hub")
|
410 |
+
|
411 |
+
# reload
|
412 |
+
model = classifier_model.from_pretrained("my-awesome-model")
|
413 |
+
# repo_url = "https://huggingface.co/fahimfarhan/mqtl-classifier-model"
|
414 |
+
#
|
415 |
+
# push_to_hub(
|
416 |
+
# model_file=classifier_model.file_name, # Replace with your model file path
|
417 |
+
# repo_url=repo_url,
|
418 |
+
# # config_file="config.json" # Optional, if you have a config file
|
419 |
+
# )
|
420 |
|
421 |
# start_interpreting_ig_and_dl(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
|
422 |
# start_interpreting_with_dlshap(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
|
|
|
434 |
dataset_folder_prefix="inputdata/", is_debug=True, max_epochs=3)
|
435 |
|
436 |
pass
|
437 |
+
|
438 |
+
|
439 |
+
"""
|
440 |
+
lightning_logs/
|
441 |
+
*.pth
|
442 |
+
my-awesome-model
|
443 |
+
"""
|