Spaces:
Runtime error
Runtime error
import os | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
import numpy as mp | |
def predict_genus_dna(dnaSeqs): | |
genuses = [] | |
probs = dnamodel.predict_proba(dnaSeqs) | |
preds = dnamodel.predict(dnaSeqs) | |
top5prob = np.argsort(probs, axis=1)[:,-n:] | |
top5class = dnamodel.classes_[top5prob] | |
pred_df = pd.DataFrame(data=[top5class, top5prob], columns= ['Genus', 'Probability']) | |
return genuses | |
def predict_genus_dna_env(dnaSeqsEnv): | |
genuses = {} | |
probs = model.predict_proba(dnaSeqsEnv) | |
preds = model.predict(dnaSeqsEnv) | |
for i in range(len(dnaSeqsEnv)): | |
top5prob = np.argsort(probs[i], axis=1)[:,-5:] | |
top5class = model.classes_[top5prob] | |
sampleStr = dnaSeqsEnv['nucraw'][i] | |
genuses[sampleStr] = (top5class, top5prob) | |
# pred_df = pd.DataFrame(data=[top5class, top5prob], columns= ['Genus', 'Probability']) | |
return genuses | |
# def get_genus_image(genus): | |
# # return a URL to genus image | |
# return f"https://example.com/images/{genus}.jpg" | |
def get_genuses(dna_file, dnaenv_file): | |
dna_df = pd.read_csv(dna_file.name) | |
dnaenv_df = pd.read_csv(dnaenv_file.name) | |
results = [] | |
envdna_genuses = predict_genus_dna_env(dnaenv_df) | |
dna_genuses = predict_genus_dna(dna_df) | |
# images = [get_genus_image(genus) for genus in top_5_genuses] | |
results.append({ | |
"sequence": dna_sequence, | |
"Predictions": envdna_genuses + dna_genuses, | |
# "images": images | |
}) | |
return results | |
def display_results(results): | |
display = [] | |
for result in results: | |
for i in range(len(result["predictions"])): | |
display.append({ | |
"DNA Sequence": result["sequence"], | |
"Predicted Genus": result['predictions'][i][0], | |
"Predicted Genus": result['predictions'][i][0], | |
"Predicted Genus": result['predictions'][i][0], | |
# "Image": result["images"][i] | |
}) | |
return pd.DataFrame(display) | |
def gradio_interface(file): | |
results = get_genuses(file) | |
return display_results(results) | |
# Gradio interface | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown("# Top 5 Most Likely Genus Predictions") | |
file_input = gr.File(label="Upload CSV file", file_types=['csv']) | |
output_table = gr.Dataframe(headers=["DNA", "Coord", "DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"]) | |
def update_output(file): | |
result_df = gradio_interface(file) | |
return result_df | |
file_input.change(update_output, inputs=file_input, outputs=output_table) | |
demo.launch() | |
# with gr.Blocks() as demo: | |
# with gr.Row(): | |
# word = gr.Textbox(label="word") | |
# leng = gr.Number(label="leng") | |
# output = gr.Textbox(label="Output") | |
# with gr.Row(): | |
# run = gr.Button() | |
# event = run.click(predict_genus, | |
# [word, leng], | |
# output, | |
# batch=True, | |
# max_batch_size=20) | |
# demo.launch() | |
# DB_USER = os.getenv("DB_USER") | |
# DB_PASSWORD = os.getenv("DB_PASSWORD") | |
# DB_HOST = os.getenv("DB_HOST") | |
# PORT = 8080 | |
# DB_NAME = "bikeshare" | |
# connection_string = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}" | |
# def get_count_ride_type(): | |
# df = pd.read_sql( | |
# """ | |
# SELECT COUNT(ride_id) as n, rideable_type | |
# FROM rides | |
# GROUP BY rideable_type | |
# ORDER BY n DESC | |
# """, | |
# con=connection_string | |
# ) | |
# fig_m, ax = plt.subplots() | |
# ax.bar(x=df['rideable_type'], height=df['n']) | |
# ax.set_title("Number of rides by bycycle type") | |
# ax.set_ylabel("Number of Rides") | |
# ax.set_xlabel("Bicycle Type") | |
# return fig_m | |
# def get_most_popular_stations(): | |
# df = pd.read_sql( | |
# """ | |
# SELECT COUNT(ride_id) as n, MAX(start_station_name) as station | |
# FROM RIDES | |
# WHERE start_station_name is NOT NULL | |
# GROUP BY start_station_id | |
# ORDER BY n DESC | |
# LIMIT 5 | |
# """, | |
# con=connection_string | |
# ) | |
# fig_m, ax = plt.subplots() | |
# ax.bar(x=df['station'], height=df['n']) | |
# ax.set_title("Most popular stations") | |
# ax.set_ylabel("Number of Rides") | |
# ax.set_xlabel("Station Name") | |
# ax.set_xticklabels( | |
# df['station'], rotation=45, ha="right", rotation_mode="anchor" | |
# ) | |
# ax.tick_params(axis="x", labelsize=8) | |
# fig_m.tight_layout() | |
# return fig_m | |
# with gr.Blocks() as demo: | |
# with gr.Row(): | |
# bike_type = gr.Plot() | |
# station = gr.Plot() | |
# demo.load(get_count_ride_type, inputs=None, outputs=bike_type) | |
# demo.load(get_most_popular_stations, inputs=None, outputs=station) | |
# def greet(name, intensity): | |
# return "Hello, " + name + "!" * int(intensity) | |
# demo = gr.Interface( | |
# fn=greet, | |
# inputs=["text", "slider"], | |
# outputs=["text"], | |
# ) | |
demo.launch() |