import torch | |
from transformers import AutoTokenizer, AutoModel | |
def load_esm2_model(model_name): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
return tokenizer, model | |
def get_latents(model, tokenizer, sequence): | |
inputs = tokenizer(sequence, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
return outputs.last_hidden_state.squeeze(0) | |