jennzhuge commited on
Commit
89a88ac
·
1 Parent(s): a0e49d5
Files changed (1) hide show
  1. app.py +62 -16
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import json
2
  import pandas as pd
 
3
  import gradio as gr
4
  # from transformers import PreTrainedTokenizerFast, BertForMaskedLM
5
  from datasets import load_dataset
6
  import infer
 
 
7
 
8
  with open("default_inputs.json", "r") as default_inputs_file:
9
  DEFAULT_INPUTS = json.load(default_inputs_file)
@@ -42,6 +45,8 @@ def preprocess():
42
  def predict_genus():
43
  data = preprocess()
44
  out = infer.infer_dna(data)
 
 
45
 
46
  results = []
47
 
@@ -54,35 +59,73 @@ def predict_genus():
54
 
55
  return results
56
 
57
- def tsne():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- return plots
 
 
 
 
 
 
 
 
60
 
61
 
62
  with gr.Blocks() as demo:
63
  # Header section
64
  gr.Markdown("# DNA Identifier Tool")
65
- gr.Markdown("Welcome to Lofi Amazon Beats' DNA Identifier Tool")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  with gr.Tab("Genus Prediction"):
68
- gr.Markdown("Enter a DNA sequence and the coordinates at which its sample was taken to get a genus prediction. Click 'I'm feeling lucky' to see a prediction for a random sequence.")
69
 
70
  # Collect inputs for app (DNA and location)
71
- with gr.Row():
72
- with gr.Column():
73
- inp_dna = gr.Textbox(label="DNA", placeholder="e.g. AACAATGTA... (min 200 and max 660 characters)")
74
 
75
- with gr.Column():
76
- with gr.Row():
77
- inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. -3.009083")
78
- with gr.Row():
79
- inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -58.68281")
80
 
81
- with gr.Row():
82
- btn_run = gr.Button("Predict")
83
 
84
- btn_defaults = gr.Button("I'm feeling lucky")
85
- btn_defaults.click(fn=set_default_inputs, outputs=[inp_dna, inp_lat, inp_lng])
86
 
87
  with gr.Row():
88
  gr.Markdown('Make plot or table for Top 5 species')
@@ -97,6 +140,9 @@ with gr.Blocks() as demo:
97
  with gr.Row() as row:
98
  with gr.Column():
99
  gr.Markdown("Plot of your DNA sequence among other known species clusters.")
 
 
 
100
 
101
  with gr.Column():
102
  gr.Markdown("Plot of the five most common species at your sample coordinate.")
 
1
  import json
2
  import pandas as pd
3
+ import numpy as np
4
  import gradio as gr
5
  # from transformers import PreTrainedTokenizerFast, BertForMaskedLM
6
  from datasets import load_dataset
7
  import infer
8
+ import matplotlib.pyplot as plt
9
+ from sklearn.manifold import TSNE
10
 
11
  with open("default_inputs.json", "r") as default_inputs_file:
12
  DEFAULT_INPUTS = json.load(default_inputs_file)
 
45
  def predict_genus():
46
  data = preprocess()
47
  out = infer.infer_dna(data)
48
+
49
+
50
 
51
  results = []
52
 
 
59
 
60
  return results
61
 
62
+ def tsne_DNA(data, genuses):
63
+ data["embeddings"] = data["embeddings"].apply(lambda x: np.array(list(map(float, x[1:-1].split()))))
64
+
65
+ # Pick genuses with most samples
66
+ top_k = 5
67
+ genus_counts = df["genus"].value_counts()
68
+ top_genuses = genus_counts.head(top_k).index
69
+ df = df[df["genus"].isin(top_genuses)]
70
+
71
+ # Create a t-SNE plot of the embeddings
72
+ n_genus = len(df["genus"].unique())
73
+ tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, n_iter=1000, random_state=0)
74
+
75
+ X = np.stack(df["embeddings"].tolist())
76
+ y = df["genus"].tolist()
77
 
78
+ X_tsne = tsne.fit_transform(X)
79
+
80
+ label_encoder = LabelEncoder()
81
+ y_encoded = label_encoder.fit_transform(y)
82
+
83
+ plot = plt.figure(figsize=(6, 5))
84
+ scatter = plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_encoded, cmap="viridis", alpha=0.7)
85
+
86
+ return plot
87
 
88
 
89
  with gr.Blocks() as demo:
90
  # Header section
91
  gr.Markdown("# DNA Identifier Tool")
92
+ gr.Markdown("Welcome to Lofi Amazon Beats' DNA Identifier Tool. Please enter a DNA sequence and the coordinates at which its sample was taken to get started. Click 'I'm feeling lucky' to see use a random sequence.")
93
+ with gr.Row():
94
+ with gr.Column():
95
+ inp_dna = gr.Textbox(label="DNA", placeholder="e.g. AACAATGTA... (min 200 and max 660 characters)")
96
+
97
+ with gr.Column():
98
+ with gr.Row():
99
+ inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. -3.009083")
100
+ with gr.Row():
101
+ inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -58.68281")
102
+
103
+ with gr.Row():
104
+ btn_run = gr.Button("Predict")
105
+
106
+ btn_defaults = gr.Button("I'm feeling lucky")
107
+ btn_defaults.click(fn=set_default_inputs, outputs=[inp_dna, inp_lat, inp_lng])
108
+
109
 
110
  with gr.Tab("Genus Prediction"):
111
+ # gr.Markdown("Enter a DNA sequence and the coordinates at which its sample was taken to get a genus prediction. Click 'I'm feeling lucky' to see a prediction for a random sequence.")
112
 
113
  # Collect inputs for app (DNA and location)
114
+ # with gr.Row():
115
+ # with gr.Column():
116
+ # inp_dna = gr.Textbox(label="DNA", placeholder="e.g. AACAATGTA... (min 200 and max 660 characters)")
117
 
118
+ # with gr.Column():
119
+ # with gr.Row():
120
+ # inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. -3.009083")
121
+ # with gr.Row():
122
+ # inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -58.68281")
123
 
124
+ # with gr.Row():
125
+ # btn_run = gr.Button("Predict")
126
 
127
+ # btn_defaults = gr.Button("I'm feeling lucky")
128
+ # btn_defaults.click(fn=set_default_inputs, outputs=[inp_dna, inp_lat, inp_lng])
129
 
130
  with gr.Row():
131
  gr.Markdown('Make plot or table for Top 5 species')
 
140
  with gr.Row() as row:
141
  with gr.Column():
142
  gr.Markdown("Plot of your DNA sequence among other known species clusters.")
143
+ plot = gr.Plot("")
144
+
145
+ btn_run.click(fn=tsne_DNA, inputs=[inp_dna, genus_out])
146
 
147
  with gr.Column():
148
  gr.Markdown("Plot of the five most common species at your sample coordinate.")