File size: 3,664 Bytes
f7832ad
 
6322afe
f7832ad
 
dc1bc99
6322afe
f7832ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6322afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1638e8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
import pandas as pd
import wandb

def fetch_runs_to_df(api, projects, selected_project):
    data = []

    if selected_project == "All":
        # return all runs from all projects
        for project_name, details in projects.items():
            entity = details["entity"]
            project = details["project"]
            runs = api.runs(f"{entity}/{project}")
            for run in runs:
                data.append({
                    "Run Name": run.name,
                    "ID": run.id,
                    "Created At": run.created_at,
                    "State": run.state,
                    "Tags": ", ".join(run.tags)  # Join tags into a single string
                })
        df = pd.DataFrame(data)
    
    else:
        # Get the selected project's details
        entity = projects[selected_project]["entity"]
        project = projects[selected_project]["project"]
        runs = api.runs(f"{entity}/{project}")
        for run in runs:
            data.append({
                "Run Name": run.name,
                "ID": run.id,
                "Created At": run.created_at,
                "State": run.state,
                "Tags": ", ".join(run.tags)  # Join tags into a single string
            })
        df = pd.DataFrame(data)

    return df

def fetch_run(api, projects, selected_project, selected_run_id):
    # Fetch run details based on the selected project
    if selected_project == "All":
        # Find the project for the selected run_id
        for project_name, details in projects.items():
            entity = details["entity"]
            project = details["project"]
            try:
                run = api.run(f"{entity}/{project}/{selected_run_id}")
                break
            except wandb.errors.CommError:
                continue
        else:
            st.error(f"Run ID {selected_run_id} not found in any project.")
    else:
        entity = projects[selected_project]["entity"]
        project = projects[selected_project]["project"]
        run = api.run(f"{entity}/{project}/{selected_run_id}")

    return run


def fetch_models_to_df(api, projects, selected_project, df):
    data = []
    for index, row in df.iterrows():
        try:
            if selected_project == "All":
                # Determine the project for the current run
                for project_name, details in projects.items():
                    entity = details["entity"]
                    project = details["project"]
                    try:
                        run = api.run(f"{entity}/{project}/{row['ID']}")
                        break
                    except wandb.errors.CommError:
                        continue
                else:
                    st.error(f"Run ID {row['ID']} not found in any project.")
                    continue
            else:
                entity = projects[selected_project]["entity"]
                project = projects[selected_project]["project"]
                run = api.run(f"{entity}/{project}/{row['ID']}")

            metrics = run.summary
            model_name = run.config.get("model_name", "Unknown")  # Fetch model name from the config, defaulting to "Unknown"
            
            data.append({
                "Model Name": model_name,  # Add model name to the table
                "Run Name": row["Run Name"],
                "ID": row["ID"],
                "Accuracy": metrics.get("accuracy"),  # Example metric
                "Loss": metrics.get("loss")  # Example metric
            })
        except wandb.errors.CommError:
            continue

    data_df = pd.DataFrame(data)
    return data_df