jennzhuge commited on
Commit
3f8dd98
·
1 Parent(s): dccc973

pseudocode for pap

Browse files
Files changed (1) hide show
  1. app.py +171 -7
app.py CHANGED
@@ -1,12 +1,176 @@
 
 
 
1
  import gradio as gr
 
2
 
3
- def greet(name, intensity):
4
- return "Hello, " + name + "!" * int(intensity)
5
 
6
- demo = gr.Interface(
7
- fn=greet,
8
- inputs=["text", "slider"],
9
- outputs=["text"],
10
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  demo.launch()
 
1
+ import os
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
  import gradio as gr
5
+ import numpy as mp
6
 
 
 
7
 
8
+ def predict_genus_dna(dnaSeqs):
9
+ genuses = []
10
+
11
+ probs = dnamodel.predict_proba(dnaSeqs)
12
+ preds = dnamodel.predict(dnaSeqs)
13
+ top5prob = np.argsort(probs, axis=1)[:,-n:]
14
+ top5class = dnamodel.classes_[top5prob]
15
+
16
+ pred_df = pd.DataFrame(data=[top5class, top5prob], columns= ['Genus', 'Probability'])
17
+
18
+ return genuses
19
+
20
+ def predict_genus_dna_env(dnaSeqsEnv):
21
+ genuses = {}
22
+ probs = model.predict_proba(dnaSeqsEnv)
23
+ preds = model.predict(dnaSeqsEnv)
24
+
25
+ for i in range(len(dnaSeqsEnv)):
26
+ top5prob = np.argsort(probs[i], axis=1)[:,-5:]
27
+ top5class = model.classes_[top5prob]
28
+
29
+ sampleStr = dnaSeqsEnv['nucraw'][i]
30
+ genuses[sampleStr] = (top5class, top5prob)
31
+
32
+ # pred_df = pd.DataFrame(data=[top5class, top5prob], columns= ['Genus', 'Probability'])
33
+
34
+ return genuses
35
+
36
+ # def get_genus_image(genus):
37
+ # # return a URL to genus image
38
+ # return f"https://example.com/images/{genus}.jpg"
39
+
40
+ def get_genuses(dna_file, dnaenv_file):
41
+ dna_df = pd.read_csv(dna_file.name)
42
+ dnaenv_df = pd.read_csv(dnaenv_file.name)
43
+
44
+ results = []
45
+
46
+ envdna_genuses = predict_genus_dna_env(dnaenv_df)
47
+ dna_genuses = predict_genus_dna(dna_df)
48
+ # images = [get_genus_image(genus) for genus in top_5_genuses]
49
+
50
+ results.append({
51
+ "sequence": dna_sequence,
52
+ "Predictions": envdna_genuses + dna_genuses,
53
+ # "images": images
54
+ })
55
+
56
+ return results
57
+
58
+ def display_results(results):
59
+ display = []
60
+ for result in results:
61
+ for i in range(len(result["predictions"])):
62
+ display.append({
63
+ "DNA Sequence": result["sequence"],
64
+ "Predicted Genus": result['predictions'][i][0],
65
+ "Predicted Genus": result['predictions'][i][0],
66
+ "Predicted Genus": result['predictions'][i][0],
67
+ # "Image": result["images"][i]
68
+ })
69
+ return pd.DataFrame(display)
70
+
71
+ def gradio_interface(file):
72
+ results = get_genuses(file)
73
+ return display_results(results)
74
+
75
+ # Gradio interface
76
+ with gr.Blocks() as demo:
77
+ with gr.Column():
78
+ gr.Markdown("# Top 5 Most Likely Genus Predictions")
79
+ file_input = gr.File(label="Upload CSV file", file_types=['csv'])
80
+ output_table = gr.Dataframe(headers=["DNA", "Coord", "DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"])
81
+
82
+ def update_output(file):
83
+ result_df = gradio_interface(file)
84
+ return result_df
85
+
86
+ file_input.change(update_output, inputs=file_input, outputs=output_table)
87
+
88
+ demo.launch()
89
+
90
+
91
+ # with gr.Blocks() as demo:
92
+ # with gr.Row():
93
+ # word = gr.Textbox(label="word")
94
+ # leng = gr.Number(label="leng")
95
+ # output = gr.Textbox(label="Output")
96
+ # with gr.Row():
97
+ # run = gr.Button()
98
+
99
+ # event = run.click(predict_genus,
100
+ # [word, leng],
101
+ # output,
102
+ # batch=True,
103
+ # max_batch_size=20)
104
+
105
+ # demo.launch()
106
+
107
+ # DB_USER = os.getenv("DB_USER")
108
+ # DB_PASSWORD = os.getenv("DB_PASSWORD")
109
+ # DB_HOST = os.getenv("DB_HOST")
110
+ # PORT = 8080
111
+ # DB_NAME = "bikeshare"
112
+
113
+ # connection_string = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}"
114
+
115
+ # def get_count_ride_type():
116
+ # df = pd.read_sql(
117
+ # """
118
+ # SELECT COUNT(ride_id) as n, rideable_type
119
+ # FROM rides
120
+ # GROUP BY rideable_type
121
+ # ORDER BY n DESC
122
+ # """,
123
+ # con=connection_string
124
+ # )
125
+ # fig_m, ax = plt.subplots()
126
+ # ax.bar(x=df['rideable_type'], height=df['n'])
127
+ # ax.set_title("Number of rides by bycycle type")
128
+ # ax.set_ylabel("Number of Rides")
129
+ # ax.set_xlabel("Bicycle Type")
130
+ # return fig_m
131
+
132
+
133
+ # def get_most_popular_stations():
134
+
135
+ # df = pd.read_sql(
136
+ # """
137
+ # SELECT COUNT(ride_id) as n, MAX(start_station_name) as station
138
+ # FROM RIDES
139
+ # WHERE start_station_name is NOT NULL
140
+ # GROUP BY start_station_id
141
+ # ORDER BY n DESC
142
+ # LIMIT 5
143
+ # """,
144
+ # con=connection_string
145
+ # )
146
+ # fig_m, ax = plt.subplots()
147
+ # ax.bar(x=df['station'], height=df['n'])
148
+ # ax.set_title("Most popular stations")
149
+ # ax.set_ylabel("Number of Rides")
150
+ # ax.set_xlabel("Station Name")
151
+ # ax.set_xticklabels(
152
+ # df['station'], rotation=45, ha="right", rotation_mode="anchor"
153
+ # )
154
+ # ax.tick_params(axis="x", labelsize=8)
155
+ # fig_m.tight_layout()
156
+ # return fig_m
157
+
158
+
159
+ # with gr.Blocks() as demo:
160
+ # with gr.Row():
161
+ # bike_type = gr.Plot()
162
+ # station = gr.Plot()
163
+
164
+ # demo.load(get_count_ride_type, inputs=None, outputs=bike_type)
165
+ # demo.load(get_most_popular_stations, inputs=None, outputs=station)
166
+
167
+ # def greet(name, intensity):
168
+ # return "Hello, " + name + "!" * int(intensity)
169
+
170
+ # demo = gr.Interface(
171
+ # fn=greet,
172
+ # inputs=["text", "slider"],
173
+ # outputs=["text"],
174
+ # )
175
 
176
  demo.launch()