Spaces:
Build error
Build error
import torch | |
import transformers | |
from transformers import PreTrainedTokenizerFast | |
import tranception | |
import datasets | |
from tranception import config, model_pytorch | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import gradio as gr | |
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tranception/utils/tokenizers/Basic_tokenizer", | |
unk_token="[UNK]", | |
sep_token="[SEP]", | |
pad_token="[PAD]", | |
cls_token="[CLS]", | |
mask_token="[MASK]" | |
) | |
####################################################################################################################################### | |
############################################### HELPER FUNCTIONS #################################################################### | |
####################################################################################################################################### | |
AA_vocab = "ACDEFGHIKLMNPQRSTVWY" | |
def create_all_single_mutants(sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None): | |
all_single_mutants={} | |
sequence_list=list(sequence) | |
if mutation_range_start is None: mutation_range_start=1 | |
if mutation_range_end is None: mutation_range_end=len(sequence) | |
for position,current_AA in enumerate(sequence[mutation_range_start-1:mutation_range_end]): | |
for mutated_AA in AA_vocab: | |
if current_AA!=mutated_AA: | |
mutated_sequence = sequence_list.copy() | |
mutated_sequence[position] = mutated_AA | |
all_single_mutants[current_AA+str(position+1)+mutated_AA]="".join(mutated_sequence) | |
all_single_mutants = pd.DataFrame.from_dict(all_single_mutants,columns=['mutated_sequence'],orient='index') | |
all_single_mutants.reset_index(inplace=True) | |
all_single_mutants.columns = ['mutant','mutated_sequence'] | |
return all_single_mutants | |
def create_scoring_matrix_visual(scores,sequence,AA_vocab=AA_vocab,mutation_range_start=None,mutation_range_end=None): | |
piv=scores.pivot(index='position',columns='target_AA',values='avg_score').transpose().round(4) | |
fig, ax = plt.subplots(figsize=(len(sequence)*1.2,20)) | |
scores_dict = {} | |
valid_mutant_set=set(scores.mutant) | |
if mutation_range_start is None: mutation_range_start=1 | |
if mutation_range_end is None: mutation_range_start=len(sequence) | |
for target_AA in list(AA_vocab): | |
for position in range(mutation_range_start,mutation_range_end+1): | |
mutant = sequence[position-1]+str(position)+target_AA | |
if mutant in valid_mutant_set: | |
scores_dict[mutant]= float(scores.loc[scores.mutant==mutant,'avg_score']) | |
else: | |
scores_dict[mutant]=0.0 | |
labels = (np.asarray(["{} \n {:.4f}".format(symb,value) for symb, value in scores_dict.items() ])).reshape(len(AA_vocab),mutation_range_end-mutation_range_start+1) | |
heat = sns.heatmap(piv,annot=labels,fmt="",cmap='RdYlGn',linewidths=0.30,vmin=np.percentile(scores.avg_score,2),vmax=np.percentile(scores.avg_score,98),\ | |
cbar_kws={'label': 'Log likelihood ratio (mutant / starting sequence)'}) | |
heat.figure.axes[-1].yaxis.label.set_size(20) | |
#heat.set_title("Fitness scores for all single amino acid substitutions",fontsize=30) | |
heat.set_title("Higher predicted scores (green) imply higher protein fitness",fontsize=30, pad=40) | |
heat.set_xlabel("Sequence position", fontsize = 20) | |
heat.set_ylabel("Amino Acid mutation", fontsize = 20) | |
plt.savefig('fitness_scoring_substitution_matrix.png') | |
return plt | |
def suggest_mutations(scores): | |
intro_message = "The following mutations may be sensible options to improve fitness: \n\n" | |
#Best mutants | |
top_mutants=list(scores.sort_values(by=['avg_score'],ascending=False).head(5).mutant) | |
mutant_recos = "The 5 single mutants with highest predicted fitness are:\n {} \n\n".format(", ".join(top_mutants)) | |
#Best positions | |
positive_scores = scores[scores.avg_score > 0] | |
positive_scores_position_avg = positive_scores.groupby(['position']).mean() | |
top_positions=list(positive_scores_position_avg.sort_values(by=['avg_score'],ascending=False).head(5).index.astype(str)) | |
print(top_positions) | |
position_recos = "The 5 positions with the highest average fitness increase are:\n {}".format(", ".join(top_positions)) | |
return intro_message+mutant_recos+position_recos | |
def get_mutated_protein(sequence,mutant): | |
mutated_sequence = list(sequence) | |
mutated_sequence[int(mutant[1:-1])-1]=mutant[-1] | |
return ''.join(mutated_sequence) | |
def score_and_create_matrix_all_singles(sequence,mutation_range_start=None,mutation_range_end=None,model_type="Small",scoring_mirror=False,batch_size_inference=20,num_workers=0,AA_vocab=AA_vocab): | |
if model_type=="Small": | |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Small") | |
elif model_type=="Medium": | |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Medium") | |
elif model_type=="Large": | |
model = tranception.model_pytorch.TranceptionLMHeadModel.from_pretrained(pretrained_model_name_or_path="PascalNotin/Tranception_Large") | |
model.config.tokenizer = tokenizer | |
all_single_mutants = create_all_single_mutants(sequence,AA_vocab,mutation_range_start,mutation_range_end) | |
scores = model.score_mutants(DMS_data=all_single_mutants, | |
target_seq=sequence, | |
scoring_mirror=scoring_mirror, | |
batch_size_inference=batch_size_inference, | |
num_workers=num_workers, | |
indel_mode=False | |
) | |
scores = pd.merge(scores,all_single_mutants,on="mutated_sequence",how="left") | |
scores["position"]=scores["mutant"].map(lambda x: int(x[1:-1])) | |
scores["target_AA"] = scores["mutant"].map(lambda x: x[-1]) | |
score_heatmap = create_scoring_matrix_visual(scores,sequence,AA_vocab,mutation_range_start,mutation_range_end) | |
return score_heatmap,suggest_mutations(scores) | |
####################################################################################################################################### | |
############################################### GRADIO INTERFACE #################################################################### | |
####################################################################################################################################### | |
title = "Interactive in silico directed evolution with Tranception" | |
description = "Perform in silico directed evolution with Tranception to iteratively improve the fitness of a starting protein sequence one mutation at a time. At each step, the Tranception model computes the log likelihood ratios of all possible single amino acid substitution Vs the starting sequence, and outputs a fitness heatmap and recommandations to guide the selection of the mutation to apply. Note: The current version does not currently leverage homologs retrieval at inference time to boost fitness prediction performance." | |
article = "<p style='text-align: center'><a href='https://proceedings.mlr.press/v162/notin22a.html' target='_blank'>Tranception: Protein Fitness Prediction with Autoregressive Transformers and Inference-time Retrieval</a></p>" | |
examples=[ | |
['A4_HUMAN: MLPGLALLLLAAWTARALEVPTDGNAGLLAEPQIAMFCGRLNMHMNVQNGKWDSDPSGTKTCIDTKEGILQYCQEVYPELQITNVVEANQPVTIQNWCKRGRKQCKTHPHFVIPYRCLVGEFVSDALLVPDKCKFLHQERMDVCETHLHWHTVAKETCSEKSTNLHDYGMLLPCGIDKFRGVEFVCCPLAEESDNVDSADAEEDDSDVWWGGADTDYADGSEDKVVEVAEEEEVAEVEEEEADDDEDDEDGDEVEEEAEEPYEEATERTTSIATTTTTTTESVEEVVREVCSEQAETGPCRAMISRWYFDVTEGKCAPFFYGGCGGNRNNFDTEEYCMAVCGSAMSQSLLKTTQEPLARDPVKLPTTAASTPDAVDKYLETPGDENEHAHFQKAKERLEAKHRERMSQVMREWEEAERQAKNLPKADKKAVIQHFQEKVESLEQEAANERQQLVETHMARVEAMLNDRRRLALENYITALQAVPPRPRHVFNMLKKYVRAEQKDRQHTLKHFEHVRMVDPKKAAQIRSQVMTHLRVIYERMNQSLSLLYNVPAVAEEIQDEVDELLQKEQNYSDDVLANMISEPRISYGNDALMPSLTETKTTVELLPVNGEFSLDDLQPWHSFGADSVPANTENEVEPVDARPAADRGLTTRPGSGLTNIKTEEISEVKMDAEFRHDSGYEVHHQKLVFFAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHHGVVEVDAAVTPEERHLSKMQQNGYENPTYKFFEQMQN'], | |
['ADRB2_HUMAN: MGQPGNGSAFLLAPNGSHAPDHDVTQERDEVWVVGMGIVMSLIVLAIVFGNVLVITAIAKFERLQTVTNYFITSLACADLVMGLAVVPFGAAHILMKMWTFGNFWCEFWTSIDVLCVTASIETLCVIAVDRYFAITSPFKYQSLLTKNKARVIILMVWIVSGLTSFLPIQMHWYRATHQEAINCYANETCCDFFTNQAYAIASSIVSFYVPLVIMVFVYSRVFQEAKRQLQKIDKSEGRFHVQNLSQVEQDGRTGHGLRRSSKFCLKEHKALKTLGIIMGTFTLCWLPFFIVNIVHVIQDNLIRKEVYILLNWIGYVNSGFNPLIYCRSPDFRIAFQELLCLRRSSLKAYGNGYSSNGNTGEQSGYHVEQEKENKLLCEDLPGTEDFVGHQGTVPSDNIDSQGRNCSTNDSLL'], | |
['AMIE_PSEAE: MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA'], | |
['P53_HUMAN: MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPRVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD'] | |
] | |
model_size_selection = gr.Radio(label="Tranception model size (larger models are more accurate but are slower at inference)", choices=["Small","Medium","Large"], value="Small") | |
protein_sequence_input = gr.Textbox(lines=1, label="Input protein sequence (see below for examples; default = RL40A_YEAST)",value="MQIFVKTLTGKTITLEVESSDTIDNVKSKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGGIIEPSLKALASKYNCDKSVCRKCYARLPPRATNCRKRKCGHTNQLRPKKKLK") | |
mutation_range_start = gr.Number(label="Start of mutation range (min value = 1)",value=1,precision=0) | |
mutation_range_end = gr.Number(label="End of mutation range (leave empty for full lenth)",value=10,precision=0) | |
scoring_mirror = gr.Checkbox(label="Score protein from both directions (leads to more robust fitness predictions, but doubles inference time)") | |
#output ==> find a way to make scroallable | |
output_plot = gr.Plot(label="Fitness scores for all single amino acid substitutions in mutation range") | |
output_recommendations = gr.Textbox(label="Mutation recommendations") | |
gr.Interface( | |
fn=score_and_create_matrix_all_singles, | |
inputs=[protein_sequence_input,mutation_range_start,mutation_range_end,model_size_selection,scoring_mirror], | |
outputs=["plot","text"], | |
title=title, | |
description=description, | |
article=article, | |
#examples=examples, | |
allow_flagging="never" | |
).launch(debug=True) | |