LofiAmazonSpace / app.py
jennzhuge
pseudocode for pap
3f8dd98
raw
history blame
5.05 kB
import os
import pandas as pd
import matplotlib.pyplot as plt
import gradio as gr
import numpy as mp
def predict_genus_dna(dnaSeqs):
genuses = []
probs = dnamodel.predict_proba(dnaSeqs)
preds = dnamodel.predict(dnaSeqs)
top5prob = np.argsort(probs, axis=1)[:,-n:]
top5class = dnamodel.classes_[top5prob]
pred_df = pd.DataFrame(data=[top5class, top5prob], columns= ['Genus', 'Probability'])
return genuses
def predict_genus_dna_env(dnaSeqsEnv):
genuses = {}
probs = model.predict_proba(dnaSeqsEnv)
preds = model.predict(dnaSeqsEnv)
for i in range(len(dnaSeqsEnv)):
top5prob = np.argsort(probs[i], axis=1)[:,-5:]
top5class = model.classes_[top5prob]
sampleStr = dnaSeqsEnv['nucraw'][i]
genuses[sampleStr] = (top5class, top5prob)
# pred_df = pd.DataFrame(data=[top5class, top5prob], columns= ['Genus', 'Probability'])
return genuses
# def get_genus_image(genus):
# # return a URL to genus image
# return f"https://example.com/images/{genus}.jpg"
def get_genuses(dna_file, dnaenv_file):
dna_df = pd.read_csv(dna_file.name)
dnaenv_df = pd.read_csv(dnaenv_file.name)
results = []
envdna_genuses = predict_genus_dna_env(dnaenv_df)
dna_genuses = predict_genus_dna(dna_df)
# images = [get_genus_image(genus) for genus in top_5_genuses]
results.append({
"sequence": dna_sequence,
"Predictions": envdna_genuses + dna_genuses,
# "images": images
})
return results
def display_results(results):
display = []
for result in results:
for i in range(len(result["predictions"])):
display.append({
"DNA Sequence": result["sequence"],
"Predicted Genus": result['predictions'][i][0],
"Predicted Genus": result['predictions'][i][0],
"Predicted Genus": result['predictions'][i][0],
# "Image": result["images"][i]
})
return pd.DataFrame(display)
def gradio_interface(file):
results = get_genuses(file)
return display_results(results)
# Gradio interface
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# Top 5 Most Likely Genus Predictions")
file_input = gr.File(label="Upload CSV file", file_types=['csv'])
output_table = gr.Dataframe(headers=["DNA", "Coord", "DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"])
def update_output(file):
result_df = gradio_interface(file)
return result_df
file_input.change(update_output, inputs=file_input, outputs=output_table)
demo.launch()
# with gr.Blocks() as demo:
# with gr.Row():
# word = gr.Textbox(label="word")
# leng = gr.Number(label="leng")
# output = gr.Textbox(label="Output")
# with gr.Row():
# run = gr.Button()
# event = run.click(predict_genus,
# [word, leng],
# output,
# batch=True,
# max_batch_size=20)
# demo.launch()
# DB_USER = os.getenv("DB_USER")
# DB_PASSWORD = os.getenv("DB_PASSWORD")
# DB_HOST = os.getenv("DB_HOST")
# PORT = 8080
# DB_NAME = "bikeshare"
# connection_string = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}"
# def get_count_ride_type():
# df = pd.read_sql(
# """
# SELECT COUNT(ride_id) as n, rideable_type
# FROM rides
# GROUP BY rideable_type
# ORDER BY n DESC
# """,
# con=connection_string
# )
# fig_m, ax = plt.subplots()
# ax.bar(x=df['rideable_type'], height=df['n'])
# ax.set_title("Number of rides by bycycle type")
# ax.set_ylabel("Number of Rides")
# ax.set_xlabel("Bicycle Type")
# return fig_m
# def get_most_popular_stations():
# df = pd.read_sql(
# """
# SELECT COUNT(ride_id) as n, MAX(start_station_name) as station
# FROM RIDES
# WHERE start_station_name is NOT NULL
# GROUP BY start_station_id
# ORDER BY n DESC
# LIMIT 5
# """,
# con=connection_string
# )
# fig_m, ax = plt.subplots()
# ax.bar(x=df['station'], height=df['n'])
# ax.set_title("Most popular stations")
# ax.set_ylabel("Number of Rides")
# ax.set_xlabel("Station Name")
# ax.set_xticklabels(
# df['station'], rotation=45, ha="right", rotation_mode="anchor"
# )
# ax.tick_params(axis="x", labelsize=8)
# fig_m.tight_layout()
# return fig_m
# with gr.Blocks() as demo:
# with gr.Row():
# bike_type = gr.Plot()
# station = gr.Plot()
# demo.load(get_count_ride_type, inputs=None, outputs=bike_type)
# demo.load(get_most_popular_stations, inputs=None, outputs=station)
# def greet(name, intensity):
# return "Hello, " + name + "!" * int(intensity)
# demo = gr.Interface(
# fn=greet,
# inputs=["text", "slider"],
# outputs=["text"],
# )
demo.launch()