Soumic commited on
Commit
08f515b
·
1 Parent(s): c334cb2

:lady_beetle: Trying to repair git related issues

Browse files
Files changed (2) hide show
  1. .gitignore +2 -1
  2. app.py +27 -2
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  lightning_logs/
2
- *.pth
 
 
1
  lightning_logs/
2
+ *.pth
3
+ my-awesome-model/
app.py CHANGED
@@ -13,6 +13,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAtte
13
  import torch
14
  from torch import nn
15
  from datasets import load_dataset
 
16
 
17
  timber = logging.getLogger()
18
  # logging.basicConfig(level=logging.DEBUG)
@@ -296,7 +297,9 @@ def create_conv_sequence(in_channel_num_of_nucleotides, num_filters, kernel_size
296
  return nn.Sequential(conv1d, activation, pooling)
297
 
298
 
299
- class Cnn1dClassifier(nn.Module):
 
 
300
  def __init__(self,
301
  seq_len,
302
  in_channel_num_of_nucleotides=4,
@@ -382,6 +385,7 @@ def start(classifier_model, model_save_path, is_attention_model=False, m_optimiz
382
  data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset)
383
 
384
  classifier_model = classifier_model #.to(DEVICE)
 
385
 
386
  classifier_module = MQtlClassifierLightningModule(classifier=classifier_model, regularization=2,
387
  m_optimizer=m_optimizer)
@@ -398,7 +402,21 @@ def start(classifier_model, model_save_path, is_attention_model=False, m_optimiz
398
  timber.info("\n\n")
399
  torch.save(classifier_module.state_dict(), model_save_path)
400
 
401
- trainer.push_to_hub("fahimfarhan/mqtl-classifier-model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
  # start_interpreting_ig_and_dl(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
404
  # start_interpreting_with_dlshap(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
@@ -416,3 +434,10 @@ if __name__ == '__main__':
416
  dataset_folder_prefix="inputdata/", is_debug=True, max_epochs=3)
417
 
418
  pass
 
 
 
 
 
 
 
 
13
  import torch
14
  from torch import nn
15
  from datasets import load_dataset
16
+ from huggingface_hub import PyTorchModelHubMixin
17
 
18
  timber = logging.getLogger()
19
  # logging.basicConfig(level=logging.DEBUG)
 
297
  return nn.Sequential(conv1d, activation, pooling)
298
 
299
 
300
+ class Cnn1dClassifier(nn.Module,
301
+ PyTorchModelHubMixin
302
+ ):
303
  def __init__(self,
304
  seq_len,
305
  in_channel_num_of_nucleotides=4,
 
385
  data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset)
386
 
387
  classifier_model = classifier_model #.to(DEVICE)
388
+ classifier_model = classifier_model.from_pretrained("my-awesome-model")
389
 
390
  classifier_module = MQtlClassifierLightningModule(classifier=classifier_model, regularization=2,
391
  m_optimizer=m_optimizer)
 
402
  timber.info("\n\n")
403
  torch.save(classifier_module.state_dict(), model_save_path)
404
 
405
+ # save locally
406
+ classifier_model.save_pretrained("my-awesome-model")
407
+
408
+ # push to the hub
409
+ classifier_model.push_to_hub(repo_id="fahimfarhan/mqtl-classifier-model", commit_message=":tada: Push model using huggingface_hub")
410
+
411
+ # reload
412
+ model = classifier_model.from_pretrained("my-awesome-model")
413
+ # repo_url = "https://huggingface.co/fahimfarhan/mqtl-classifier-model"
414
+ #
415
+ # push_to_hub(
416
+ # model_file=classifier_model.file_name, # Replace with your model file path
417
+ # repo_url=repo_url,
418
+ # # config_file="config.json" # Optional, if you have a config file
419
+ # )
420
 
421
  # start_interpreting_ig_and_dl(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
422
  # start_interpreting_with_dlshap(classifier_model, WINDOW, dataset_folder_prefix=dataset_folder_prefix)
 
434
  dataset_folder_prefix="inputdata/", is_debug=True, max_epochs=3)
435
 
436
  pass
437
+
438
+
439
+ """
440
+ lightning_logs/
441
+ *.pth
442
+ my-awesome-model
443
+ """