blogpost-fineweb-v1 / src /plotting.js
hynky's picture
hynky HF staff
add cluster
3905983
raw
history blame
8.63 kB
const TASK_ID_TO_NAME = {
// Ablations
agg_score: "Aggregate Score",
"commonsense_qa/acc_norm": "Commonsense QA Norm",
"hellaswag/acc_norm": "HellaSwag",
"openbookqa/acc_norm": "OpenBook QA Norm",
"piqa/acc_norm": "PIQA",
"siqa/acc_norm": "Social IQA",
"winogrande/acc_norm": "WinoGrande",
"arc/acc_norm": "ARC",
"mmlu/acc_norm": "MMLU",
// Stats
ccnet: "CCNet",
};
const DATASET_ID_TO_NAME = {
pii_removed: "Fineweb",
allenai_c4_en: "C4",
"tiiuae_falcon-refinedweb_data": "RefinedWeb",
"red-pajama-v2_jsonl-deduplicated-extract": "RedPajamaV2",
"dolma-sample": "Dolma1.6",
dedup_minhash_independent_output: "Individual Dedup MinHash",
};
const DEFAULT_SETTINGS = {
slider: {
max: 30,
min: 0,
default: 0,
},
defaultMetric: "agg_score",
};
const DEFAULT_LAYOUT = {
title: {
text: "Plot Title",
font: {
size: 19,
family: "apple-system, Arial, sans-serif",
},
},
xaxis: {
title: {
text: "Training tokens (billions)",
font: {
size: 15,
family: "apple-system, Arial, sans-serif",
},
},
tickfont: {
size: 14,
family: "apple-system, Arial, sans-serif",
},
showgrid: false,
mirror: true,
ticks: "outside",
showline: true,
},
yaxis: {
title: {
text: "Agg Score",
font: {
size: 15,
family: "apple-system, Arial, sans-serif",
},
standoff: 10,
},
showgrid: false,
mirror: true,
ticks: "outside",
showline: true,
tickfont: {
size: 14,
family: "apple-system, Arial, sans-serif",
},
},
legend: {
orientation: "v",
xanchor: "right",
yanchor: "bottom",
x: 1,
y: 0,
font: {
size: 14,
family: "apple-system, Arial, sans-serif",
},
bgcolor: "rgba(0,0,0,0)",
},
margin: {
t: 30,
b: 50,
},
height: 400,
};
const getAutoRange = (traces) => {
let minX = Math.min(...traces.flatMap((trace) => trace.x));
let maxX = Math.max(...traces.flatMap((trace) => trace.x));
return [minX * 0.95, maxX * 1.05];
};
const init_ablation_plot = function () {
const plotElements = document.querySelectorAll('[id^="plot-"]');
plotElements.forEach(async (plotElement) => {
const plotName = plotElement.id.replace("plot-", "");
const indexData = await fetch(`data/plots/${plotName}/index.json`).then(
(response) => response.json()
);
const settings = _.merge({}, DEFAULT_SETTINGS, indexData.settings);
const indexMapping = indexData.files;
const { dropdown, slider, plot } = createAblationPlottingElements(
plotElement,
indexMapping,
settings
);
plot.id = `graph-${plotName}`;
dropdown.addEventListener("change", () => updatePlot(dropdown, slider));
let timeoutId;
// Debounce the slider
if (slider !== undefined) {
slider.addEventListener("input", () => {
clearTimeout(timeoutId);
timeoutId = setTimeout(() => {
updatePlot(dropdown, slider);
}, 500);
});
}
// Shared plot
Plotly.newPlot(plot, []);
async function updatePlot(dropdown, slider) {
const metricName = dropdown.value;
const sliderValue = parseInt(slider?.value ?? 0);
const metricData = await fetch(
`data/plots/${plotName}/${indexMapping[metricName]["file"]}`
).then((response) => response.json());
const traces = metricData?.traces?.[metricName] ?? [];
for (const key in metricData?.data ?? []) {
const traceData = metricData.data[key];
const y = rollingWindow(traceData.y, sliderValue);
const x = traceData.x.slice(0, y.length);
const trace = {
x: x,
y: y,
type: "scatter",
mode: "lines",
line: {
width: 2.5,
},
name: traceData.label,
};
traces.push(trace);
}
const width = plot.parentElement.offsetWidth;
const layout = _.merge(
{},
DEFAULT_LAYOUT,
{
width: width,
yaxis: { title: { text: TASK_ID_TO_NAME[metricName] } },
xaxis: {
range: settings.autoSetXRange
? getAutoRange(traces)
: undefined,
},
},
metricData.layout
);
Plotly.react(plot, traces, layout);
window.addEventListener("resize", () => {
// If the window size is smaller than 768, we don't care as it's not shown
if (window.innerWidth < 768) {
return;
}
Plotly.relayout(plot, {
width: plot.parentElement.offsetWidth,
});
});
}
// Initial plot
updatePlot(dropdown, slider);
});
};
document.addEventListener("DOMContentLoaded", () => {
init_ablation_plot();
});
const createAblationPlottingElements = (
plotElement,
indexMapping,
settings
) => {
const plot = document.createElement("figure");
const controls = document.createElement("div");
plot.classList.add("plotly");
controls.classList.add("plotly_controls");
plotElement.appendChild(plot);
plotElement.appendChild(controls);
const metricOptions = Object.keys(indexMapping).filter(
(metric) => metric in TASK_ID_TO_NAME
);
// Dropdown
const dropdownLabel = document.createElement("label");
dropdownLabel.textContent = "Metric:";
const dropdown = document.createElement("select");
dropdown.innerHTML = metricOptions
.map(
(option) =>
`<option value="${option}">${TASK_ID_TO_NAME[option]}</option>`
)
.join("");
dropdown.value = settings.defaultMetric;
const dropdownContainer = document.createElement("div");
dropdownContainer.classList.add("plotly_input_container");
dropdownContainer.appendChild(dropdownLabel);
dropdownContainer.appendChild(dropdown);
controls.appendChild(dropdownContainer);
let slider = undefined;
if (settings.slider !== null) {
const sliderLabel = document.createElement("label");
sliderLabel.textContent = "Rolling window:";
slider = document.createElement("input");
slider.type = "range";
slider.min = settings.slider.min;
slider.max = settings.slider.max;
slider.value = settings.slider.default;
// current value
const sliderValue = document.createElement("span");
sliderValue.textContent = slider.value;
slider.addEventListener("input", () => {
sliderValue.textContent = slider.value;
});
const sliderInputContainer = document.createElement("div");
sliderInputContainer.classList.add("plotly_slider");
sliderInputContainer.appendChild(slider);
sliderInputContainer.appendChild(sliderValue);
const sliderContainer = document.createElement("div");
sliderContainer.classList.add("plotly_input_container");
sliderContainer.appendChild(sliderLabel);
sliderContainer.appendChild(sliderInputContainer);
controls.appendChild(sliderContainer);
}
return { dropdown, slider, plot };
};
const rollingWindow = function (data, windowSize) {
if (windowSize === 0) {
return data;
}
const rollingData = [];
// Start at halfWindowSize to ensure we can get a full window
for (let i = windowSize; i < data.length; i++) {
const windowStart = i - windowSize;
const windowEnd = i;
const windowData = data.slice(windowStart, windowEnd);
const windowAverage =
windowData.reduce((acc, value) => acc + value, 0) /
windowData.length;
rollingData.push(windowAverage);
}
return rollingData;
};