File size: 2,227 Bytes
ec24c08
a8b943d
d2f11d4
2088fd8
 
40ca248
2088fd8
7dedba3
2088fd8
 
 
4d52868
2aa27b1
 
2088fd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
---
license: cc
---
PhyloGPN is a convolutional neural network that takes encoded DNA sequences as input and outputs rate matrix parameters for [Felsenstein's 1981 model](https://en.wikipedia.org/wiki/Models_of_DNA_evolution#F81_model_(Felsenstein_1981)) (the F81 model, for short). It was trained to maximize the likelihood of columns in the [Zoonomia alignment](https://cglgenomics.ucsc.edu/november-2023-nature-zoonomia-with-expanded-primates-alignment/) given a phylogenetic tree. The stationary distribution of the substitution process described by the F81 model indicates the relative viability of each allele at any given locus. As a result, PhyloGPN is formally a genomic language model. It can be used for transfer learning.

The following Python snippet shows how to obtain embeddings and log rate parameters from PhyloGPN for each site in a batch of sequences. Note that PhyloGPN is designed as a sliding window function: it takes a batch of \\(b\\) encoded sequences of any given length \\(\ell >= 481\\) as input and yields outputs for the \\(b \times (\ell - 480)\\) central positions.

```python
import torch
from transformers import AutoModel, AutoTokenizer

checkpoint = "songlab/PhyloGPN"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True)

# Example data
seqs = [
    "TATAAA",
    "GGCCAATCT",
    "CACGTG",
    "AGGTCACGT",
    "GCCAGCC",
    "GGGGATTTCC"
]

# Output length is input length minus 480 (the receptive field size minus 1)
pad_token = tokenizer.pad_token
pad_size = 481 // 2
pad_sequence = lambda seq: pad_token * pad_size + seq + pad_token * pad_size
padded_seqs = [pad_sequence(seq) for seq in seqs]
input_tensor = tokenizer(padded_seqs, return_tensors="pt", padding=True)["input_ids"]

with torch.no_grad():
    padded_embeddings = model.get_embeddings(input_tensor)
    padded_logits = model(input_tensor) # These are log rate parameters for the F81 model

embeddings = []
logits = []

for i in range(len(seqs)):
    length = len(seqs[i])
    embeddings.append(padded_embeddings[i, :length])
    logits.append({})

    for k in "ACGT":
        logits[-1][k] = padded_logits[k][i, :length]
```