File size: 10,899 Bytes
1335bda
 
 
 
 
 
5d3f7a9
1335bda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9103754
1335bda
9103754
1335bda
9103754
1335bda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d3f7a9
1335bda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53fe6a0
1335bda
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import transformers
from transformers import PreTrainedTokenizerFast
import tranception
import datasets
from tranception import config, model_pytorch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr

tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer",
                                                unk_token="[UNK]",
                                                sep_token="[SEP]",
                                                pad_token="[PAD]",
                                                cls_token="[CLS]",
                                                mask_token="[MASK]"
                                    )
#######################################################################################################################################
###############################################  HELPER FUNCTIONS  ####################################################################
#######################################################################################################################################

AA_vocab = "ACDEFGHIKLMNPQRSTVWY"
def create_all_single_mutants(sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None):
  all_single_mutants={}
  sequence_list=list(sequence)
  if mutation_range_start is None: mutation_range_start=1
  if mutation_range_end is None: mutation_range_end=len(sequence)
  for position,current_AA in enumerate(sequence[mutation_range_start-1:mutation_range_end]):
    for mutated_AA in AA_vocab:
      if current_AA!=mutated_AA:
        mutated_sequence = sequence_list.copy()
        mutated_sequence[position] = mutated_AA
        all_single_mutants[current_AA+str(position+1)+mutated_AA]="".join(mutated_sequence)
  all_single_mutants = pd.DataFrame.from_dict(all_single_mutants,columns=['mutated_sequence'],orient='index')
  all_single_mutants.reset_index(inplace=True)
  all_single_mutants.columns = ['mutant','mutated_sequence']
  return all_single_mutants

def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None):
  piv=scores.pivot(index='position',columns='target_AA',values='avg_score').transpose().round(4)
  fig, ax = plt.subplots(figsize=(len(sequence)*1.2,20))
  scores_dict = {}
  valid_mutant_set=set(scores.mutant)
  if mutation_range_start is None: mutation_range_start=1
  if mutation_range_end is None: mutation_range_start=len(sequence)
  for target_AA in list(AA_vocab):
    for position in range(mutation_range_start,mutation_range_end+1):
      mutant = sequence[position-1]+str(position)+target_AA
      if mutant in valid_mutant_set:
        scores_dict[mutant]= float(scores.loc[scores.mutant==mutant,'avg_score'])
      else:
        scores_dict[mutant]=0.0
  labels = (np.asarray(["{} \n {:.4f}".format(symb,value) for symb, value in scores_dict.items() ])).reshape(len(AA_vocab),mutation_range_end-mutation_range_start+1)
  heat = sns.heatmap(piv,annot=labels,fmt="",cmap='RdYlGn',linewidths=0.30,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\
              cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'})
  heat.figure.axes[-1].yaxis.label.set_size(20)
  #heat.set_title("Fitness scores for all single amino acid substitutions",fontsize=30)
  heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=30, pad=40)
  heat.set_xlabel("Sequence position", fontsize = 20)
  heat.set_ylabel("Amino Acid mutation", fontsize = 20)
  plt.savefig('fitness_scoring_substitution_matrix.png')
  return plt

def suggest_mutations(scores):
  intro_message = "The following mutations may be sensible options to improve fitness: \n\n"
  #Best mutants
  top_mutants=list(scores.sort_values(by=['avg_score'],ascending=False).head(5).mutant)
  mutant_recos = "The 5 single mutants with highest predicted fitness are:\n {} \n\n".format(", ".join(top_mutants))
  #Best positions
  positive_scores = scores[scores.avg_score > 0]
  positive_scores_position_avg = positive_scores.groupby(['position']).mean()
  top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str))
  print(top_positions)
  position_recos = "The 5 positions with the highest average fitness increase are:\n {}".format(", ".join(top_positions))
  return intro_message+mutant_recos+position_recos

def get_mutated_protein(sequence,mutant):
  mutated_sequence = list(sequence)
  mutated_sequence[int(mutant[1:-1])-1]=mutant[-1]
  return ''.join(mutated_sequence)

def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutation_range_end=None,model_type="Small",scoring_mirror=False,batch_size_inference=20,num_workers=0,AA_vocab=AA_vocab):
  if model_type=="Small":
    model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Small")
  elif model_type=="Medium":
    model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium")
  elif model_type=="Large":
    model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large")
  model.config.tokenizer = tokenizer
  all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end)
  scores = model.score_mutants(DMS_data=all_single_mutants, 
                                    target_seq=sequence, 
                                    scoring_mirror=scoring_mirror, 
                                    batch_size_inference=batch_size_inference,  
                                    num_workers=num_workers, 
                                    indel_mode=False
                                    )
  scores = pd.merge(scores,all_single_mutants,on="mutated_sequence",how="left")
  scores["position"]=scores["mutant"].map(lambda x: int(x[1:-1]))
  scores["target_AA"] = scores["mutant"].map(lambda x: x[-1])
  score_heatmap = create_scoring_matrix_visual(scores,sequence,AA_vocab,mutation_range_start,mutation_range_end)
  return score_heatmap,suggest_mutations(scores)

