notbulubula commited on
Commit
1638e8f
1 Parent(s): 9b37c86

naprawianie all

Browse files
Files changed (2) hide show
  1. app.py +2 -21
  2. utils.py +40 -1
app.py CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
4
  import os
5
  # import matplotlib.pyplot as plt
6
 
7
- from utils import fetch_runs_to_df, fetch_run
8
 
9
  # Access the API key from the environment variable
10
  wandb_api_key = os.getenv('WANDB_API_KEY')
@@ -84,27 +84,8 @@ if option == "Models":
84
 
85
  # Ensure the DataFrame is not empty
86
  if not df.empty:
87
- # Fetch metrics for ranking (e.g., accuracy or loss)
88
- ranking_data = []
89
- for index, row in df.iterrows():
90
- try:
91
- # Fetch the run details
92
- run = api.run(f"{projects[selected_project]['entity']}/{projects[selected_project]['project']}/{row['ID']}")
93
- metrics = run.summary
94
- model_name = run.config.get("model_name", "Unknown") # Fetch model name from the config, defaulting to "Unknown"
95
-
96
- ranking_data.append({
97
- "Model Name": model_name, # Add model name to the table
98
- "Run Name": row["Run Name"],
99
- "ID": row["ID"],
100
- "Accuracy": metrics.get("accuracy"), # Example metric
101
- "Loss": metrics.get("loss") # Example metric
102
- })
103
- except wandb.errors.CommError:
104
- continue
105
-
106
  # Convert to DataFrame
107
- ranking_df = pd.DataFrame(ranking_data)
108
 
109
  # Rank by Accuracy (or another metric)
110
  ranking_df = ranking_df.sort_values(by="Accuracy", ascending=False).reset_index(drop=True)
 
4
  import os
5
  # import matplotlib.pyplot as plt
6
 
7
+ from utils import fetch_runs_to_df, fetch_run, fetch_models_to_df
8
 
9
  # Access the API key from the environment variable
10
  wandb_api_key = os.getenv('WANDB_API_KEY')
 
84
 
85
  # Ensure the DataFrame is not empty
86
  if not df.empty:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  # Convert to DataFrame
88
+ ranking_df = fetch_models_to_df(api, projects, selected_project, df)
89
 
90
  # Rank by Accuracy (or another metric)
91
  ranking_df = ranking_df.sort_values(by="Accuracy", ascending=False).reset_index(drop=True)
utils.py CHANGED
@@ -57,4 +57,43 @@ def fetch_run(api, projects, selected_project, selected_run_id):
57
  project = projects[selected_project]["project"]
58
  run = api.run(f"{entity}/{project}/{selected_run_id}")
59
 
60
- return run
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  project = projects[selected_project]["project"]
58
  run = api.run(f"{entity}/{project}/{selected_run_id}")
59
 
60
+ return run
61
+
62
+
63
+ def fetch_models_to_df(api, projects, selected_project, df):
64
+ data = []
65
+ for index, row in df.iterrows():
66
+ try:
67
+ if selected_project == "All":
68
+ # Determine the project for the current run
69
+ for project_name, details in projects.items():
70
+ entity = details["entity"]
71
+ project = details["project"]
72
+ try:
73
+ run = api.run(f"{entity}/{project}/{row['ID']}")
74
+ break
75
+ except wandb.errors.CommError:
76
+ continue
77
+ else:
78
+ st.error(f"Run ID {row['ID']} not found in any project.")
79
+ continue
80
+ else:
81
+ entity = projects[selected_project]["entity"]
82
+ project = projects[selected_project]["project"]
83
+ run = api.run(f"{entity}/{project}/{row['ID']}")
84
+
85
+ metrics = run.summary
86
+ model_name = run.config.get("model_name", "Unknown") # Fetch model name from the config, defaulting to "Unknown"
87
+
88
+ data.append({
89
+ "Model Name": model_name, # Add model name to the table
90
+ "Run Name": row["Run Name"],
91
+ "ID": row["ID"],
92
+ "Accuracy": metrics.get("accuracy"), # Example metric
93
+ "Loss": metrics.get("loss") # Example metric
94
+ })
95
+ except wandb.errors.CommError:
96
+ continue
97
+
98
+ data_df = pd.DataFrame(data)
99
+ return data_df