jennzhuge commited on
Commit
6d06448
·
1 Parent(s): 3f8dd98

skeleton code

Browse files
Files changed (3) hide show
  1. app.py +42 -36
  2. requirements.txt +8 -0
  3. xgboost_infer.py +66 -0
app.py CHANGED
@@ -2,55 +2,56 @@ import os
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import gradio as gr
5
- import numpy as mp
 
6
 
 
 
7
 
8
- def predict_genus_dna(dnaSeqs):
9
- genuses = []
 
 
10
 
11
- probs = dnamodel.predict_proba(dnaSeqs)
12
- preds = dnamodel.predict(dnaSeqs)
13
- top5prob = np.argsort(probs, axis=1)[:,-n:]
14
- top5class = dnamodel.classes_[top5prob]
15
 
16
- pred_df = pd.DataFrame(data=[top5class, top5prob], columns= ['Genus', 'Probability'])
17
 
18
- return genuses
19
-
20
- def predict_genus_dna_env(dnaSeqsEnv):
21
- genuses = {}
22
- probs = model.predict_proba(dnaSeqsEnv)
23
- preds = model.predict(dnaSeqsEnv)
24
 
25
- for i in range(len(dnaSeqsEnv)):
26
- top5prob = np.argsort(probs[i], axis=1)[:,-5:]
27
- top5class = model.classes_[top5prob]
28
 
29
- sampleStr = dnaSeqsEnv['nucraw'][i]
30
- genuses[sampleStr] = (top5class, top5prob)
31
 
32
  # pred_df = pd.DataFrame(data=[top5class, top5prob], columns= ['Genus', 'Probability'])
33
 
34
- return genuses
35
 
36
  # def get_genus_image(genus):
37
  # # return a URL to genus image
38
  # return f"https://example.com/images/{genus}.jpg"
39
 
40
  def get_genuses(dna_file, dnaenv_file):
41
- dna_df = pd.read_csv(dna_file.name)
42
- dnaenv_df = pd.read_csv(dnaenv_file.name)
43
 
44
  results = []
45
 
46
- envdna_genuses = predict_genus_dna_env(dnaenv_df)
47
- dna_genuses = predict_genus_dna(dna_df)
48
  # images = [get_genus_image(genus) for genus in top_5_genuses]
 
 
49
 
50
  results.append({
51
- "sequence": dna_sequence,
52
- "Predictions": envdna_genuses + dna_genuses,
53
- # "images": images
54
  })
55
 
56
  return results
@@ -58,14 +59,19 @@ def get_genuses(dna_file, dnaenv_file):
58
  def display_results(results):
59
  display = []
60
  for result in results:
61
- for i in range(len(result["predictions"])):
62
- display.append({
63
- "DNA Sequence": result["sequence"],
64
- "Predicted Genus": result['predictions'][i][0],
65
- "Predicted Genus": result['predictions'][i][0],
66
- "Predicted Genus": result['predictions'][i][0],
67
- # "Image": result["images"][i]
68
- })
 
 
 
 
 
69
  return pd.DataFrame(display)
70
 
71
  def gradio_interface(file):
@@ -76,7 +82,7 @@ def gradio_interface(file):
76
  with gr.Blocks() as demo:
77
  with gr.Column():
78
  gr.Markdown("# Top 5 Most Likely Genus Predictions")
79
- file_input = gr.File(label="Upload CSV file", file_types=['csv'])
80
  output_table = gr.Dataframe(headers=["DNA", "Coord", "DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"])
81
 
82
  def update_output(file):
 
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import gradio as gr
5
+ import numpy as np
6
+ import xgboost_infer
7
 
8
+ # def predict_genus_dna(dnaSeqs):
9
+ # genuses = []
10
 
11
+ # # probs = dnamodel.predict_proba(dnaSeqs)
12
+ # # preds = dnamodel.predict(dnaSeqs)
13
+ # # topProb = np.argsort(probs, axis=1)[:,-3:]
14
+ # # topClass = dnamodel.classes_[topProb]
15
 
16
+ # # pred_df = pd.DataFrame(data=[topClass, topProb], columns= ['Genus', 'Probability'])
 
 
 
17
 
18
+ # return genuses
19
 
20
+ # def predict_genus_dna_env(dnaSeqsEnv):
21
+ # genuses = {}
22
+ # probs = model.predict_proba(dnaSeqsEnv)
23
+ # preds = model.predict(dnaSeqsEnv)
 
 
24
 
25
+ # for i in range(len(dnaSeqsEnv)):
26
+ # topProb = np.argsort(probs[i], axis=1)[:,-3:]
27
+ # topClass = model.classes_[topProb]
28
 
29
+ # sampleStr = dnaSeqsEnv['nucraw'][i]
30
+ # genuses[sampleStr] = (topClass, topProb)
31
 
32
  # pred_df = pd.DataFrame(data=[top5class, top5prob], columns= ['Genus', 'Probability'])
33
 
34
+ # return genuses
35
 
36
  # def get_genus_image(genus):
37
  # # return a URL to genus image
38
  # return f"https://example.com/images/{genus}.jpg"
39
 
40
  def get_genuses(dna_file, dnaenv_file):
41
+ # dna_df = pd.read_csv(dna_file.name)
42
+ # dnaenv_df = pd.read_csv(dnaenv_file.name)
43
 
44
  results = []
45
 
