jennzhuge commited on
Commit
e86736e
·
1 Parent(s): a83006f
Files changed (2) hide show
  1. app.py +39 -13
  2. xgboost_infer.py → infer.py +7 -0
app.py CHANGED
@@ -1,25 +1,51 @@
1
  import json
 
2
  import gradio as gr
 
 
 
3
 
 
4
 
5
  with open("default_inputs.json", "r") as default_inputs_file:
6
  DEFAULT_INPUTS = json.load(default_inputs_file)
7
 
8
-
9
  def set_default_inputs():
10
  return (DEFAULT_INPUTS["dna_sequence"],
11
  DEFAULT_INPUTS["latitude"],
12
  DEFAULT_INPUTS["longitude"])
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def predict_genus():
15
- dna_df = pd.read_csv(dna_file.name)
16
- dnaenv_df = pd.read_csv(dnaenv_file.name)
17
 
18
  results = []
19
-
20
- # envdna_genuses = predict_genus_dna_env(dnaenv_df)
21
- # dna_genuses = predict_genus_dna(dna_df)
22
- # images = [get_genus_image(genus) for genus in top_5_genuses]
23
 
24
  genuses = xgboost_infer.infer()
25
 
@@ -38,11 +64,11 @@ with gr.Blocks() as demo:
38
  gr.Markdown("Welcome to Lofi Amazon Beats' DNA Identifier Tool")
39
 
40
  with gr.Tab("Genus Prediction"):
41
- gr.Markdown("Input a DNA sequence and the coordinates at which its sample was taken to predict the genus of the DNA. Click 'I'm feeling lucky' to see our predictio for a random sequence.")
42
 
43
  # Collect inputs for app (DNA and location)
44
  with gr.Row():
45
- inp_dna = gr.Textbox(label="DNA", placeholder="e.g. AACAATGTA... (will be automatically truncated to 660 characters)")
46
  with gr.Row():
47
  inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. -3.009083")
48
  inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -58.68281")
@@ -57,11 +83,11 @@ with gr.Blocks() as demo:
57
  gr.Markdown('Make plot or table for Top 5 species')
58
 
59
  with gr.Column():
60
- genus_out = gr.Dataframe(headers=["DNA", "Coord", "DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"])
61
- btn_run.click(predict_genus, inputs=[inp_dna, inp_lat, inp_lng], outputs=genus_out)
62
 
63
  with gr.Tab('DNA Embedding Space Similarity Visualizer'):
64
- gr.Markdown("If the highest genus probability is very low for your DNA sequence, we can still examine the DNA embedding of the sequence in relation to known samples or clues.")
65
-
66
 
67
  demo.launch()
 
1
  import json
2
+ import pandas as pd
3
  import gradio as gr
4
+ from transformers import PreTrainedTokenizerFast, BertForMaskedLM
5
+ from datasets import load_dataset
6
+ import xgboost_infer
7
 
8
+ embeddings_train = load_dataset("LofiAmazon/BOLD-Embeddings-Ecolayers-Amazon", split='train').to_pandas()
9
 
10
  with open("default_inputs.json", "r") as default_inputs_file:
11
  DEFAULT_INPUTS = json.load(default_inputs_file)
12
 
 
13
  def set_default_inputs():
14
  return (DEFAULT_INPUTS["dna_sequence"],
15
  DEFAULT_INPUTS["latitude"],
16
  DEFAULT_INPUTS["longitude"])
17
 
18
+ def preprocess():
19
+ ''' prepares app input for the genus prediction model
20
+ '''
21
+ # preprocess DNA seq
22
+ # Replace all symbols in nucraw which are not A, C, G, T with N
23
+ inp_dna = inp_dna.str.replace("[^ACGT]", "N", regex=True)
24
+ # Truncate trailing Ns from nucraw
25
+ inp_dna = inp_dna.str.replace("N+$", "", regex=True)
26
+ # Insert spaces between all k-mers
27
+ inp_dna = inp_dna.apply(lambda x: " ".join([x[i:i+4] for i in range(0, len(x), 4)]))
28
+
29
+ # load model to calculate new embeddings
30
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model, force_download=True)
31
+ tokenizer.add_special_tokens({"pad_token": "<UNK>"})
32
+ bert_model = BertForMaskedLM.from_pretrained(model, force_download=True)
33
+ embed = bert_model.predic(inp_dna)
34
+
35
+ # format lat and lon into coords
36
+ coords = (inp_lat, inp_lng)
37
+ # Grab rasters from the tifs
38
+ ecoLayers = load_dataset("LofiAmazon/Global-Ecolayers")
39
+ temp = pd.DataFrame([coords, embed], columns = ['coord', 'embeddings'])
40
+ data = pd.merge(temp, ecoLayers, on='coord', how='left')
41
+
42
+ return data
43
+
44
  def predict_genus():
