Spaces:
Runtime error
Runtime error
import os | |
import pickle | |
import re | |
import PIL.Image | |
import pandas as pd | |
import numpy as np | |
import gradio as gr | |
from datasets import load_dataset | |
import matplotlib.pyplot as plt | |
from sklearn.manifold import TSNE | |
from sklearn.preprocessing import LabelEncoder | |
import torch | |
from torch import nn | |
from transformers import BertConfig, BertForMaskedLM, PreTrainedTokenizerFast | |
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download | |
from pinecone import Pinecone | |
import rasterio | |
from rasterio.sample import sample_gen | |
from config import DEFAULT_INPUTS, MODELS, DATASETS, ID_TO_GENUS_MAP, LAYER_NAMES | |
# Download ecolayers from HF dataset | |
for image_name in LAYER_NAMES: | |
hf_hub_download( | |
repo_id="LofiAmazon/Global-Ecolayers", | |
filename=image_name, | |
repo_type="dataset", | |
local_dir=".", | |
) | |
# We need this for the eco layers because they are too big | |
PIL.Image.MAX_IMAGE_PIXELS = None | |
torch.set_grad_enabled(False) | |
# Configure pinecone | |
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) | |
pc_index = pc.Index("amazon") | |
# Load models | |
class DNASeqClassifier(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, bert_model, env_dim, num_classes): | |
super(DNASeqClassifier, self).__init__() | |
self.bert = bert_model | |
self.env_dim = env_dim | |
self.num_classes = num_classes | |
self.fc = nn.Linear(768 + env_dim, num_classes) | |
def forward(self, bert_inputs, env_data): | |
outputs = self.bert(**bert_inputs) | |
dna_embeddings = outputs.hidden_states[-1].mean(1) | |
combined = torch.cat((dna_embeddings, env_data), dim=1) | |
logits = self.fc(combined) | |
return logits | |
tokenizer = PreTrainedTokenizerFast.from_pretrained(MODELS["embeddings"]) | |
embeddings_model = BertForMaskedLM.from_pretrained(MODELS["embeddings"]) | |
classification_model = DNASeqClassifier.from_pretrained( | |
MODELS["classification"], | |
bert_model=BertForMaskedLM( | |
BertConfig(vocab_size=259, output_hidden_states=True), | |
), | |
) | |
with open("scaler.pkl", "rb") as f: | |
scaler = pickle.load(f) | |
embeddings_model.eval() | |
classification_model.eval() | |
# Load datasets | |
amazon_ds = load_dataset(DATASETS["amazon"])['train'].to_pandas() | |
amazon_ds = amazon_ds[amazon_ds["genus"].notna()] | |
def set_default_inputs(): | |
return (DEFAULT_INPUTS["dna_sequence"], | |
DEFAULT_INPUTS["latitude"], | |
DEFAULT_INPUTS["longitude"]) | |
def preprocess(dna_sequence: str, latitude: float, longitude: float): | |
"""Prepares app input for downsteram tasks""" | |
# Preprocess the DNA sequence turning it into an embedding | |
dna_seq_preprocessed: str = re.sub(r"[^ACGT]", "N", dna_sequence) | |
dna_seq_preprocessed: str = re.sub(r"N+$", "", dna_sequence) | |
dna_seq_preprocessed = dna_seq_preprocessed[:660] | |
dna_seq_preprocessed = " ".join([ | |
dna_seq_preprocessed[i:i+4] for i in range(0, len(dna_seq_preprocessed), 4) | |
]) | |
dna_embedding: torch.Tensor = embeddings_model( | |
**tokenizer(dna_seq_preprocessed, return_tensors="pt") | |
).hidden_states[-1].mean(1).squeeze() | |
# Preprocess the location data | |
coords = (float(latitude), float(longitude)) | |
return dna_embedding, coords[0], coords[1] | |
def tokenize(dna_sequence: str) -> dict[str, torch.Tensor]: | |
dna_seq_preprocessed: str = re.sub(r"[^ACGT]", "N", dna_sequence) | |
dna_seq_preprocessed: str = re.sub(r"N+$", "", dna_sequence) | |
dna_seq_preprocessed = dna_seq_preprocessed[:660] | |
dna_seq_preprocessed = " ".join([ | |
dna_seq_preprocessed[i:i+4] for i in range(0, len(dna_seq_preprocessed), 4) | |
]) | |
return tokenizer(dna_seq_preprocessed, return_tensors="pt") | |
def get_embedding(dna_sequence: str) -> torch.Tensor: | |
dna_embedding: torch.Tensor = embeddings_model( | |
**tokenize(dna_sequence) | |
).hidden_states[-1].mean(1).squeeze() | |
return dna_embedding | |
def predict_genus(method: str, dna_sequence: str, latitude: str, longitude: str): | |
coords = (float(latitude), float(longitude)) | |
if method == "cosine": | |
embedding = get_embedding(dna_sequence) | |
result = pc_index.query( | |
namespace="all", | |
vector=embedding.tolist(), | |
top_k=10, | |
include_metadata=True, | |
) | |
top_k = [m["metadata"]["genus"] for m in result["matches"]] | |
top_k = pd.Series(top_k).value_counts() | |
top_k = top_k / top_k.sum() | |
if method == "fine_tuned_model": | |
bert_inputs = tokenize(dna_sequence) | |
env_data = [] | |
for layer in LAYER_NAMES: | |
with rasterio.open(layer) as dataset: | |
# Get the corresponding ecological values for the samples | |
results = sample_gen(dataset, [coords]) | |
results = [r for r in results] | |
layer_data = np.mean(results[0]) | |
env_data.append(layer_data) | |
env_data = scaler.transform([env_data]) | |
env_data = torch.from_numpy(env_data).to(torch.float32) | |
logits = classification_model(bert_inputs, env_data) | |
temperature = 0.2 | |
probs = torch.softmax(logits / temperature, dim=1).squeeze() | |
top_k = torch.topk(probs, 10) | |
top_k = pd.Series( | |
top_k.values.detach().numpy(), | |
index=[ID_TO_GENUS_MAP[i] for i in top_k.indices.detach().numpy()] | |
) | |
# fig, ax = plt.subplots() | |
# ax.bar(top_k.index.astype(str), top_k.values) | |
# ax.set_ylim(0, 1) | |
# ax.set_title("Genus Prediction") | |
# ax.set_xlabel("Genus") | |
# ax.set_ylabel("Probability") | |
# ax.set_xticks(range(len(top_k))) | |
# ax.set_xticklabels(top_k.index.astype(str), rotation=90) | |
# fig.subplots_adjust(bottom=0.3) | |
# fig.canvas.draw() | |
# return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
return top_k | |
def genus_hist(method: str, dna_sequence: str, latitude: str, longitude: str): | |
top_k = predict_genus(method, dna_sequence, latitude, longitude) | |
fig, ax = plt.subplots() | |
ax.bar(top_k.index.astype(str), top_k.values) | |
ax.set_ylim(0, 1) | |
ax.set_title("Genus Prediction") | |
ax.set_xlabel("Genus") | |
ax.set_ylabel("Probability") | |
ax.set_xticks(range(len(top_k))) | |
ax.set_xticklabels(top_k.index.astype(str), rotation=90) | |
fig.subplots_adjust(bottom=0.3) | |
fig.canvas.draw() | |
return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
def cluster_dna(k: float): | |
df = amazon_ds | |
# df = df[df["genus"].notna()] | |
k = int(k) | |
genus_counts = df["genus"].value_counts() | |
top_genuses = genus_counts.head(k).index | |
df = df[df["genus"].isin(top_genuses)] | |
tsne = TSNE( | |
n_components=2, perplexity=30, learning_rate=200, | |
n_iter=1000, random_state=0, | |
) | |
X = np.stack(df["embeddings"].tolist()) | |
y = df["genus"].tolist() | |
X_tsne = tsne.fit_transform(X) | |
label_encoder = LabelEncoder() | |
y_encoded = label_encoder.fit_transform(y) | |
classes = list(label_encoder.inverse_transform(range(len(df['genus'].unique())))) | |
fig, ax = plt.subplots() | |
plot = ax.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y_encoded, cmap="tab20", alpha=0.7) | |
handles, _ = plot.legend_elements(prop='colors') | |
ax.legend(handles, classes) | |
ax.set_title(f"DNA Embedding Space (of {str(k)} most common genera)") | |
# Reduce unnecessary whitespace | |
ax.set_xlim(X_tsne[:, 0].min() - 0.1, X_tsne[:, 0].max() + 0.1) | |
fig.canvas.draw() | |
return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
def cluster_dna2(k: float, method: str, dna_sequence: str, latitude: str, longitude: str): | |
top_genuses = predict_genus(method, dna_sequence, latitude, longitude) | |
embed = get_embedding(dna_sequence).tolist() | |
# df = amazon_ds["train"].to_pandas() | |
df = amazon_ds | |
# df = df[df["genus"].notna()] | |
k = int(k) | |
# genus_counts = df["genus"].value_counts() | |
top_genuses = top_genuses.head(k).index | |
df = df[df["genus"].isin(top_genuses)] | |
tsne = TSNE( | |
n_components=2, perplexity=5, learning_rate=200, | |
n_iter=1000, random_state=0, | |
) | |
X = np.vstack([df['embeddings'].tolist(), embed]) | |
# X = np.stack(df["embeddings"].tolist()) | |
y = df["genus"].tolist() | |
X_tsne = tsne.fit_transform(X) | |
tsne_embed_space = X_tsne[:-1] | |
tsne_single = X_tsne[-1] | |
label_encoder = LabelEncoder() | |
y_encoded = label_encoder.fit_transform(y) | |
classes = list(label_encoder.inverse_transform(range(len(df['genus'].unique())))) | |
fig, ax = plt.subplots() | |
plot = ax.scatter(tsne_embed_space[:, 0], tsne_embed_space[:, 1], c=y_encoded, cmap="tab20", alpha=0.7) | |
ax.scatter(tsne_single[0], tsne_single[1], color='red', edgecolor='black') | |
handles, _ = plot.legend_elements(prop='colors') | |
ax.legend(handles, classes) | |
# ax.legend(loc='best') | |
ax.text(tsne_single[0], tsne_single[1], 'Your DNA Seq', fontsize=10, color='black') | |
ax.set_title(f"DNA Embedding Space Around Your DNA's Embedding") | |
# Reduce unnecessary whitespace | |
ax.set_xlim(X_tsne[:, 0].min() + 0.1, X_tsne[:, 0].max() + 0.1) | |
fig.canvas.draw() | |
return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
with gr.Blocks() as demo: | |
# Header section | |
gr.Markdown((""" | |
# DNA Identifier Tool | |
Welcome to Lofi Amazon Beats' DNA Identifier Tool. Please enter a DNA | |
sequence and the coordinates at which its sample was taken to get | |
started. Click 'I'm feeling lucky' to see use a random sequence. | |
For more information on how to use check out our | |
[README](https://huggingface.co/spaces/LofiAmazon/LofiAmazonSpace/blob/main/README.md) | |
""")) | |
with gr.Row(): | |
with gr.Column(): | |
inp_dna = gr.Textbox(label="DNA", placeholder="e.g. AACAATGTA... (min 200 and max 660 characters)") | |
with gr.Column(): | |
with gr.Row(): | |
inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. 2.009083") | |
with gr.Row(): | |
inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -41.68281") | |
with gr.Row(): | |
btn_defaults = gr.Button("I'm feeling lucky") | |
btn_defaults.click(fn=set_default_inputs, outputs=[inp_dna, inp_lat, inp_lng]) | |
with gr.Tab("Genus Prediction"): | |
gr.Markdown(""" | |
## Genus prediction | |
A demo of predicting the genus of a DNA sequence using multiple | |
approaches (method dropdown): | |
- **fine_tuned_model**: uses our | |
`LofiAmazon/BarcodeBERT-Finetuned-Amazon` model which predicts the genus | |
based on the DNA sequence and environmental data. | |
- **cosine**: computes a cosine similarity between the DNA sequence | |
embedding generated by our model and the embeddings of known samples | |
that we precomputed and stored. This method DOES NOT use ecological layer data. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
method_dropdown = gr.Dropdown( | |
choices=["cosine", "fine_tuned_model"], value="fine_tuned_model", | |
) | |
predict_button = gr.Button("Predict Genus") | |
with gr.Column(): | |
genus_output = gr.Image() | |
predict_button.click( | |
fn=genus_hist, | |
inputs=[method_dropdown, inp_dna, inp_lat, inp_lng], | |
outputs=genus_output | |
) | |
with gr.Tab("DNA Embedding Space Visualizer"): | |
gr.Markdown(""" | |
## DNA Embedding Space Visualizer | |
Use this tool to visualize how our DNA Transformer model | |
learns to cluster similar DNA sequences together. | |
""") | |
# with gr.Row(): | |
# with gr.Column(): | |
# top_k_slider = gr.Slider( | |
# minimum=1, maximum=10, step=1, value=5, | |
# label="Choose **k**, the number of top genera to visualize", | |
# ) | |
# visualize_button = gr.Button("Visualize Embedding Space") | |
# with gr.Column(): | |
# visualize_output = gr.Image() | |
# visualize_button.click( | |
# fn=cluster_dna, | |
# inputs=top_k_slider, | |
# outputs=visualize_output | |
# ) | |
with gr.Row(): | |
top_k_slider = gr.Slider( | |
minimum=1, maximum=10, step=1, value=5, | |
label="Choose **k**, the number of top genera to visualize", | |
) | |
visualize_button = gr.Button("Visualize Embedding Space") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(""" | |
t-SNE plot of the DNA embedding spaces of the **k** most common | |
genera in our dataset. | |
""") | |
visualize_output = gr.Image() | |
visualize_button.click( | |
fn=cluster_dna, | |
inputs=top_k_slider, | |
outputs=visualize_output | |
) | |
with gr.Column(): | |
gr.Markdown(""" | |
t-SNE plot of the DNA embedding spaces of the **k** most likely | |
genera for the DNA sequence you provided. | |
""") | |
visualize_output2 = gr.Image() | |
visualize_button.click( | |
fn=cluster_dna2, | |
inputs=[top_k_slider, method_dropdown, inp_dna, inp_lat, inp_lng], | |
outputs=visualize_output2 | |
) | |
demo.launch() | |