Spaces:
Runtime error
Runtime error
jennzhuge
commited on
Commit
·
8465e44
1
Parent(s):
040747a
added tsne graph, chanaged default coords
Browse files- .gitignore +2 -1
- __pycache__/app.cpython-39.pyc +0 -0
- __pycache__/config.cpython-39.pyc +0 -0
- app.py +122 -31
- default_inputs.json +2 -2
.gitignore
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
.venv
|
2 |
flagged
|
3 |
*.tif
|
4 |
-
*.tiff
|
|
|
|
1 |
.venv
|
2 |
flagged
|
3 |
*.tif
|
4 |
+
*.tiff
|
5 |
+
.env
|
__pycache__/app.cpython-39.pyc
ADDED
Binary file (10.7 kB). View file
|
|
__pycache__/config.cpython-39.pyc
ADDED
Binary file (1.29 kB). View file
|
|
app.py
CHANGED
@@ -60,7 +60,8 @@ embeddings_model.eval()
|
|
60 |
classification_model.eval()
|
61 |
|
62 |
# Load datasets
|
63 |
-
amazon_ds = load_dataset(DATASETS["amazon"])
|
|
|
64 |
|
65 |
def set_default_inputs():
|
66 |
return (DEFAULT_INPUTS["dna_sequence"],
|
@@ -148,6 +149,22 @@ def predict_genus(method: str, dna_sequence: str, latitude: str, longitude: str)
|
|
148 |
index=[ID_TO_GENUS_MAP[i] for i in top_k.indices.detach().numpy()]
|
149 |
)
|
150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
fig, ax = plt.subplots()
|
152 |
ax.bar(top_k.index.astype(str), top_k.values)
|
153 |
ax.set_ylim(0, 1)
|
@@ -162,12 +179,12 @@ def predict_genus(method: str, dna_sequence: str, latitude: str, longitude: str)
|
|
162 |
return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
163 |
|
164 |
|
165 |
-
def cluster_dna(
|
166 |
-
df = amazon_ds
|
167 |
-
df = df[df["genus"].notna()]
|
168 |
-
|
169 |
genus_counts = df["genus"].value_counts()
|
170 |
-
top_genuses = genus_counts.head(
|
171 |
df = df[df["genus"].isin(top_genuses)]
|
172 |
tsne = TSNE(
|
173 |
n_components=2, perplexity=30, learning_rate=200,
|
@@ -180,16 +197,59 @@ def cluster_dna(top_k: float):
|
|
180 |
|
181 |
label_encoder = LabelEncoder()
|
182 |
y_encoded = label_encoder.fit_transform(y)
|
|
|
183 |
|
184 |
fig, ax = plt.subplots()
|
185 |
-
ax.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_encoded, cmap="
|
186 |
-
|
|
|
|
|
187 |
# Reduce unnecessary whitespace
|
188 |
ax.set_xlim(X_tsne[:, 0].min() - 0.1, X_tsne[:, 0].max() + 0.1)
|
189 |
fig.canvas.draw()
|
190 |
|
191 |
return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
with gr.Blocks() as demo:
|
194 |
# Header section
|
195 |
gr.Markdown(("""
|
@@ -209,9 +269,9 @@ with gr.Blocks() as demo:
|
|
209 |
|
210 |
with gr.Column():
|
211 |
with gr.Row():
|
212 |
-
inp_lat = gr.Textbox(label="Latitude", placeholder="e.g.
|
213 |
with gr.Row():
|
214 |
-
inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -
|
215 |
|
216 |
with gr.Row():
|
217 |
btn_defaults = gr.Button("I'm feeling lucky")
|
@@ -224,13 +284,12 @@ with gr.Blocks() as demo:
|
|
224 |
A demo of predicting the genus of a DNA sequence using multiple
|
225 |
approaches (method dropdown):
|
226 |
|
227 |
-
- **fine_tuned_model**:
|
228 |
-
`LofiAmazon/BarcodeBERT-Finetuned-Amazon` which predicts the genus
|
229 |
based on the DNA sequence and environmental data.
|
230 |
- **cosine**: computes a cosine similarity between the DNA sequence
|
231 |
embedding generated by our model and the embeddings of known samples
|
232 |
-
that we precomputed and stored
|
233 |
-
DOES NOT examine ecological layer data.
|
234 |
""")
|
235 |
|
236 |
with gr.Row():
|
@@ -243,34 +302,66 @@ with gr.Blocks() as demo:
|
|
243 |
genus_output = gr.Image()
|
244 |
|
245 |
predict_button.click(
|
246 |
-
fn=
|
247 |
inputs=[method_dropdown, inp_dna, inp_lat, inp_lng],
|
248 |
outputs=genus_output
|
249 |
)
|
250 |
|
251 |
with gr.Tab("DNA Embedding Space Visualizer"):
|
252 |
gr.Markdown("""
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
with gr.Row():
|
261 |
with gr.Column():
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
visualize_button = gr.Button("Visualize Embedding Space")
|
267 |
-
with gr.Column():
|
268 |
visualize_output = gr.Image()
|
269 |
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
demo.launch()
|
|
|
60 |
classification_model.eval()
|
61 |
|
62 |
# Load datasets
|
63 |
+
amazon_ds = load_dataset(DATASETS["amazon"])['train'].to_pandas()
|
64 |
+
amazon_ds = amazon_ds[amazon_ds["genus"].notna()]
|
65 |
|
66 |
def set_default_inputs():
|
67 |
return (DEFAULT_INPUTS["dna_sequence"],
|
|
|
149 |
index=[ID_TO_GENUS_MAP[i] for i in top_k.indices.detach().numpy()]
|
150 |
)
|
151 |
|
152 |
+
# fig, ax = plt.subplots()
|
153 |
+
# ax.bar(top_k.index.astype(str), top_k.values)
|
154 |
+
# ax.set_ylim(0, 1)
|
155 |
+
# ax.set_title("Genus Prediction")
|
156 |
+
# ax.set_xlabel("Genus")
|
157 |
+
# ax.set_ylabel("Probability")
|
158 |
+
# ax.set_xticks(range(len(top_k)))
|
159 |
+
# ax.set_xticklabels(top_k.index.astype(str), rotation=90)
|
160 |
+
# fig.subplots_adjust(bottom=0.3)
|
161 |
+
# fig.canvas.draw()
|
162 |
+
|
163 |
+
# return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
164 |
+
return top_k
|
165 |
+
|
166 |
+
def genus_hist(method: str, dna_sequence: str, latitude: str, longitude: str):
|
167 |
+
top_k = predict_genus(method, dna_sequence, latitude, longitude)
|
168 |
fig, ax = plt.subplots()
|
169 |
ax.bar(top_k.index.astype(str), top_k.values)
|
170 |
ax.set_ylim(0, 1)
|
|
|
179 |
return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
180 |
|
181 |
|
182 |
+
def cluster_dna(k: float):
|
183 |
+
df = amazon_ds
|
184 |
+
# df = df[df["genus"].notna()]
|
185 |
+
k = int(k)
|
186 |
genus_counts = df["genus"].value_counts()
|
187 |
+
top_genuses = genus_counts.head(k).index
|
188 |
df = df[df["genus"].isin(top_genuses)]
|
189 |
tsne = TSNE(
|
190 |
n_components=2, perplexity=30, learning_rate=200,
|
|
|
197 |
|
198 |
label_encoder = LabelEncoder()
|
199 |
y_encoded = label_encoder.fit_transform(y)
|
200 |
+
classes = list(label_encoder.inverse_transform(range(len(df['genus'].unique()))))
|
201 |
|
202 |
fig, ax = plt.subplots()
|
203 |
+
plot = ax.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_encoded, cmap="tab20", alpha=0.7)
|
204 |
+
handles, _ = plot.legend_elements(prop='colors')
|
205 |
+
ax.legend(handles, classes)
|
206 |
+
ax.set_title(f"DNA Embedding Space (of {str(k)} most common genera)")
|
207 |
# Reduce unnecessary whitespace
|
208 |
ax.set_xlim(X_tsne[:, 0].min() - 0.1, X_tsne[:, 0].max() + 0.1)
|
209 |
fig.canvas.draw()
|
210 |
|
211 |
return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
212 |
|
213 |
+
def cluster_dna2(k: float, method: str, dna_sequence: str, latitude: str, longitude: str):
|
214 |
+
top_genuses = predict_genus(method, dna_sequence, latitude, longitude)
|
215 |
+
embed = get_embedding(dna_sequence).tolist()
|
216 |
+
# df = amazon_ds["train"].to_pandas()
|
217 |
+
df = amazon_ds
|
218 |
+
# df = df[df["genus"].notna()]
|
219 |
+
k = int(k)
|
220 |
+
# genus_counts = df["genus"].value_counts()
|
221 |
+
top_genuses = top_genuses.head(k).index
|
222 |
+
df = df[df["genus"].isin(top_genuses)]
|
223 |
+
tsne = TSNE(
|
224 |
+
n_components=2, perplexity=30, learning_rate=200,
|
225 |
+
n_iter=1000, random_state=0,
|
226 |
+
)
|
227 |
+
X = np.vstack([df['embeddings'].tolist(), embed])
|
228 |
+
# X = np.stack(df["embeddings"].tolist())
|
229 |
+
y = df["genus"].tolist()
|
230 |
+
|
231 |
+
X_tsne = tsne.fit_transform(X)
|
232 |
+
tsne_embed_space = X_tsne[:-1]
|
233 |
+
tsne_single = X_tsne[-1]
|
234 |
+
|
235 |
+
label_encoder = LabelEncoder()
|
236 |
+
y_encoded = label_encoder.fit_transform(y)
|
237 |
+
classes = list(label_encoder.inverse_transform(range(len(df['genus'].unique()))))
|
238 |
+
|
239 |
+
fig, ax = plt.subplots()
|
240 |
+
plot = ax.scatter(tsne_embed_space[:, 0], tsne_embed_space[:, 1], c=y_encoded, cmap="tab20", alpha=0.7)
|
241 |
+
ax.scatter(tsne_single[0], tsne_single[1], color='red', edgecolor='black')
|
242 |
+
handles, _ = plot.legend_elements(prop='colors')
|
243 |
+
ax.legend(handles, classes)
|
244 |
+
# ax.legend(loc='best')
|
245 |
+
ax.text(tsne_single[0], tsne_single[1], 'Your DNA Seq', fontsize=10, color='black')
|
246 |
+
ax.set_title(f"DNA Embedding Space Around Your DNA's Embedding")
|
247 |
+
# Reduce unnecessary whitespace
|
248 |
+
ax.set_xlim(X_tsne[:, 0].min() + 0.1, X_tsne[:, 0].max() + 0.1)
|
249 |
+
fig.canvas.draw()
|
250 |
+
|
251 |
+
return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
252 |
+
|
253 |
with gr.Blocks() as demo:
|
254 |
# Header section
|
255 |
gr.Markdown(("""
|
|
|
269 |
|
270 |
with gr.Column():
|
271 |
with gr.Row():
|
272 |
+
inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. 2.009083")
|
273 |
with gr.Row():
|
274 |
+
inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -41.68281")
|
275 |
|
276 |
with gr.Row():
|
277 |
btn_defaults = gr.Button("I'm feeling lucky")
|
|
|
284 |
A demo of predicting the genus of a DNA sequence using multiple
|
285 |
approaches (method dropdown):
|
286 |
|
287 |
+
- **fine_tuned_model**: uses our
|
288 |
+
`LofiAmazon/BarcodeBERT-Finetuned-Amazon` model which predicts the genus
|
289 |
based on the DNA sequence and environmental data.
|
290 |
- **cosine**: computes a cosine similarity between the DNA sequence
|
291 |
embedding generated by our model and the embeddings of known samples
|
292 |
+
that we precomputed and stored. This method DOES NOT use ecological layer data.
|
|
|
293 |
""")
|
294 |
|
295 |
with gr.Row():
|
|
|
302 |
genus_output = gr.Image()
|
303 |
|
304 |
predict_button.click(
|
305 |
+
fn=genus_hist,
|
306 |
inputs=[method_dropdown, inp_dna, inp_lat, inp_lng],
|
307 |
outputs=genus_output
|
308 |
)
|
309 |
|
310 |
with gr.Tab("DNA Embedding Space Visualizer"):
|
311 |
gr.Markdown("""
|
312 |
+
## DNA Embedding Space Visualizer
|
313 |
+
|
314 |
+
Use this tool to visualize how our DNA Transformer model
|
315 |
+
learns to cluster similar DNA sequences together.
|
316 |
+
""")
|
317 |
+
|
318 |
+
# with gr.Row():
|
319 |
+
# with gr.Column():
|
320 |
+
# top_k_slider = gr.Slider(
|
321 |
+
# minimum=1, maximum=10, step=1, value=5,
|
322 |
+
# label="Choose **k**, the number of top genera to visualize",
|
323 |
+
# )
|
324 |
+
# visualize_button = gr.Button("Visualize Embedding Space")
|
325 |
+
# with gr.Column():
|
326 |
+
# visualize_output = gr.Image()
|
327 |
+
|
328 |
+
# visualize_button.click(
|
329 |
+
# fn=cluster_dna,
|
330 |
+
# inputs=top_k_slider,
|
331 |
+
# outputs=visualize_output
|
332 |
+
# )
|
333 |
|
334 |
+
with gr.Row():
|
335 |
+
top_k_slider = gr.Slider(
|
336 |
+
minimum=1, maximum=10, step=1, value=5,
|
337 |
+
label="Choose **k**, the number of top genera to visualize",
|
338 |
+
)
|
339 |
+
visualize_button = gr.Button("Visualize Embedding Space")
|
340 |
with gr.Row():
|
341 |
with gr.Column():
|
342 |
+
gr.Markdown("""
|
343 |
+
t-SNE plot of the DNA embedding spaces of the **k** most common
|
344 |
+
genera in our dataset.
|
345 |
+
""")
|
|
|
|
|
346 |
visualize_output = gr.Image()
|
347 |
|
348 |
+
visualize_button.click(
|
349 |
+
fn=cluster_dna,
|
350 |
+
inputs=top_k_slider,
|
351 |
+
outputs=visualize_output
|
352 |
)
|
353 |
+
with gr.Column():
|
354 |
+
gr.Markdown("""
|
355 |
+
t-SNE plot of the DNA embedding spaces of the **k** most likely
|
356 |
+
genera for the DNA sequence you provided.
|
357 |
+
""")
|
358 |
+
visualize_output2 = gr.Image()
|
359 |
+
|
360 |
+
visualize_button.click(
|
361 |
+
fn=cluster_dna2,
|
362 |
+
inputs=[top_k_slider, method_dropdown, inp_dna, inp_lat, inp_lng],
|
363 |
+
outputs=visualize_output2
|
364 |
+
)
|
365 |
+
|
366 |
|
367 |
demo.launch()
|
default_inputs.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
"dna_sequence": "AACAATGTATTTGATTTTCGCCCTTGTGAATTTATTCGCTGGCGGAACAATGGCATTGTTGATTCGTTTGGAGTTGTTCCAACCTGGCTTGCAATTTTTAAGACCTGAGTTTTTTAATCAGTTAACAACTATGCACGGCCTTATAATGGTTTTCGGTGCAATTATGCCGGCCTTTGTGGGTTTTGCTAACTTGATGATTCCTTTGCAAATTGGTGCCTCTGATATGGCGTTTGCAAGAATGAACAATTTTAGTTTCTGGATTATGCCTGTTGCAGGGATGTTATTATTTGGCTCATTTTTGGCTCCTGGTGGCGCTACTGCAGCTGGTTGGACTTTGTATGCTCCTTTGTCGGTCCAAATGGGGCCTGGTATGGACATGACTATTTTTGCTGTTCACTTGATGGGTGCTTCATCCATTATGGGATCCATTAATATCATTGTGACAATTCTGAATATGCGTGCTCCTGGACTGTCTTTGATGAAGATGCCAATGTTCTGTTGGACATGGTTGATTACTGCATATTTGTTAATTGCGGTTATGCCTGTTTTAGCTGGTGCTATCACTATGGTTCTAACAGACCGTCACTTTGGAACAAGCTTTTTTGCAGCTGCTGGCGGTGGAGACCCTGTAATGTATCAACATATCTTC",
|
3 |
-
"latitude": "
|
4 |
-
"longitude": "-
|
5 |
}
|
|
|
1 |
{
|
2 |
"dna_sequence": "AACAATGTATTTGATTTTCGCCCTTGTGAATTTATTCGCTGGCGGAACAATGGCATTGTTGATTCGTTTGGAGTTGTTCCAACCTGGCTTGCAATTTTTAAGACCTGAGTTTTTTAATCAGTTAACAACTATGCACGGCCTTATAATGGTTTTCGGTGCAATTATGCCGGCCTTTGTGGGTTTTGCTAACTTGATGATTCCTTTGCAAATTGGTGCCTCTGATATGGCGTTTGCAAGAATGAACAATTTTAGTTTCTGGATTATGCCTGTTGCAGGGATGTTATTATTTGGCTCATTTTTGGCTCCTGGTGGCGCTACTGCAGCTGGTTGGACTTTGTATGCTCCTTTGTCGGTCCAAATGGGGCCTGGTATGGACATGACTATTTTTGCTGTTCACTTGATGGGTGCTTCATCCATTATGGGATCCATTAATATCATTGTGACAATTCTGAATATGCGTGCTCCTGGACTGTCTTTGATGAAGATGCCAATGTTCTGTTGGACATGGTTGATTACTGCATATTTGTTAATTGCGGTTATGCCTGTTTTAGCTGGTGCTATCACTATGGTTCTAACAGACCGTCACTTTGGAACAAGCTTTTTTGCAGCTGCTGGCGGTGGAGACCCTGTAATGTATCAACATATCTTC",
|
3 |
+
"latitude": "2.009083",
|
4 |
+
"longitude": "-41.68281"
|
5 |
}
|