#######################################################################################################################################
###############################################  GRADIO INTERFACE  ####################################################################
#######################################################################################################################################

title = "Interactive in silico directed evolution with Tranception"
description = "Perform in silico directed evolution with Tranception to iteratively improve the fitness of a starting protein sequence one mutation at a time. At each step, the Tranception model computes the log likelihood ratios of all possible single amino acid substitution Vs the starting sequence, and outputs a fitness heatmap and recommandations to guide the selection of the mutation to apply. Note: The current version does not currently leverage homologs retrieval at inference time to boost fitness prediction performance."
article = "<p style='text-align: center'><a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval</a></p>"
examples=[
['A4_HUMAN: MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMNVQNGKWDSDPSGTKTCIDTKEGILQYCQEVYPELQITNVVEANQPVTIQNWCKRGRKQCKTHPHFVIPYRCLVGEFVSDALLVPDKCKFLHQERMDVCETHLHWHTVAKETCSEKSTNLHDYGMLLPCGIDKFRGVEFVCCPLAEESDNVDSADAEEDDSDVWWGGADTDYADGSEDKVVEVAEEEEVAEVEEEEADDDEDDEDGDEVEEEAEEPYEEATERTTSIATTTTTTTESVEEVVREVCSEQAETGPCRAMISRWYFDVTEGKCAPFFYGGCGGNRNNFDTEEYCMAVCGSAMSQSLLKTTQEPLARDPVKLPTTAASTPDAVDKYLETPGDENEHAHFQKAKERLEAKHRERMSQVMREWEEAERQAKNLPKADKKAVIQHFQEKVESLEQEAANERQQLVETHMARVEAMLNDRRRLALENYITALQAVPPRPRHVFNMLKKYVRAEQKDRQHTLKHFEHVRMVDPKKAAQIRSQVMTHLRVIYERMNQSLSLLYNVPAVAEEIQDEVDELLQKEQNYSDDVLANMISEPRISYGNDALMPSLTETKTTVELLPVNGEFSLDDLQPWHSFGADSVPANTENEVEPVDARPAADRGLTTRPGSGLTNIKTEEISEVKMDAEFRHDSGYEVHHQKLVFFAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHHGVVEVDAAVTPEERHLSKMQQNGYENPTYKFFEQMQN'],
['ADRB2_HUMAN: MGQPGNGSAFLLAPNGSHAPDHDVTQERDEVWVVGMGIVMSLIVLAIVFGNVLVITAIAKFERLQTVTNYFITSLACADLVMGLAVVPFGAAHILMKMWTFGNFWCEFWTSIDVLCVTASIETLCVIAVDRYFAITSPFKYQSLLTKNKARVIILMVWIVSGLTSFLPIQMHWYRATHQEAINCYANETCCDFFTNQAYAIASSIVSFYVPLVIMVFVYSRVFQEAKRQLQKIDKSEGRFHVQNLSQVEQDGRTGHGLRRSSKFCLKEHKALKTLGIIMGTFTLCWLPFFIVNIVHVIQDNLIRKEVYILLNWIGYVNSGFNPLIYCRSPDFRIAFQELLCLRRSSLKAYGNGYSSNGNTGEQSGYHVEQEKENKLLCEDLPGTEDFVGHQGTVPSDNIDSQGRNCSTNDSLL'],
['AMIE_PSEAE: MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA'],
['P53_HUMAN: MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPRVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD']
]

model_size_selection = gr.Radio(label="Tranception model size (larger models are more accurate but are slower at inference)", choices=["Small","Medium","Large"], value="Small")
protein_sequence_input = gr.Textbox(lines=1, label="Input protein sequence (see below for examples; default = RL40A_YEAST)",value="MQIFVKTLTGKTITLEVESSDTIDNVKSKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGGIIEPSLKALASKYNCDKSVCRKCYARLPPRATNCRKRKCGHTNQLRPKKKLK")
mutation_range_start = gr.Number(label="Start of mutation range (min value = 1)",value=1,precision=0)
mutation_range_end = gr.Number(label="End of mutation range (leave empty for full lenth)",value=10,precision=0)
scoring_mirror = gr.Checkbox(label="Score protein from both directions (leads to more robust fitness predictions, but doubles inference time)")

#output ==> find a way to make scroallable
output_plot = gr.Plot(label="Fitness scores for all single amino acid substitutions in mutation range")
output_recommendations = gr.Textbox(label="Mutation recommendations")

gr.Interface(
    fn=score_and_create_matrix_all_singles,
    inputs=[protein_sequence_input,mutation_range_start,mutation_range_end,model_size_selection,scoring_mirror],
    outputs=["plot","text"],
    title=title,
    description=description,
    article=article,
    #examples=examples,
    allow_flagging="never"
).launch(debug=True)