jennzhuge commited on
Commit
8465e44
·
1 Parent(s): 040747a

added tsne graph, chanaged default coords

Browse files
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  .venv
2
  flagged
3
  *.tif
4
- *.tiff
 
 
1
  .venv
2
  flagged
3
  *.tif
4
+ *.tiff
5
+ .env
__pycache__/app.cpython-39.pyc ADDED
Binary file (10.7 kB). View file
 
__pycache__/config.cpython-39.pyc ADDED
Binary file (1.29 kB). View file
 
app.py CHANGED
@@ -60,7 +60,8 @@ embeddings_model.eval()
60
  classification_model.eval()
61
 
62
  # Load datasets
63
- amazon_ds = load_dataset(DATASETS["amazon"])
 
64
 
65
  def set_default_inputs():
66
  return (DEFAULT_INPUTS["dna_sequence"],
@@ -148,6 +149,22 @@ def predict_genus(method: str, dna_sequence: str, latitude: str, longitude: str)
148
  index=[ID_TO_GENUS_MAP[i] for i in top_k.indices.detach().numpy()]
149
  )
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  fig, ax = plt.subplots()
152
  ax.bar(top_k.index.astype(str), top_k.values)
153
  ax.set_ylim(0, 1)
@@ -162,12 +179,12 @@ def predict_genus(method: str, dna_sequence: str, latitude: str, longitude: str)
162
  return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
163
 
164
 
165
- def cluster_dna(top_k: float):
166
- df = amazon_ds["train"].to_pandas()
167
- df = df[df["genus"].notna()]
168
- top_k = int(top_k)
169
  genus_counts = df["genus"].value_counts()
170
- top_genuses = genus_counts.head(top_k).index
171
  df = df[df["genus"].isin(top_genuses)]
172
  tsne = TSNE(
173
  n_components=2, perplexity=30, learning_rate=200,
@@ -180,16 +197,59 @@ def cluster_dna(top_k: float):
180
 
181
  label_encoder = LabelEncoder()
182
  y_encoded = label_encoder.fit_transform(y)
 
183
 
184
  fig, ax = plt.subplots()
185
- ax.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_encoded, cmap="viridis", alpha=0.7)
186
- ax.set_title(f"DNA Embedding Space (of {str(top_k)} most common genera)")
 
 
187
  # Reduce unnecessary whitespace
188
  ax.set_xlim(X_tsne[:, 0].min() - 0.1, X_tsne[:, 0].max() + 0.1)
189
  fig.canvas.draw()
190
 
191
  return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  with gr.Blocks() as demo:
194
  # Header section
