import torch import transformers from transformers import PreTrainedTokenizerFast import tranception import datasets from tranception import config, model_pytorch import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import gradio as gr tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer", unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", mask_token="[MASK]" ) ####################################################################################################################################### ############################################### HELPER FUNCTIONS #################################################################### ####################################################################################################################################### AA_vocab = "ACDEFGHIKLMNPQRSTVWY" def create_all_single_mutants(sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None): all_single_mutants={} sequence_list=list(sequence) if mutation_range_start is None: mutation_range_start=1 if mutation_range_end is None: mutation_range_end=len(sequence) for position,current_AA in enumerate(sequence[mutation_range_start-1:mutation_range_end]): for mutated_AA in AA_vocab: if current_AA!=mutated_AA: mutated_sequence = sequence_list.copy() mutated_sequence[position] = mutated_AA all_single_mutants[current_AA+str(position+1)+mutated_AA]="".join(mutated_sequence) all_single_mutants = pd.DataFrame.from_dict(all_single_mutants,columns=['mutated_sequence'],orient='index') all_single_mutants.reset_index(inplace=True) all_single_mutants.columns = ['mutant','mutated_sequence'] return all_single_mutants def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None): piv=scores.pivot(index='position',columns='target_AA',values='avg_score').transpose().round(4) fig, ax = plt.subplots(figsize=(len(sequence)*1.2,20)) scores_dict = {} valid_mutant_set=set(scores.mutant) if mutation_range_start is None: mutation_range_start=1 if mutation_range_end is None: mutation_range_start=len(sequence) for target_AA in list(AA_vocab): for position in range(mutation_range_start,mutation_range_end+1): mutant = sequence[position-1]+str(position)+target_AA if mutant in valid_mutant_set: scores_dict[mutant]= float(scores.loc[scores.mutant==mutant,'avg_score']) else: scores_dict[mutant]=0.0 labels = (np.asarray(["{} \n {:.4f}".format(symb,value) for symb, value in scores_dict.items() ])).reshape(len(AA_vocab),mutation_range_end-mutation_range_start+1) heat = sns.heatmap(piv,annot=labels,fmt="",cmap='RdYlGn',linewidths=0.30,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\ cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'}) heat.figure.axes[-1].yaxis.label.set_size(20) #heat.set_title("Fitness scores for all single amino acid substitutions",fontsize=30) heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=30, pad=40) heat.set_xlabel("Sequence position", fontsize = 20) heat.set_ylabel("Amino Acid mutation", fontsize = 20) plt.savefig('fitness_scoring_substitution_matrix.png') return plt def suggest_mutations(scores): intro_message = "The following mutations may be sensible options to improve fitness: \n\n" #Best mutants top_mutants=list(scores.sort_values(by=['avg_score'],ascending=False).head(5).mutant) mutant_recos = "The 5 single mutants with highest predicted fitness are:\n {} \n\n".format(", ".join(top_mutants)) #Best positions positive_scores = scores[scores.avg_score > 0] positive_scores_position_avg = positive_scores.groupby(['position']).mean() top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str)) print(top_positions) position_recos = "The 5 positions with the highest average fitness increase are:\n {}".format(", ".join(top_positions)) return intro_message+mutant_recos+position_recos def get_mutated_protein(sequence,mutant): mutated_sequence = list(sequence) mutated_sequence[int(mutant[1:-1])-1]=mutant[-1] return ''.join(mutated_sequence) def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutation_range_end=None,model_type="Small",scoring_mirror=False,batch_size_inference=20,num_workers=0,AA_vocab=AA_vocab): if model_type=="Small": model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Small",use_auth_token=True) elif model_type=="Medium": model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium",use_auth_token=True) elif model_type=="Large": model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large",use_auth_token=True) model.config.tokenizer = tokenizer all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end) scores = model.score_mutants(DMS_data=all_single_mutants, target_seq=sequence, scoring_mirror=scoring_mirror, batch_size_inference=batch_size_inference, num_workers=num_workers, indel_mode=False ) scores = pd.merge(scores,all_single_mutants,on="mutated_sequence",how="left") scores["position"]=scores["mutant"].map(lambda x: int(x[1:-1])) scores["target_AA"] = scores["mutant"].map(lambda x: x[-1]) score_heatmap = create_scoring_matrix_visual(scores,sequence,AA_vocab,mutation_range_start,mutation_range_end) return score_heatmap,suggest_mutations(scores) ####################################################################################################################################### ############################################### GRADIO INTERFACE #################################################################### ####################################################################################################################################### title = "Interactive in silico directed evolution with Tranception" description = "Perform in silico directed evolution with Tranception to iteratively improve the fitness of a starting protein sequence one mutation at a time. At each step, the Tranception model computes the log likelihood ratios of all possible single amino acid substitution Vs the starting sequence, and outputs a fitness heatmap and recommandations to guide the selection of the mutation to apply. Note: The current version does not currently leverage homologs retrieval at inference time to boost fitness prediction performance." article = "

Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval

