File size: 441 Bytes
ed920f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
from transformers import AutoTokenizer, AutoModel
def load_esm2_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
return tokenizer, model
def get_latents(model, tokenizer, sequence):
inputs = tokenizer(sequence, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.squeeze(0)
|