|
--- |
|
license: cc |
|
--- |
|
PhyloGPN is a convolutional neural network that takes encoded DNA sequences as input and outputs rate matrix parameters for [Felsenstein's 1981 model](https://en.wikipedia.org/wiki/Models_of_DNA_evolution#F81_model_(Felsenstein_1981)) (the F81 model, for short). It was trained to maximize the likelihood of columns in the [Zoonomia alignment](https://cglgenomics.ucsc.edu/november-2023-nature-zoonomia-with-expanded-primates-alignment/) given a phylogenetic tree. The stationary distribution of the substitution process described by the F81 model indicates the relative viability of each allele at any given locus. As a result, PhyloGPN is formally a genomic language model. It can be used for transfer learning. |
|
|
|
The following Python snippet shows how to obtain embeddings and log rate parameters from PhyloGPN for each site in a batch of sequences. Note that PhyloGPN is designed as a sliding window function: it takes a batch of \\(b\\) encoded sequences of any given length \\(\ell >= 481\\) as input and yields outputs for the \\(b \times (\ell - 480)\\) central positions. |
|
|
|
```python |
|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
checkpoint = "songlab/PhyloGPN" |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) |
|
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True) |
|
|
|
# Example data |
|
seqs = [ |
|
"TATAAA", |
|
"GGCCAATCT", |
|
"CACGTG", |
|
"AGGTCACGT", |
|
"GCCAGCC", |
|
"GGGGATTTCC" |
|
] |
|
|
|
# Output length is input length minus 480 (the receptive field size minus 1) |
|
pad_token = tokenizer.pad_token |
|
pad_size = 481 // 2 |
|
pad_sequence = lambda seq: pad_token * pad_size + seq + pad_token * pad_size |
|
padded_seqs = [pad_sequence(seq) for seq in seqs] |
|
input_tensor = tokenizer(padded_seqs, return_tensors="pt", padding=True)["input_ids"] |
|
|
|
with torch.no_grad(): |
|
padded_embeddings = model.get_embeddings(input_tensor) |
|
padded_logits = model(input_tensor) # These are log rate parameters for the F81 model |
|
|
|
embeddings = [] |
|
logits = [] |
|
|
|
for i in range(len(seqs)): |
|
length = len(seqs[i]) |
|
embeddings.append(padded_embeddings[i, :length]) |
|
logits.append({}) |
|
|
|
for k in "ACGT": |
|
logits[-1][k] = padded_logits[k][i, :length] |
|
``` |