46
+ # envdna_genuses = predict_genus_dna_env(dnaenv_df)
47
+ # dna_genuses = predict_genus_dna(dna_df)
48
  # images = [get_genus_image(genus) for genus in top_5_genuses]
49
+
50
+ genuses = xgboost_infer.infer()
51
 
52
  results.append({
53
+ "sequence": dna_df['nucraw']
54
+ "predictions": pd.concat([dna_genuses, envdna_genuses], axis=0)
 
55
  })
56
 
57
  return results
 
59
  def display_results(results):
60
  display = []
61
  for result in results:
62
+ # for i in range(len(result["predictions"])):
63
+ # display.append({
64
+ # "DNA Sequence": result["sequence"],
65
+ # "DNA Pred Genus": result['predictions'][i][0],
66
+ # "DNA Only Prob": result['predictions'][i][1],
67
+ # "DNA Env Pred Genus": result['predictions'][i][2],
68
+ # "DNA Env Prob": result['predictions'][i][3],
69
+ # # "Image": result["images"][i]
70
+ # })
71
+ display.append({
72
+ "DNA Sequence": result["sequence"],
73
+ "DNA Pred Genus": result['predictions'][0]
74
+ })
75
  return pd.DataFrame(display)
76
 
77
  def gradio_interface(file):
 
82
  with gr.Blocks() as demo:
83
  with gr.Column():
84
  gr.Markdown("# Top 5 Most Likely Genus Predictions")
85
+ file_input = gr.File(label="Upload DNA CSV file", file_types=['csv'])
86
  output_table = gr.Dataframe(headers=["DNA", "Coord", "DNA Only Pred Genus", "DNA Only Prob", "DNA & Env Pred Genus", "DNA & Env Prob"])
87
 
88
  def update_output(file):
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ huggingface-hub==0.23.2
2
+ pandas==2.2.2
3
+ torch==2.3.0
4
+ tqdm==4.66.4
5
+ transformers==4.41.2
6
+ faiss
7
+ gradio
8
+ datasets
xgboost_infer.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #PSUEDOCODE UNTIL WE GET DATA
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sklearn.metrics import accuracy_score
5
+ from sklearn.preprocessing import LabelEncoder
6
+ from datasets import load_dataset
7
+ import pickle
8
+
9
+ def infer_dna(args):
10
+ ecoDf = pd.read_csv(args['input_path'], sep='\t')
11
+ dnaEmbeds = load_dataset("LofiAmazon/BOLD-Embeddings", split='train')
12
+
13
+ modelDNA = load_checkpoint()
14
+ modelDNAEnv = load_checkpoint()
15
+
16
+ ecoDF = ecoDf[ecoDf['marker_code' == 'COI-5P']]
17
+ ecoDf = ecoDf[['processid','nucraw','coord','country','depth',
18
+ 'WorldClim2_BIO_Temperature_Seasonality',
19
+ 'WorldClim2_BIO_Precipitation_Seasonality','WorldClim2_BIO_Annual_Precipitation', 'EarthEnvTopoMed_Elevation',
20
+ 'EsaWorldCover_TreeCover', 'CHELSA_exBIO_GrowingSeasonLength',
21
+ 'WCS_Human_Footprint_2009', 'GHS_Population_Density',
22
+ 'CHELSA_BIO_Annual_Mean_Temperature']]
23
+
24
+ # grab DNA embeddings and merge them onto ecoDf by processid
25
+ X_eco = pd.merge(ecoDf, dnaEmbeds, on='processid', how='left')
26
+
27
+
28
+ # split data into X and y
29
+ # X = df.drop(columns=['genus'])
30
+ Y_eco = ecoDf['genus']
31
+
32
+ # do inference with the model trained on DNA and Env data
33
+ y_eco_probs = modelDNA.predict_proba(X_eco)
34
+ # topProb = np.argsort(y_probs, axis=1)[:,-3:]
35
+ # topClass = dnamodel.classes_[topProb]
36
+
37
+ DNAGenuses = {}
38
+ for i in range(len(X_eco)):
39
+ topProbs = np.argsort(y_probs[i], axis=1)[:,-3:]
40
+ topClasses = modelDNA.classes_[topProbs]
41
+
42
+ sampleStr = X_eco['nucraw'][i]
43
+ DNAGenuses[sampleStr] = (topClasses, topProbs)
44
+
45
+
46
+ X_dna = dnaEmbeds.drop(columns='genus')
47
+ Y_dna = dnaEmbeds['genus']
48
+ # do inferences with the model only trained on DNA
49
+ y_dna_probs = modelDNAEnv.predict_proba(X_dna)
50
+ DNAEnvGenuses = {}
51
+ for i in range(len()):
52
+ topProbs = np.argsort(y_dna_probs[i], axis=1)[:,-3:]
53
+ topClasses = modelDNA.classes_[topProbs]
54
+
55
+ sampleStr = X_eco['nucraw'][i]
56
+ DNAGenuses[sampleStr] = (topClasses, topProbs)
57
+
58
+ return DNAGenuses, DNAEnvGenuses
59
+
60
+
61
+ # if __name__ == '__main__':
62
+ # parser = argparse.ArgumentParser()
63
+ # parser.add_argument('--input_path', action='store', type=str)
64
+ # # parser.add_argument('--checkpt', action='store', type=bool, default=False)
65
+
66
+ # args = vars(parser.parse_args())