195
  gr.Markdown(("""
@@ -209,9 +269,9 @@ with gr.Blocks() as demo:
209
 
210
  with gr.Column():
211
  with gr.Row():
212
- inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. -3.009083")
213
  with gr.Row():
214
- inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -58.68281")
215
 
216
  with gr.Row():
217
  btn_defaults = gr.Button("I'm feeling lucky")
@@ -224,13 +284,12 @@ with gr.Blocks() as demo:
224
  A demo of predicting the genus of a DNA sequence using multiple
225
  approaches (method dropdown):
226
 
227
- - **fine_tuned_model**: using our
228
- `LofiAmazon/BarcodeBERT-Finetuned-Amazon` which predicts the genus
229
  based on the DNA sequence and environmental data.
230
  - **cosine**: computes a cosine similarity between the DNA sequence
231
  embedding generated by our model and the embeddings of known samples
232
- that we precomputed and stored in a Pinecone index. Thie method
233
- DOES NOT examine ecological layer data.
234
  """)
235
 
236
  with gr.Row():
@@ -243,34 +302,66 @@ with gr.Blocks() as demo:
243
  genus_output = gr.Image()
244
 
245
  predict_button.click(
246
- fn=predict_genus,
247
  inputs=[method_dropdown, inp_dna, inp_lat, inp_lng],
248
  outputs=genus_output
249
  )
250
 
251
  with gr.Tab("DNA Embedding Space Visualizer"):
252
  gr.Markdown("""
253
- ## DNA Embedding Space Visualizer
254
-
255
- We show a 2D t-SNE plot of the DNA embeddings of the five most common
256
- genera in our dataset. This shows that the DNA Transformer model is
257
- learning to cluster similar DNA sequences together.
258
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
 
 
 
 
 
 
260
  with gr.Row():
261
  with gr.Column():
262
- top_k_slider = gr.Slider(
263
- minimum=1, maximum=10, step=1, value=5,
264
- label="Number of top genera to visualize",
265
- )
266
- visualize_button = gr.Button("Visualize Embedding Space")
267
- with gr.Column():
268
  visualize_output = gr.Image()
269
 
270
- visualize_button.click(
271
- fn=cluster_dna,
272
- inputs=top_k_slider,
273
- outputs=visualize_output
274
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  demo.launch()
 
60
  classification_model.eval()
61
 
62
  # Load datasets
63
+ amazon_ds = load_dataset(DATASETS["amazon"])['train'].to_pandas()
64
+ amazon_ds = amazon_ds[amazon_ds["genus"].notna()]
65
 
66
  def set_default_inputs():
67
  return (DEFAULT_INPUTS["dna_sequence"],
 
149
  index=[ID_TO_GENUS_MAP[i] for i in top_k.indices.detach().numpy()]
150
  )
151
 
152
+ # fig, ax = plt.subplots()
153
+ # ax.bar(top_k.index.astype(str), top_k.values)
154
+ # ax.set_ylim(0, 1)
155
+ # ax.set_title("Genus Prediction")
156
+ # ax.set_xlabel("Genus")
157
+ # ax.set_ylabel("Probability")
158
+ # ax.set_xticks(range(len(top_k)))
159
+ # ax.set_xticklabels(top_k.index.astype(str), rotation=90)
160
+ # fig.subplots_adjust(bottom=0.3)
161
+ # fig.canvas.draw()
162
+
163
+ # return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
164
+ return top_k
165
+
166
+ def genus_hist(method: str, dna_sequence: str, latitude: str, longitude: str):
167
+ top_k = predict_genus(method, dna_sequence, latitude, longitude)
168
  fig, ax = plt.subplots()
169
  ax.bar(top_k.index.astype(str), top_k.values)
170
  ax.set_ylim(0, 1)
 
179
  return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
180
 
181
 
182
+ def cluster_dna(k: float):
183
+ df = amazon_ds
184
+ # df = df[df["genus"].notna()]
185
+ k = int(k)
186
  genus_counts = df["genus"].value_counts()
187
+ top_genuses = genus_counts.head(k).index
188
  df = df[df["genus"].isin(top_genuses)]
189
  tsne = TSNE(
190
  n_components=2, perplexity=30, learning_rate=200,
 
197
 
198
  label_encoder = LabelEncoder()
199
  y_encoded = label_encoder.fit_transform(y)
200
+ classes = list(label_encoder.inverse_transform(range(len(df['genus'].unique()))))
201
 
202
  fig, ax = plt.subplots()
203
+ plot = ax.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_encoded, cmap="tab20", alpha=0.7)
204
+ handles, _ = plot.legend_elements(prop='colors')
205
+ ax.legend(handles, classes)
206
+ ax.set_title(f"DNA Embedding Space (of {str(k)} most common genera)")
207
  # Reduce unnecessary whitespace
208
  ax.set_xlim(X_tsne[:, 0].min() - 0.1, X_tsne[:, 0].max() + 0.1)
209
  fig.canvas.draw()
210
 
211
  return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
212
 
213
+ def cluster_dna2(k: float, method: str, dna_sequence: str, latitude: str, longitude: str):
214
+ top_genuses = predict_genus(method, dna_sequence, latitude, longitude)
215
+ embed = get_embedding(dna_sequence).tolist()
216
+ # df = amazon_ds["train"].to_pandas()
217
+ df = amazon_ds
218
+ # df = df[df["genus"].notna()]
219
+ k = int(k)
220
+ # genus_counts = df["genus"].value_counts()
221
+ top_genuses = top_genuses.head(k).index
222
+ df = df[df["genus"].isin(top_genuses)]
223
+ tsne = TSNE(
224
+ n_components=2, perplexity=30, learning_rate=200,
225
+ n_iter=1000, random_state=0,
226
+ )
227
+ X = np.vstack([df['embeddings'].tolist(), embed])
228
+ # X = np.stack(df["embeddings"].tolist())
229
+ y = df["genus"].tolist()
230
+
231
+ X_tsne = tsne.fit_transform(X)
232
+ tsne_embed_space = X_tsne[:-1]
233
+ tsne_single = X_tsne[-1]
234
+
235
+ label_encoder = LabelEncoder()
236
+ y_encoded = label_encoder.fit_transform(y)
237
+ classes = list(label_encoder.inverse_transform(range(len(df['genus'].unique()))))
238
+
239
+ fig, ax = plt.subplots()
240
+ plot = ax.scatter(tsne_embed_space[:, 0], tsne_embed_space[:, 1], c=y_encoded, cmap="tab20", alpha=0.7)
241
+ ax.scatter(tsne_single[0], tsne_single[1], color='red', edgecolor='black')
242
+ handles, _ = plot.legend_elements(prop='colors')
243
+ ax.legend(handles, classes)
244
+ # ax.legend(loc='best')
245
+ ax.text(tsne_single[0], tsne_single[1], 'Your DNA Seq', fontsize=10, color='black')
246
+ ax.set_title(f"DNA Embedding Space Around Your DNA's Embedding")
247
+ # Reduce unnecessary whitespace
248
+ ax.set_xlim(X_tsne[:, 0].min() + 0.1, X_tsne[:, 0].max() + 0.1)
249
+ fig.canvas.draw()
250
+
251
+ return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
252
+
253
  with gr.Blocks() as demo:
254
  # Header section
255
  gr.Markdown(("""
 
269
 
270
  with gr.Column():
271
  with gr.Row():
272
+ inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. 2.009083")
273
  with gr.Row():
274
+ inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -41.68281")
275
 
276
  with gr.Row():
277
  btn_defaults = gr.Button("I'm feeling lucky")
 
284
  A demo of predicting the genus of a DNA sequence using multiple
285
  approaches (method dropdown):
286
 
287
+ - **fine_tuned_model**: uses our
288
+ `LofiAmazon/BarcodeBERT-Finetuned-Amazon` model which predicts the genus
289
  based on the DNA sequence and environmental data.
290
  - **cosine**: computes a cosine similarity between the DNA sequence
291
  embedding generated by our model and the embeddings of known samples
292
+ that we precomputed and stored. This method DOES NOT use ecological layer data.
 
293
  """)
294
 
295
  with gr.Row():
 
302
  genus_output = gr.Image()
303
 
304
  predict_button.click(
305
+ fn=genus_hist,
306
  inputs=[method_dropdown, inp_dna, inp_lat, inp_lng],
307
  outputs=genus_output
308
  )
309
 
310
  with gr.Tab("DNA Embedding Space Visualizer"):
311
  gr.Markdown("""
312
+ ## DNA Embedding Space Visualizer
313
+
314
+ Use this tool to visualize how our DNA Transformer model
315
+ learns to cluster similar DNA sequences together.
316
+ """)
317
+
318
+ # with gr.Row():
319
+ # with gr.Column():
320
+ # top_k_slider = gr.Slider(
321
+ # minimum=1, maximum=10, step=1, value=5,
322
+ # label="Choose **k**, the number of top genera to visualize",
323
+ # )
324
+ # visualize_button = gr.Button("Visualize Embedding Space")
325
+ # with gr.Column():
326
+ # visualize_output = gr.Image()
327
+
328
+ # visualize_button.click(
329
+ # fn=cluster_dna,
330
+ # inputs=top_k_slider,
331
+ # outputs=visualize_output
332
+ # )
333
 
334
+ with gr.Row():
335
+ top_k_slider = gr.Slider(
336
+ minimum=1, maximum=10, step=1, value=5,
337
+ label="Choose **k**, the number of top genera to visualize",
338
+ )
339
+ visualize_button = gr.Button("Visualize Embedding Space")
340
  with gr.Row():
341
  with gr.Column():
342
+ gr.Markdown("""
343
+ t-SNE plot of the DNA embedding spaces of the **k** most common
344
+ genera in our dataset.
345
+ """)
 
 
346
  visualize_output = gr.Image()
347
 
348
+ visualize_button.click(
349
+ fn=cluster_dna,
350
+ inputs=top_k_slider,
351
+ outputs=visualize_output
352
  )
353
+ with gr.Column():
354
+ gr.Markdown("""
355
+ t-SNE plot of the DNA embedding spaces of the **k** most likely
356
+ genera for the DNA sequence you provided.
357
+ """)
358
+ visualize_output2 = gr.Image()
359
+
360
+ visualize_button.click(
361
+ fn=cluster_dna2,
362
+ inputs=[top_k_slider, method_dropdown, inp_dna, inp_lat, inp_lng],
363
+ outputs=visualize_output2
364
+ )
365
+
366
 
367
  demo.launch()
default_inputs.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
  "dna_sequence": "AACAATGTATTTGATTTTCGCCCTTGTGAATTTATTCGCTGGCGGAACAATGGCATTGTTGATTCGTTTGGAGTTGTTCCAACCTGGCTTGCAATTTTTAAGACCTGAGTTTTTTAATCAGTTAACAACTATGCACGGCCTTATAATGGTTTTCGGTGCAATTATGCCGGCCTTTGTGGGTTTTGCTAACTTGATGATTCCTTTGCAAATTGGTGCCTCTGATATGGCGTTTGCAAGAATGAACAATTTTAGTTTCTGGATTATGCCTGTTGCAGGGATGTTATTATTTGGCTCATTTTTGGCTCCTGGTGGCGCTACTGCAGCTGGTTGGACTTTGTATGCTCCTTTGTCGGTCCAAATGGGGCCTGGTATGGACATGACTATTTTTGCTGTTCACTTGATGGGTGCTTCATCCATTATGGGATCCATTAATATCATTGTGACAATTCTGAATATGCGTGCTCCTGGACTGTCTTTGATGAAGATGCCAATGTTCTGTTGGACATGGTTGATTACTGCATATTTGTTAATTGCGGTTATGCCTGTTTTAGCTGGTGCTATCACTATGGTTCTAACAGACCGTCACTTTGGAACAAGCTTTTTTGCAGCTGCTGGCGGTGGAGACCCTGTAATGTATCAACATATCTTC",
3
- "latitude": "-3.009083",
4
- "longitude": "-58.68281"
5
  }
 
1
  {
2
  "dna_sequence": "AACAATGTATTTGATTTTCGCCCTTGTGAATTTATTCGCTGGCGGAACAATGGCATTGTTGATTCGTTTGGAGTTGTTCCAACCTGGCTTGCAATTTTTAAGACCTGAGTTTTTTAATCAGTTAACAACTATGCACGGCCTTATAATGGTTTTCGGTGCAATTATGCCGGCCTTTGTGGGTTTTGCTAACTTGATGATTCCTTTGCAAATTGGTGCCTCTGATATGGCGTTTGCAAGAATGAACAATTTTAGTTTCTGGATTATGCCTGTTGCAGGGATGTTATTATTTGGCTCATTTTTGGCTCCTGGTGGCGCTACTGCAGCTGGTTGGACTTTGTATGCTCCTTTGTCGGTCCAAATGGGGCCTGGTATGGACATGACTATTTTTGCTGTTCACTTGATGGGTGCTTCATCCATTATGGGATCCATTAATATCATTGTGACAATTCTGAATATGCGTGCTCCTGGACTGTCTTTGATGAAGATGCCAATGTTCTGTTGGACATGGTTGATTACTGCATATTTGTTAATTGCGGTTATGCCTGTTTTAGCTGGTGCTATCACTATGGTTCTAACAGACCGTCACTTTGGAACAAGCTTTTTTGCAGCTGCTGGCGGTGGAGACCCTGTAATGTATCAACATATCTTC",
3
+ "latitude": "2.009083",
4
+ "longitude": "-41.68281"
5
  }