45
+ data = preprocess()
46
+ out = xgboost_infer.infer_dna(data)
47
 
48
  results = []
 
 
 
 
49
 
50
  genuses = xgboost_infer.infer()
51
 
 
64
  gr.Markdown("Welcome to Lofi Amazon Beats' DNA Identifier Tool")
65
 
66
  with gr.Tab("Genus Prediction"):
67
+ gr.Markdown("Enter a DNA sequence and the coordinates at which its sample was taken to get a genus prediction. Click 'I'm feeling lucky' to see a prediction for a random sequence.")
68
 
69
  # Collect inputs for app (DNA and location)
70
  with gr.Row():
71
+ inp_dna = gr.Textbox(label="DNA", placeholder="e.g. AACAATGTA... (min 200 and max 660 characters)")
72
  with gr.Row():
73
  inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. -3.009083")
74
  inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -58.68281")
 
83
  gr.Markdown('Make plot or table for Top 5 species')
84
 
85
  with gr.Column():
86
+ genus_out = gr.Dataframe(headers=["DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"])
87
+ btn_run.click(fn=predict_genus, inputs=[inp_dna, inp_lat, inp_lng], outputs=genus_out)
88
 
89
  with gr.Tab('DNA Embedding Space Similarity Visualizer'):
90
+ gr.Markdown("If the highest genus probability is very low for your DNA sequence, we can still examine the DNA embedding of the sequence in relation to known samples for clues.")
91
+
92
 
93
  demo.launch()
xgboost_infer.py → infer.py RENAMED
@@ -6,10 +6,16 @@ from sklearn.preprocessing import LabelEncoder
6
  from datasets import load_dataset
7
  import pickle
8
 
 
9
  def infer_dna(args):
10
  ecoDf = pd.read_csv(args['input_path'], sep='\t')
11
  dnaEmbeds = load_dataset("LofiAmazon/BOLD-Embeddings", split='train')
12
 
 
 
 
 
 
13
  modelDNA = load_checkpoint()
14
  modelDNAEnv = load_checkpoint()
15
 
@@ -49,6 +55,7 @@ def infer_dna(args):
49
  y_dna_probs = modelDNAEnv.predict_proba(X_dna)
50
  DNAEnvGenuses = {}
51
  for i in range(len()):
 
52
  topProbs = np.argsort(y_dna_probs[i], axis=1)[:,-3:]
53
  topClasses = modelDNA.classes_[topProbs]
54
 
 
6
  from datasets import load_dataset
7
  import pickle
8
 
9
+
10
  def infer_dna(args):
11
  ecoDf = pd.read_csv(args['input_path'], sep='\t')
12
  dnaEmbeds = load_dataset("LofiAmazon/BOLD-Embeddings", split='train')
13
 
14
+ # load model to calculate new embeddings
15
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model, force_download=True)
16
+ tokenizer.add_special_tokens({"pad_token": "<UNK>"})
17
+ bert_model = BertForMaskedLM.from_pretrained(model, force_download=True)
18
+
19
  modelDNA = load_checkpoint()
20
  modelDNAEnv = load_checkpoint()
21
 
 
55
  y_dna_probs = modelDNAEnv.predict_proba(X_dna)
56
  DNAEnvGenuses = {}
57
  for i in range(len()):
58
+
59
  topProbs = np.argsort(y_dna_probs[i], axis=1)[:,-3:]
60
  topClasses = modelDNA.classes_[topProbs]
61