probablybots commited on
Commit
09f8392
1 Parent(s): 0183e81

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -8
README.md CHANGED
@@ -59,8 +59,8 @@ mgen test --model SequenceClassification --model.backbone aido_dna_7b --data Seq
59
  ```python
60
  from modelgenerator.tasks import Embed
61
  model = Embed.from_config({"model.backbone": "aido_dna_7b"}).eval()
62
- collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
63
- embedding = model(collated_batch)
64
  print(embedding.shape)
65
  print(embedding)
66
  ```
@@ -69,8 +69,8 @@ print(embedding)
69
  import torch
70
  from modelgenerator.tasks import SequenceClassification
71
  model = SequenceClassification.from_config({"model.backbone": "aido_dna_7b", "model.n_classes": 2}).eval()
72
- collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
73
- logits = model(collated_batch)
74
  print(logits)
75
  print(torch.argmax(logits, dim=-1))
76
  ```
@@ -79,8 +79,8 @@ print(torch.argmax(logits, dim=-1))
79
  import torch
80
  from modelgenerator.tasks import TokenClassification
81
  model = TokenClassification.from_config({"model.backbone": "aido_dna_7b", "model.n_classes": 3}).eval()
82
- collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
83
- logits = model(collated_batch)
84
  print(logits)
85
  print(torch.argmax(logits, dim=-1))
86
  ```
@@ -88,8 +88,8 @@ print(torch.argmax(logits, dim=-1))
88
  ```python
89
  from modelgenerator.tasks import SequenceRegression
90
  model = SequenceRegression.from_config({"model.backbone": "aido_dna_7b"}).eval()
91
- collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
92
- logits = model(collated_batch)
93
  print(logits)
94
  ```
95
 
 
59
  ```python
60
  from modelgenerator.tasks import Embed
61
  model = Embed.from_config({"model.backbone": "aido_dna_7b"}).eval()
62
+ transformed_batch = model.transform({"sequences": ["ACGT", "AGCT"]})
63
+ embedding = model(transformed_batch)
64
  print(embedding.shape)
65
  print(embedding)
66
  ```
 
69
  import torch
70
  from modelgenerator.tasks import SequenceClassification
71
  model = SequenceClassification.from_config({"model.backbone": "aido_dna_7b", "model.n_classes": 2}).eval()
72
+ transformed_batch = model.transform({"sequences": ["ACGT", "AGCT"]})
73
+ logits = model(transformed_batch)
74
  print(logits)
75
  print(torch.argmax(logits, dim=-1))
76
  ```
 
79
  import torch
80
  from modelgenerator.tasks import TokenClassification
81
  model = TokenClassification.from_config({"model.backbone": "aido_dna_7b", "model.n_classes": 3}).eval()
82
+ transformed_batch = model.transform({"sequences": ["ACGT", "AGCT"]})
83
+ logits = model(transformed_batch)
84
  print(logits)
85
  print(torch.argmax(logits, dim=-1))
86
  ```
 
88
  ```python
89
  from modelgenerator.tasks import SequenceRegression
90
  model = SequenceRegression.from_config({"model.backbone": "aido_dna_7b"}).eval()
91
+ transformed_batch = model.transform({"sequences": ["ACGT", "AGCT"]})
92
+ logits = model(transformed_batch)
93
  print(logits)
94
  ```
95