" examples=[ ['A4_HUMAN: MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMNVQNGKWDSDPSGTKTCIDTKEGILQYCQEVYPELQITNVVEANQPVTIQNWCKRGRKQCKTHPHFVIPYRCLVGEFVSDALLVPDKCKFLHQERMDVCETHLHWHTVAKETCSEKSTNLHDYGMLLPCGIDKFRGVEFVCCPLAEESDNVDSADAEEDDSDVWWGGADTDYADGSEDKVVEVAEEEEVAEVEEEEADDDEDDEDGDEVEEEAEEPYEEATERTTSIATTTTTTTESVEEVVREVCSEQAETGPCRAMISRWYFDVTEGKCAPFFYGGCGGNRNNFDTEEYCMAVCGSAMSQSLLKTTQEPLARDPVKLPTTAASTPDAVDKYLETPGDENEHAHFQKAKERLEAKHRERMSQVMREWEEAERQAKNLPKADKKAVIQHFQEKVESLEQEAANERQQLVETHMARVEAMLNDRRRLALENYITALQAVPPRPRHVFNMLKKYVRAEQKDRQHTLKHFEHVRMVDPKKAAQIRSQVMTHLRVIYERMNQSLSLLYNVPAVAEEIQDEVDELLQKEQNYSDDVLANMISEPRISYGNDALMPSLTETKTTVELLPVNGEFSLDDLQPWHSFGADSVPANTENEVEPVDARPAADRGLTTRPGSGLTNIKTEEISEVKMDAEFRHDSGYEVHHQKLVFFAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHHGVVEVDAAVTPEERHLSKMQQNGYENPTYKFFEQMQN'], ['ADRB2_HUMAN: MGQPGNGSAFLLAPNGSHAPDHDVTQERDEVWVVGMGIVMSLIVLAIVFGNVLVITAIAKFERLQTVTNYFITSLACADLVMGLAVVPFGAAHILMKMWTFGNFWCEFWTSIDVLCVTASIETLCVIAVDRYFAITSPFKYQSLLTKNKARVIILMVWIVSGLTSFLPIQMHWYRATHQEAINCYANETCCDFFTNQAYAIASSIVSFYVPLVIMVFVYSRVFQEAKRQLQKIDKSEGRFHVQNLSQVEQDGRTGHGLRRSSKFCLKEHKALKTLGIIMGTFTLCWLPFFIVNIVHVIQDNLIRKEVYILLNWIGYVNSGFNPLIYCRSPDFRIAFQELLCLRRSSLKAYGNGYSSNGNTGEQSGYHVEQEKENKLLCEDLPGTEDFVGHQGTVPSDNIDSQGRNCSTNDSLL'], ['AMIE_PSEAE: MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA'], ['P53_HUMAN: MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPRVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD'] ] model_size_selection = gr.Radio(label="Tranception model size", choices=["Small","Medium","Large"], value="Small") protein_sequence_input = gr.Textbox(lines=1, label="Input protein sequence (see below for examples; default = RL40A_YEAST)",value="MQIFVKTLTGKTITLEVESSDTIDNVKSKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGGIIEPSLKALASKYNCDKSVCRKCYARLPPRATNCRKRKCGHTNQLRPKKKLK") mutation_range_start = gr.Number(label="Start of mutation range (min value = 1)",value=1,precision=0) mutation_range_end = gr.Number(label="End of mutation range (leave empty for full lenth)",value=10,precision=0) scoring_mirror = gr.Checkbox(label="Score protein from both directions (leads to more robust fitness predictions, but doubles inference time)") #output ==> find a way to make scroallable output_plot = gr.Plot(label="Fitness scores for all single amino acid substitutions in mutation range") output_recommendations = gr.Textbox(label="Mutation recommendations") gr.Interface( fn=score_and_create_matrix_all_singles, inputs=[protein_sequence_input,mutation_range_start,mutation_range_end,model_size_selection,scoring_mirror], outputs=["plot","text"], title=title, description=description, article=article, examples=examples, enable_queue=True, allow_flagging="never" ).launch(debug=True)