from functools import partial import plotly.express as px import plotly.graph_objects as go import numpy as np import gradio as gr from typing import Dict, List from src.logic.data_processing import PARTITION_OPTIONS, prepare_for_non_grouped_plotting, prepare_for_group_plotting from src.logic.graph_settings import Grouping from src.logic.utils import set_alpha from datatrove.utils.stats import MetricStatsDict def plot_scatter( data: Dict[str, MetricStatsDict], metric_name: str, log_scale_x: bool, log_scale_y: bool, normalization: bool, rounding: int, cumsum: bool, perc: bool, progress: gr.Progress, ): fig = go.Figure() data = {name: histogram for name, histogram in sorted(data.items())} for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")): histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding) x = sorted(histogram_prepared.keys()) y = [histogram_prepared[k] for k in x] if cumsum: y = np.cumsum(y).tolist() if perc: y = (np.array(y) * 100).tolist() fig.add_trace( go.Scatter( x=x, y=y, mode="lines", name=name, marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)), ) ) yaxis_title = "Frequency" if normalization else "Total" fig.update_layout( title=f"Line Plots for {metric_name}", xaxis_title=metric_name, yaxis_title=yaxis_title, xaxis_type="log" if log_scale_x and len(x) > 1 else None, yaxis_type="log" if log_scale_y and len(y) > 1 else None, width=1200, height=600, showlegend=True, ) return fig def plot_bars( data: Dict[str, MetricStatsDict], metric_name: str, top_k: int, direction: PARTITION_OPTIONS, regex: str | None, rounding: int, log_scale_x: bool, log_scale_y: bool, show_stds: bool, progress: gr.Progress, ): fig = go.Figure() x = [] y = [] for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")): x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding) fig.add_trace(go.Bar( x=x, y=y, name=f"{name} Mean", marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)), error_y=dict(type='data', array=stds, visible=show_stds) )) fig.update_layout( title=f"Bar Plots for {metric_name}", xaxis_title=metric_name, yaxis_title="Avg. value", xaxis_type="log" if log_scale_x and len(x) > 1 else None, yaxis_type="log" if log_scale_y and len(y) > 1 else None, autosize=True, width=1200, height=600, showlegend=True, ) return fig # Add any other necessary functions def plot_data( metric_data: Dict[str, MetricStatsDict], metric_name: str, normalize: bool, rounding: int, grouping: Grouping, top_n: int, direction: PARTITION_OPTIONS, group_regex: str, log_scale_x: bool, log_scale_y: bool, cdf: bool, perc: bool, show_stds: bool, ) -> tuple[go.Figure, gr.Row, str]: if grouping == "histogram": fig = plot_scatter( metric_data, metric_name, log_scale_x, log_scale_y, normalize, rounding, cdf, perc, gr.Progress(), ) min_max_hist_data = generate_min_max_hist_data(metric_data) return fig, gr.Row.update(visible=True), min_max_hist_data else: fig = plot_bars( metric_data, metric_name, top_n, direction, group_regex, rounding, log_scale_x, log_scale_y, show_stds, gr.Progress(), ) return fig, gr.Row.update(visible=True), "" def generate_min_max_hist_data(data: Dict[str, MetricStatsDict]) -> str: runs_data = { run: { "min": min(map(float, dato.keys())), "max": max(map(float, dato.keys())), } for run, dato in data.items() } runs_rows = [ f"| {run} | {values['min']:.4f} | {values['max']:.4f} |" for run, values in runs_data.items() ] header = "| Run | Min | Max |\n|-----|-----|-----|\n" return header + "\n".join(runs_rows)