File size: 4,421 Bytes
d061944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import math
import config
import sys
import pandas as pd
from esm_utils import get_latents
from transformers import AutoModelForMaskedLM, AutoModel, AutoTokenizer


def mask_for_de_novo(sequence_length):
    return "<mask>" * sequence_length

def generate_de_novo(sequence_length, tokenizer, model):
    masked_sequence = mask_for_de_novo(sequence_length)
    inputs = tokenizer(masked_sequence, return_tensors='pt').to(model.device)

    with torch.no_grad():
        logits = model(**inputs).logits
    mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
    logits_at_masks = logits[0, mask_token_indices]

    pred_tokens = []
    for i in mask_token_indices:
        topk_logits, topk_indices = logits_at_masks[i].topk(k=3, dim=-1)
        probabilities = torch.nn.functional.softmax(topk_logits, dim=-1)
        predicted_index = torch.distributions.categorical.Categorical(probabilities).sample()
        predicted_token_id = topk_indices[predicted_index].item()
        predicted_token = tokenizer.decode([predicted_token_id], skip_special_tokens=True)
        pred_tokens.append(predicted_token)
    
    generated_sequence = ''.join(pred_tokens)
    perplexity = calculate_perplexity(model, tokenizer, generated_sequence)

    return (generated_sequence, perplexity)


def mask_for_scaffold(sequence, generate_type):
    if generate_type == "uppercase":
        sequence = ''.join(["<mask>" if residue.isupper() else residue.upper() for residue in sequence])
    elif generate_type == "lowercase":
        sequence = ''.join(["<mask>" if residue.islower() else residue for residue in sequence])   
    return sequence


def generate_scaffold(sequence, generate_type, tokenizer, model):
    masked_sequence = mask_for_scaffold(sequence, generate_type)
    inputs = tokenizer(masked_sequence, return_tensors='pt').to(model.device)

    with torch.no_grad():
        logits = model(**inputs).logits
    mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
    logits_at_masks = logits[0, mask_token_indices]

    pred_tokens = []
    for i in range(len(mask_token_indices)):
        topk_logits, topk_indices = logits_at_masks[i].topk(k=3, dim=-1)
        probabilities = torch.nn.functional.softmax(topk_logits, dim=-1)
        predicted_index = torch.distributions.categorical.Categorical(probabilities).sample()
        predicted_token_id = topk_indices[predicted_index].item()
        predicted_token = tokenizer.decode([predicted_token_id], skip_special_tokens=True)

        pred_tokens.append('G' if predicted_token == '' else predicted_token)

    generated_sequence = masked_sequence
    for token in pred_tokens:
        generated_sequence = generated_sequence.replace("<mask>", token, 1)

    return generated_sequence, mask_token_indices


def calculate_perplexity(model, tokenizer, generated_sequence, mask_token_indices):
    total_loss = 0.0
    tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)

    for i in mask_token_indices:
        masked_input = tensor_input.clone()
        masked_input[0, i] = tokenizer.mask_token_id
    
        labels = torch.full(tensor_input.shape, -100).to(model.device)
        labels[0, i] = tensor_input[0, i]

        with torch.no_grad():
            outputs = model(masked_input, labels=labels)
            total_loss += outputs.loss.item()
    
    num_mask_tokens = len(mask_token_indices)
    if num_mask_tokens == 0:
        perplexity = 10000
    else:
        avg_loss = total_loss / num_mask_tokens
        perplexity = math.exp(avg_loss)

    return perplexity


def calculate_cosine_sim(original_sequence, generated_sequence, tokenizer, esm_model, device):
    og_embeddings = get_latents(esm_model, tokenizer, original_sequence.upper(), device)
    new_embeddings = get_latents(esm_model, tokenizer, generated_sequence, device)

    sequence_similarity = torch.nn.functional.cosine_similarity(og_embeddings, new_embeddings, dim=-1)
    cosine_similarity = torch.mean(sequence_similarity).item()
    return cosine_similarity
    

def calculate_hamming_dist(original_sequence, generated_sequence):
    generated_sequence = generated_sequence.upper()
    original_sequence = original_sequence.upper()
    return sum(1 if original_sequence[i] != generated_sequence[i] else 0 for i in range(len(original_sequence)))