Spaces:
Runtime error
Runtime error
from io import BytesIO | |
import os | |
import re | |
import PIL.Image | |
import pandas as pd | |
import numpy as np | |
import gradio as gr | |
from datasets import load_dataset | |
import infer | |
import matplotlib.pyplot as plt | |
from sklearn.manifold import TSNE | |
from sklearn.preprocessing import LabelEncoder | |
import torch | |
from torch import nn | |
from transformers import BertConfig, BertForMaskedLM, PreTrainedTokenizerFast | |
from huggingface_hub import PyTorchModelHubMixin | |
from pinecone import Pinecone | |
from config import DEFAULT_INPUTS, MODELS, DATASETS, ID_TO_GENUS_MAP | |
# We need this for the eco layers because they are too big | |
PIL.Image.MAX_IMAGE_PIXELS = None | |
torch.set_grad_enabled(False) | |
# Configure pinecone | |
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) | |
pc_index = pc.Index("amazon") | |
# Load models | |
class DNASeqClassifier(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, bert_model, env_dim, num_classes): | |
super(DNASeqClassifier, self).__init__() | |
self.bert = bert_model | |
self.env_dim = env_dim | |
self.num_classes = num_classes | |
self.fc = nn.Linear(768 + env_dim, num_classes) | |
def forward(self, bert_inputs, env_data): | |
outputs = self.bert(**bert_inputs) | |
dna_embeddings = outputs.hidden_states[-1].mean(1) | |
combined = torch.cat((dna_embeddings, env_data), dim=1) | |
logits = self.fc(combined) | |
return logits | |
tokenizer = PreTrainedTokenizerFast.from_pretrained(MODELS["embeddings"]) | |
embeddings_model = BertForMaskedLM.from_pretrained(MODELS["embeddings"]) | |
classification_model = DNASeqClassifier.from_pretrained( | |
MODELS["classification"], | |
bert_model=BertForMaskedLM( | |
BertConfig(vocab_size=259, output_hidden_states=True), | |
), | |
) | |
embeddings_model.eval() | |
classification_model.eval() | |
# Load datasets | |
ecolayers_ds = load_dataset(DATASETS["ecolayers"]) | |
def set_default_inputs(): | |
return (DEFAULT_INPUTS["dna_sequence"], | |
DEFAULT_INPUTS["latitude"], | |
DEFAULT_INPUTS["longitude"]) | |
def preprocess(dna_sequence: str, latitude: float, longitude: float): | |
"""Prepares app input for downsteram tasks""" | |
# Preprocess the DNA sequence turning it into an embedding | |
dna_seq_preprocessed: str = re.sub(r"[^ACGT]", "N", dna_sequence) | |
dna_seq_preprocessed: str = re.sub(r"N+$", "", dna_sequence) | |
dna_seq_preprocessed = dna_seq_preprocessed[:660] | |
dna_seq_preprocessed = " ".join([ | |
dna_seq_preprocessed[i:i+4] for i in range(0, len(dna_seq_preprocessed), 4) | |
]) | |
dna_embedding: torch.Tensor = embeddings_model( | |
**tokenizer(dna_seq_preprocessed, return_tensors="pt") | |
).hidden_states[-1].mean(1).squeeze() | |
# Preprocess the location data | |
coords = (float(latitude), float(longitude)) | |
return dna_embedding, coords[0], coords[1] | |
def tokenize(dna_sequence: str) -> dict[str, torch.Tensor]: | |
dna_seq_preprocessed: str = re.sub(r"[^ACGT]", "N", dna_sequence) | |
dna_seq_preprocessed: str = re.sub(r"N+$", "", dna_sequence) | |
dna_seq_preprocessed = dna_seq_preprocessed[:660] | |
dna_seq_preprocessed = " ".join([ | |
dna_seq_preprocessed[i:i+4] for i in range(0, len(dna_seq_preprocessed), 4) | |
]) | |
return tokenizer(dna_seq_preprocessed, return_tensors="pt") | |
def get_embedding(dna_sequence: str) -> torch.Tensor: | |
dna_embedding: torch.Tensor = embeddings_model( | |
**tokenize(dna_sequence) | |
).hidden_states[-1].mean(1).squeeze() | |
return dna_embedding | |
def predict_genus(method: str, dna_sequence: str, latitude: str, longitude: str): | |
coords = (float(latitude), float(longitude)) | |
if method == "cosine": | |
embedding = get_embedding(dna_sequence) | |
result = pc_index.query( | |
namespace="all", | |
vector=embedding.tolist(), | |
top_k=10, | |
include_metadata=True, | |
) | |
top_k = [m["metadata"]["genus"] for m in result["matches"]] | |
top_k = pd.Series(top_k).value_counts() | |
top_k = top_k / top_k.sum() | |
if method == "fine_tuned_model": | |
bert_inputs = tokenize(dna_sequence) | |
logits = classification_model(bert_inputs, torch.zeros(1, 7)) | |
temperature = 0.2 | |
probs = torch.softmax(logits / temperature, dim=1).squeeze() | |
top_k = torch.topk(probs, 10) | |
top_k = pd.Series( | |
top_k.values.detach().numpy(), | |
index=[ID_TO_GENUS_MAP[i] for i in top_k.indices.detach().numpy()] | |
) | |
# top_k = pd.Series(top_k.values.detach().numpy(), index=top_k.indices.detach().numpy()) | |
fig, ax = plt.subplots() | |
ax.bar(top_k.index.astype(str), top_k.values) | |
ax.set_ylim(0, 1) | |
ax.set_title("Genus Prediction") | |
ax.set_xlabel("Genus") | |
ax.set_ylabel("Probability") | |
ax.set_xticklabels(top_k.index.astype(str), rotation=90) | |
fig.subplots_adjust(bottom=0.3) | |
fig.canvas.draw() | |
return PIL.Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
with gr.Blocks() as demo: | |
# Header section | |
gr.Markdown("# DNA Identifier Tool") | |
gr.Markdown(( | |
"Welcome to Lofi Amazon Beats' DNA Identifier Tool. " | |
"Please enter a DNA sequence and the coordinates at which its sample " | |
"was taken to get started. Click 'I'm feeling lucky' to see use a " | |
"random sequence." | |
)) | |
with gr.Row(): | |
with gr.Column(): | |
inp_dna = gr.Textbox(label="DNA", placeholder="e.g. AACAATGTA... (min 200 and max 660 characters)") | |
with gr.Column(): | |
with gr.Row(): | |
inp_lat = gr.Textbox(label="Latitude", placeholder="e.g. -3.009083") | |
with gr.Row(): | |
inp_lng = gr.Textbox(label="Longitude", placeholder="e.g. -58.68281") | |
with gr.Row(): | |
btn_run = gr.Button("Predict") | |
btn_run.click( | |
fn=preprocess, | |
inputs=[inp_dna, inp_lat, inp_lng], | |
) | |
btn_defaults = gr.Button("I'm feeling lucky") | |
btn_defaults.click(fn=set_default_inputs, outputs=[inp_dna, inp_lat, inp_lng]) | |
with gr.Tab("Genus Prediction"): | |
gr.Interface( | |
fn=predict_genus, | |
inputs=[ | |
gr.Dropdown(choices=["cosine", "fine_tuned_model"], value="fine_tuned_model"), | |
inp_dna, | |
inp_lat, | |
inp_lng, | |
], | |
outputs=["image"], | |
) | |
# with gr.Row(): | |
# gr.Markdown("Make plot or table for Top 5 species") | |
# with gr.Row(): | |
# genus_out = gr.Dataframe(headers=["DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"]) | |
# # btn_run.click(fn=predict_genus, inputs=[inp_dna, inp_lat, inp_lng], outputs=genus_out) | |
with gr.Tab('DNA Embedding Space Visualizer'): | |
gr.Markdown("If the highest genus probability is very low for your DNA sequence, we can still examine the DNA embedding of the sequence in relation to known samples for clues.") | |
with gr.Row() as row: | |
with gr.Column(): | |
gr.Markdown("Plot of your DNA sequence among other known species clusters.") | |
# plot = gr.Plot("") | |
# btn_run.click(fn=tsne_DNA, inputs=[inp_dna, genus_out]) | |
with gr.Column(): | |
gr.Markdown("Plot of the five most common species at your sample coordinate.") | |
demo.launch() | |