LofiAmazonSpace / app.old.py
jennzhuge
hi
9500bfe
import os
import pandas as pd
# import matplotlib.pyplot as plt
import gradio as gr
import numpy as np
import infer
# def predict_genus_dna(dnaSeqs):
# genuses = []
# # probs = dnamodel.predict_proba(dnaSeqs)
# # preds = dnamodel.predict(dnaSeqs)
# # topProb = np.argsort(probs, axis=1)[:,-3:]
# # topClass = dnamodel.classes_[topProb]
# # pred_df = pd.DataFrame(data=[topClass, topProb], columns= ['Genus', 'Probability'])
# return genuses
# def predict_genus_dna_env(dnaSeqsEnv):
# genuses = {}
# probs = model.predict_proba(dnaSeqsEnv)
# preds = model.predict(dnaSeqsEnv)
# for i in range(len(dnaSeqsEnv)):
# topProb = np.argsort(probs[i], axis=1)[:,-3:]
# topClass = model.classes_[topProb]
# sampleStr = dnaSeqsEnv['nucraw'][i]
# genuses[sampleStr] = (topClass, topProb)
# pred_df = pd.DataFrame(data=[top5class, top5prob], columns= ['Genus', 'Probability'])
# return genuses
# def get_genus_image(genus):
# # return a URL to genus image
# return f"https://example.com/images/{genus}.jpg"
def get_genuses(dna_file, dnaenv_file):
dna_df = pd.read_csv(dna_file.name)
dnaenv_df = pd.read_csv(dnaenv_file.name)
results = []
# envdna_genuses = predict_genus_dna_env(dnaenv_df)
# dna_genuses = predict_genus_dna(dna_df)
# images = [get_genus_image(genus) for genus in top_5_genuses]
genuses = infer.infer()
results.append({
"sequence": dna_df['nucraw'],
# "predictions": pd.concat([dna_genuses, envdna_genuses], axis=0)
'predictions': genuses
})
return results
def display_results(results):
display = []
for result in results:
# for i in range(len(result["predictions"])):
# display.append({
# "DNA Sequence": result["sequence"],
# "DNA Pred Genus": result['predictions'][i][0],
# "DNA Only Prob": result['predictions'][i][1],
# "DNA Env Pred Genus": result['predictions'][i][2],
# "DNA Env Prob": result['predictions'][i][3],
# # "Image": result["images"][i]
# })
display.append({
"DNA Sequence": result["sequence"],
"DNA Pred Genus": result['predictions'][0]
})
return pd.DataFrame(display)
def gradio_interface(file):
results = get_genuses(file)
return display_results(results)
# Gradio interface
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# DNA Identifier Tool")
file_input = gr.File(label="Upload DNA CSV file", file_types=['csv'])
output_table = gr.Dataframe(headers=["DNA", "Coord", "DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"])
def update_output(file):
result_df = gradio_interface(file)
return result_df
file_input.change(update_output, inputs=file_input, outputs=output_table)
demo.launch()
# with gr.Blocks() as demo:
# with gr.Row():
# word = gr.Textbox(label="word")
# leng = gr.Number(label="leng")
# output = gr.Textbox(label="Output")
# with gr.Row():
# run = gr.Button()
# event = run.click(predict_genus,
# [word, leng],
# output,
# batch=True,
# max_batch_size=20)
# demo.launch()
# DB_USER = os.getenv("DB_USER")
# DB_PASSWORD = os.getenv("DB_PASSWORD")
# DB_HOST = os.getenv("DB_HOST")
# PORT = 8080
# DB_NAME = "bikeshare"
# connection_string = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}"
# def get_count_ride_type():
# df = pd.read_sql(
# """
# SELECT COUNT(ride_id) as n, rideable_type
# FROM rides
# GROUP BY rideable_type
# ORDER BY n DESC
# """,
# con=connection_string
# )
# fig_m, ax = plt.subplots()
# ax.bar(x=df['rideable_type'], height=df['n'])
# ax.set_title("Number of rides by bycycle type")
# ax.set_ylabel("Number of Rides")
# ax.set_xlabel("Bicycle Type")
# return fig_m
# def get_most_popular_stations():
# df = pd.read_sql(
# """
# SELECT COUNT(ride_id) as n, MAX(start_station_name) as station
# FROM RIDES
# WHERE start_station_name is NOT NULL
# GROUP BY start_station_id
# ORDER BY n DESC
# LIMIT 5
# """,
# con=connection_string
# )
# fig_m, ax = plt.subplots()
# ax.bar(x=df['station'], height=df['n'])
# ax.set_title("Most popular stations")
# ax.set_ylabel("Number of Rides")
# ax.set_xlabel("Station Name")
# ax.set_xticklabels(
# df['station'], rotation=45, ha="right", rotation_mode="anchor"
# )
# ax.tick_params(axis="x", labelsize=8)
# fig_m.tight_layout()
# return fig_m
# with gr.Blocks() as demo:
# with gr.Row():
# bike_type = gr.Plot()
# station = gr.Plot()
# demo.load(get_count_ride_type, inputs=None, outputs=bike_type)
# demo.load(get_most_popular_stations, inputs=None, outputs=station)
# def greet(name, intensity):
# return "Hello, " + name + "!" * int(intensity)
# demo = gr.Interface(
# fn=greet,
# inputs=["text", "slider"],
# outputs=["text"],
# )
demo.launch()