Spaces:
Sleeping
Sleeping
notbulubula
commited on
Commit
•
1638e8f
1
Parent(s):
9b37c86
naprawianie all
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import pandas as pd
|
|
4 |
import os
|
5 |
# import matplotlib.pyplot as plt
|
6 |
|
7 |
-
from utils import fetch_runs_to_df, fetch_run
|
8 |
|
9 |
# Access the API key from the environment variable
|
10 |
wandb_api_key = os.getenv('WANDB_API_KEY')
|
@@ -84,27 +84,8 @@ if option == "Models":
|
|
84 |
|
85 |
# Ensure the DataFrame is not empty
|
86 |
if not df.empty:
|
87 |
-
# Fetch metrics for ranking (e.g., accuracy or loss)
|
88 |
-
ranking_data = []
|
89 |
-
for index, row in df.iterrows():
|
90 |
-
try:
|
91 |
-
# Fetch the run details
|
92 |
-
run = api.run(f"{projects[selected_project]['entity']}/{projects[selected_project]['project']}/{row['ID']}")
|
93 |
-
metrics = run.summary
|
94 |
-
model_name = run.config.get("model_name", "Unknown") # Fetch model name from the config, defaulting to "Unknown"
|
95 |
-
|
96 |
-
ranking_data.append({
|
97 |
-
"Model Name": model_name, # Add model name to the table
|
98 |
-
"Run Name": row["Run Name"],
|
99 |
-
"ID": row["ID"],
|
100 |
-
"Accuracy": metrics.get("accuracy"), # Example metric
|
101 |
-
"Loss": metrics.get("loss") # Example metric
|
102 |
-
})
|
103 |
-
except wandb.errors.CommError:
|
104 |
-
continue
|
105 |
-
|
106 |
# Convert to DataFrame
|
107 |
-
ranking_df =
|
108 |
|
109 |
# Rank by Accuracy (or another metric)
|
110 |
ranking_df = ranking_df.sort_values(by="Accuracy", ascending=False).reset_index(drop=True)
|
|
|
4 |
import os
|
5 |
# import matplotlib.pyplot as plt
|
6 |
|
7 |
+
from utils import fetch_runs_to_df, fetch_run, fetch_models_to_df
|
8 |
|
9 |
# Access the API key from the environment variable
|
10 |
wandb_api_key = os.getenv('WANDB_API_KEY')
|
|
|
84 |
|
85 |
# Ensure the DataFrame is not empty
|
86 |
if not df.empty:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
# Convert to DataFrame
|
88 |
+
ranking_df = fetch_models_to_df(api, projects, selected_project, df)
|
89 |
|
90 |
# Rank by Accuracy (or another metric)
|
91 |
ranking_df = ranking_df.sort_values(by="Accuracy", ascending=False).reset_index(drop=True)
|
utils.py
CHANGED
@@ -57,4 +57,43 @@ def fetch_run(api, projects, selected_project, selected_run_id):
|
|
57 |
project = projects[selected_project]["project"]
|
58 |
run = api.run(f"{entity}/{project}/{selected_run_id}")
|
59 |
|
60 |
-
return run
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
project = projects[selected_project]["project"]
|
58 |
run = api.run(f"{entity}/{project}/{selected_run_id}")
|
59 |
|
60 |
+
return run
|
61 |
+
|
62 |
+
|
63 |
+
def fetch_models_to_df(api, projects, selected_project, df):
|
64 |
+
data = []
|
65 |
+
for index, row in df.iterrows():
|
66 |
+
try:
|
67 |
+
if selected_project == "All":
|
68 |
+
# Determine the project for the current run
|
69 |
+
for project_name, details in projects.items():
|
70 |
+
entity = details["entity"]
|
71 |
+
project = details["project"]
|
72 |
+
try:
|
73 |
+
run = api.run(f"{entity}/{project}/{row['ID']}")
|
74 |
+
break
|
75 |
+
except wandb.errors.CommError:
|
76 |
+
continue
|
77 |
+
else:
|
78 |
+
st.error(f"Run ID {row['ID']} not found in any project.")
|
79 |
+
continue
|
80 |
+
else:
|
81 |
+
entity = projects[selected_project]["entity"]
|
82 |
+
project = projects[selected_project]["project"]
|
83 |
+
run = api.run(f"{entity}/{project}/{row['ID']}")
|
84 |
+
|
85 |
+
metrics = run.summary
|
86 |
+
model_name = run.config.get("model_name", "Unknown") # Fetch model name from the config, defaulting to "Unknown"
|
87 |
+
|
88 |
+
data.append({
|
89 |
+
"Model Name": model_name, # Add model name to the table
|
90 |
+
"Run Name": row["Run Name"],
|
91 |
+
"ID": row["ID"],
|
92 |
+
"Accuracy": metrics.get("accuracy"), # Example metric
|
93 |
+
"Loss": metrics.get("loss") # Example metric
|
94 |
+
})
|
95 |
+
except wandb.errors.CommError:
|
96 |
+
continue
|
97 |
+
|
98 |
+
data_df = pd.DataFrame(data)
|
99 |
+
return data_df
|