Spaces:
Runtime error
Runtime error
import os | |
import pandas as pd | |
# import matplotlib.pyplot as plt | |
import gradio as gr | |
import numpy as np | |
import infer | |
# def predict_genus_dna(dnaSeqs): | |
# genuses = [] | |
# # probs = dnamodel.predict_proba(dnaSeqs) | |
# # preds = dnamodel.predict(dnaSeqs) | |
# # topProb = np.argsort(probs, axis=1)[:,-3:] | |
# # topClass = dnamodel.classes_[topProb] | |
# # pred_df = pd.DataFrame(data=[topClass, topProb], columns= ['Genus', 'Probability']) | |
# return genuses | |
# def predict_genus_dna_env(dnaSeqsEnv): | |
# genuses = {} | |
# probs = model.predict_proba(dnaSeqsEnv) | |
# preds = model.predict(dnaSeqsEnv) | |
# for i in range(len(dnaSeqsEnv)): | |
# topProb = np.argsort(probs[i], axis=1)[:,-3:] | |
# topClass = model.classes_[topProb] | |
# sampleStr = dnaSeqsEnv['nucraw'][i] | |
# genuses[sampleStr] = (topClass, topProb) | |
# pred_df = pd.DataFrame(data=[top5class, top5prob], columns= ['Genus', 'Probability']) | |
# return genuses | |
# def get_genus_image(genus): | |
# # return a URL to genus image | |
# return f"https://example.com/images/{genus}.jpg" | |
def get_genuses(dna_file, dnaenv_file): | |
dna_df = pd.read_csv(dna_file.name) | |
dnaenv_df = pd.read_csv(dnaenv_file.name) | |
results = [] | |
# envdna_genuses = predict_genus_dna_env(dnaenv_df) | |
# dna_genuses = predict_genus_dna(dna_df) | |
# images = [get_genus_image(genus) for genus in top_5_genuses] | |
genuses = infer.infer() | |
results.append({ | |
"sequence": dna_df['nucraw'], | |
# "predictions": pd.concat([dna_genuses, envdna_genuses], axis=0) | |
'predictions': genuses | |
}) | |
return results | |
def display_results(results): | |
display = [] | |
for result in results: | |
# for i in range(len(result["predictions"])): | |
# display.append({ | |
# "DNA Sequence": result["sequence"], | |
# "DNA Pred Genus": result['predictions'][i][0], | |
# "DNA Only Prob": result['predictions'][i][1], | |
# "DNA Env Pred Genus": result['predictions'][i][2], | |
# "DNA Env Prob": result['predictions'][i][3], | |
# # "Image": result["images"][i] | |
# }) | |
display.append({ | |
"DNA Sequence": result["sequence"], | |
"DNA Pred Genus": result['predictions'][0] | |
}) | |
return pd.DataFrame(display) | |
def gradio_interface(file): | |
results = get_genuses(file) | |
return display_results(results) | |
# Gradio interface | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown("# DNA Identifier Tool") | |
file_input = gr.File(label="Upload DNA CSV file", file_types=['csv']) | |
output_table = gr.Dataframe(headers=["DNA", "Coord", "DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"]) | |
def update_output(file): | |
result_df = gradio_interface(file) | |
return result_df | |
file_input.change(update_output, inputs=file_input, outputs=output_table) | |
demo.launch() | |
# with gr.Blocks() as demo: | |
# with gr.Row(): | |
# word = gr.Textbox(label="word") | |
# leng = gr.Number(label="leng") | |
# output = gr.Textbox(label="Output") | |
# with gr.Row(): | |
# run = gr.Button() | |
# event = run.click(predict_genus, | |
# [word, leng], | |
# output, | |
# batch=True, | |
# max_batch_size=20) | |
# demo.launch() | |
# DB_USER = os.getenv("DB_USER") | |
# DB_PASSWORD = os.getenv("DB_PASSWORD") | |
# DB_HOST = os.getenv("DB_HOST") | |
# PORT = 8080 | |
# DB_NAME = "bikeshare" | |
# connection_string = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}" | |
# def get_count_ride_type(): | |
# df = pd.read_sql( | |
# """ | |
# SELECT COUNT(ride_id) as n, rideable_type | |
# FROM rides | |
# GROUP BY rideable_type | |
# ORDER BY n DESC | |
# """, | |
# con=connection_string | |
# ) | |
# fig_m, ax = plt.subplots() | |
# ax.bar(x=df['rideable_type'], height=df['n']) | |
# ax.set_title("Number of rides by bycycle type") | |
# ax.set_ylabel("Number of Rides") | |
# ax.set_xlabel("Bicycle Type") | |
# return fig_m | |
# def get_most_popular_stations(): | |
# df = pd.read_sql( | |
# """ | |
# SELECT COUNT(ride_id) as n, MAX(start_station_name) as station | |
# FROM RIDES | |
# WHERE start_station_name is NOT NULL | |
# GROUP BY start_station_id | |
# ORDER BY n DESC | |
# LIMIT 5 | |
# """, | |
# con=connection_string | |
# ) | |
# fig_m, ax = plt.subplots() | |
# ax.bar(x=df['station'], height=df['n']) | |
# ax.set_title("Most popular stations") | |
# ax.set_ylabel("Number of Rides") | |
# ax.set_xlabel("Station Name") | |
# ax.set_xticklabels( | |
# df['station'], rotation=45, ha="right", rotation_mode="anchor" | |
# ) | |
# ax.tick_params(axis="x", labelsize=8) | |
# fig_m.tight_layout() | |
# return fig_m | |
# with gr.Blocks() as demo: | |
# with gr.Row(): | |
# bike_type = gr.Plot() | |
# station = gr.Plot() | |
# demo.load(get_count_ride_type, inputs=None, outputs=bike_type) | |
# demo.load(get_most_popular_stations, inputs=None, outputs=station) | |
# def greet(name, intensity): | |
# return "Hello, " + name + "!" * int(intensity) | |
# demo = gr.Interface( | |
# fn=greet, | |
# inputs=["text", "slider"], | |
# outputs=["text"], | |
# ) | |
demo.launch() |