probablybots commited on
Commit
e16f512
·
verified ·
1 Parent(s): 4dfc52f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -8
README.md CHANGED
@@ -18,8 +18,8 @@ mgen test --model SequenceClassification --model.backbone aido_dna_300m --data S
18
  ```python
19
  from modelgenerator.tasks import Embed
20
  model = Embed.from_config({"model.backbone": "aido_dna_300m"}).eval()
21
- collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
22
- embedding = model(collated_batch)
23
  print(embedding.shape)
24
  print(embedding)
25
  ```
@@ -28,8 +28,8 @@ print(embedding)
28
  import torch
29
  from modelgenerator.tasks import SequenceClassification
30
  model = SequenceClassification.from_config({"model.backbone": "aido_dna_300m", "model.n_classes": 2}).eval()
31
- collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
32
- logits = model(collated_batch)
33
  print(logits)
34
  print(torch.argmax(logits, dim=-1))
35
  ```
@@ -38,8 +38,8 @@ print(torch.argmax(logits, dim=-1))
38
  import torch
39
  from modelgenerator.tasks import TokenClassification
40
  model = TokenClassification.from_config({"model.backbone": "aido_dna_300m", "model.n_classes": 3}).eval()
41
- collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
42
- logits = model(collated_batch)
43
  print(logits)
44
  print(torch.argmax(logits, dim=-1))
45
  ```
@@ -47,7 +47,7 @@ print(torch.argmax(logits, dim=-1))
47
  ```python
48
  from modelgenerator.tasks import SequenceRegression
49
  model = SequenceRegression.from_config({"model.backbone": "aido_dna_300m"}).eval()
50
- collated_batch = model.collate({"sequences": ["ACGT", "AGCT"]})
51
- logits = model(collated_batch)
52
  print(logits)
53
  ```
 
18
  ```python
19
  from modelgenerator.tasks import Embed
20
  model = Embed.from_config({"model.backbone": "aido_dna_300m"}).eval()
21
+ transformed_batch = model.transform({"sequences": ["ACGT", "AGCT"]})
22
+ embedding = model(transformed_batch)
23
  print(embedding.shape)
24
  print(embedding)
25
  ```
 
28
  import torch
29
  from modelgenerator.tasks import SequenceClassification
30
  model = SequenceClassification.from_config({"model.backbone": "aido_dna_300m", "model.n_classes": 2}).eval()
31
+ transformed_batch = model.transform({"sequences": ["ACGT", "AGCT"]})
32
+ logits = model(transformed_batch)
33
  print(logits)
34
  print(torch.argmax(logits, dim=-1))
35
  ```
 
38
  import torch
39
  from modelgenerator.tasks import TokenClassification
40
  model = TokenClassification.from_config({"model.backbone": "aido_dna_300m", "model.n_classes": 3}).eval()
41
+ transformed_batch = model.transform({"sequences": ["ACGT", "AGCT"]})
42
+ logits = model(transformed_batch)
43
  print(logits)
44
  print(torch.argmax(logits, dim=-1))
45
  ```
 
47
  ```python
48
  from modelgenerator.tasks import SequenceRegression
49
  model = SequenceRegression.from_config({"model.backbone": "aido_dna_300m"}).eval()
50
+ transformed_batch = model.transform({"sequences": ["ACGT", "AGCT"]})
51
+ logits = model(transformed_batch)
52
  print(logits)
53
  ```