File size: 5,374 Bytes
308f73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

from datetime import datetime, timedelta
import numpy as np
import pandas as pd
import plotly.express as px
from plotly.graph_objs import Figure

# Dummy data creation


def dummy_data_for_plot(metrics, num_days=30):
    dates = [datetime.now() - timedelta(days=i) for i in range(num_days)]
    data = []

    for metric in metrics:
        for date in dates:
            model = f"Model_{metric}"
            score = np.random.uniform(50, 55)
            data.append([date, metric, score, model])

    df = pd.DataFrame(data, columns=["date", "task", "score", "model"])
    return df


def create_metric_plot_obj_1(
    df: pd.DataFrame, metrics: list[str], title: str
) -> Figure:
    """
    Create a Plotly figure object with lines representing different metrics
    and horizontal dotted lines representing human baselines.

    :param df: The DataFrame containing the metric values, names, and dates.
    :param metrics: A list of strings representing the names of the metrics
                    to be included in the plot.
    :param title: A string representing the title of the plot.
    :return: A Plotly figure object with lines representing metrics and
             horizontal dotted lines representing human baselines.
    """

    # Filter the DataFrame based on the specified metrics
    df = df[df["task"].isin(metrics)]

    # Filter the human baselines based on the specified metrics
    # filtered_human_baselines = {k: v for k, v in HUMAN_BASELINE.items() if k in metrics}

    # Create a line figure using plotly express with specified markers and custom data
    fig = px.line(
        df,
        x="date",
        y="score",
        color="task",
        markers=True,
        custom_data=["task", "score", "model"],
        title=title,
    )

    # Update hovertemplate for better hover interaction experience
    fig.update_traces(
        hovertemplate="<br>".join(
            [
                "Model Name: %{customdata[2]}",
                "Metric Name: %{customdata[0]}",
                "Date: %{x}",
                "Metric Value: %{y}",
            ]
        )
    )

    # Update the range of the y-axis
    fig.update_layout(yaxis_range=[0, 100])

    # Create a dictionary to hold the color mapping for each metric
    metric_color_mapping = {}

    # Map each metric name to its color in the figure
    for trace in fig.data:
        metric_color_mapping[trace.name] = trace.line.color

    # Iterate over filtered human baselines and add horizontal lines to the figure
    # for metric, value in filtered_human_baselines.items():
    #     color = metric_color_mapping.get(metric, "blue")  # Retrieve color from mapping; default to blue if not found
    #     location = "top left" if metric == "HellaSwag" else "bottom left"  # Set annotation position
    #     # Add horizontal line with matched color and positioned annotation
    #     fig.add_hline(
    #         y=value,
    #         line_dash="dot",
    #         annotation_text=f"{metric} human baseline",
    #         annotation_position=location,
    #         annotation_font_size=10,
    #         annotation_font_color=color,
    #         line_color=color,
    #     )

    return fig


def dummydf():
    # data = [{"Model": "gpt-35-turbo-1106",
    #          "Agent": "prompt agent",
    #         "Opponent Model": "gpt-4",
    #          "Opponent Agent": "prompt agent",
    #          'Breakthrough': 0,
    #          'Connect Four': 0,
    #          'Blind Auction': 0,
    #          'Kuhn Poker': 0,
    #          "Liar's Dice": 0,
    #          'Negotiation': 0,
    #          'Nim': 0,
    #          'Pig': 0,
    #          'Iterated Prisoners Dilemma': 0,
    #          'Tic-Tac-Toe': 0
    #          },
    #         {"Model": "Llama-2-70b-chat-hf",
    #         "Agent": "prompt agent",
    #          "Opponent Model": "gpt-4",
    #          "Opponent Agent": "prompt agent",
    #          'Breakthrough': 1,
    #          'Connect Four': 0,
    #          'Blind Auction': 0,
    #          'Kuhn Poker': 0,
    #          "Liar's Dice": 0,
    #          'Negotiation': 0,
    #          'Nim': 0,
    #          'Pig': 0,
    #          'Iterated Prisoners Dilemma': 0,
    #          'Tic-Tac-Toe': 0
    #          },
    #         {"Model": "gpt-35-turbo-1106",
    #          "Agent": "ToT agent",
    #         "Opponent Model": "gpt-4",
    #          "Opponent Agent": "prompt agent",
    #          'Breakthrough': 0,
    #          'Connect Four': 0,
    #          'Blind Auction': 0,
    #          'Kuhn Poker': 0,
    #          "Liar's Dice": 0,
    #          'Negotiation': 0,
    #          'Nim': 0,
    #          'Pig': 0,
    #          'Iterated Prisoners Dilemma': 0,
    #          'Tic-Tac-Toe': 0
    #          },
    #         {"Model": "Llama-2-70b-chat-hf",
    #         "Agent": "CoT agent",
    #          "Opponent Model": "gpt-4",
    #          "Opponent Agent": "prompt agent",
    #          'Breakthrough': 0,
    #          'Connect Four': 0,
    #          'Blind Auction': 0,
    #          'Kuhn Poker': 0,
    #          "Liar's Dice": 0,
    #          'Negotiation': 0,
    #          'Nim': 0,
    #          'Pig': 0,
    #          'Iterated Prisoners Dilemma': 0,
    #          'Tic-Tac-Toe': 0
    #          }]
    df = pd.read_csv('./assets/gtbench_results.csv')
    return df