import * as d3 from "d3";
import { VComponent } from "./VisComponent";
import { SimpleEventHandler } from "../etc/SimpleEventHandler";
import { D3Sel } from "../etc/Util";
import { SVG } from "../etc/SVGplus"
import * as tf from '@tensorflow/tfjs'
import { Tensor3D } from "@tensorflow/tfjs";
// The below two (interface and function) can become a class
export type AttentionHeadBoxI = {
rows: number[][],
labels: number[],
max: number,
}
/**
* From an attention matrix selected by layer, show a summary of the attentions belonging to each head.
*
* @param headMat The matrix representing all the attentions by head (layer already selected)
* @param headList The heads that are selected
* @param side Is this the right or the left display?
* @param tokenInd If not null, select just the information from a single token across heads
* @returns Information needed to label the headbox
*/
export function getAttentionInfo(headMat: number[][][], headList: number[], side: "right" | "left" = "left", token: null | {ind: number, side: "left" | "right"}=null): AttentionHeadBoxI {
// Collect only from headlist, average each head, transpose to ease iteration
if (headList.length == 0) {
return {
rows: [[]],
labels: [],
max: 0,
}
}
let dim = null
// Only change the attention graph opposite selected token
if (token != null && (token.side != side)) {
dim = token.side == "left" ? -2 : -1 // Assign to "from" direction if "left"
}
let axis: number = side == "left" ? 2 : 1;
// average across the axis representing the attentions.
let gatheredMat = tf.tensor3d(headMat)
if (dim != null) {
gatheredMat = gatheredMat.gather([token.ind], dim)
}
let newMat = gatheredMat.gather(headList, 0).mean([axis]).transpose();
const rowInfo = newMat.arraySync();
const out: AttentionHeadBoxI = {
rows: rowInfo,
labels: headList,
max: newMat.max().arraySync(),
}
return out
}
interface CurrentOptions {
headHeight: number
headWidth: number
xPad: number
yPad: number
boxWidth: number
totalWidth: number
totalHeight: number
};
export class AttentionHeadBox extends VComponent{
css_name = '';
rowCssName = 'att-head';
boxCssName = 'att-rect';
static events = {
rowMouseOver: "AttentionHeadBox_RowMouseOver",
rowMouseOut: "AttentionHeadBox_RowMouseOut",
boxMouseOver: "AttentionHeadBox_BoxMouseOver",
boxMouseOut: "AttentionHeadBox_BoxMouseOut",
boxMouseMove: "AttentionHeadBox_BoxMouseMove",
boxClick: "AttentionHeadBox_BoxClick",
};
_data: AttentionHeadBoxI;
_current: Partial = {}
options = {
boxDim: 26,
yscale: 1, // Amount to scale boxheight to get individual heads
xscale: 0.5, // Amount to scale boxwidth to get individual heads
side: "left",
maxWidth: 200, // Maximum width of SVG
offset: 0, // Change to 1 if you desire the offset visualization for Autoregressive models
};
// D3 Components
headRows: D3Sel;
headCells: D3Sel;
opacityScale: d3.ScaleLinear;
constructor(d3Parent: D3Sel, eventHandler?: SimpleEventHandler, options: {} = {}) {
super(d3Parent, eventHandler);
this.superInitSVG(options);
this._init()
}
_init() {
this.headRows = this.base.selectAll(`.${this.rowCssName}`)
this.headCells = this.headRows.selectAll(`${this.boxCssName}`)
this.opacityScale = d3.scaleLinear().range([0, 1]);
}
private updateCurrent(): Partial {
const op = this.options
const cur = this._current
const nHeads = this._data.rows[0].length
const baseHeadWidth = op.boxDim * op.xscale
// Scale headwidth according to maximum width
const getHeadScale = (nH) => (Math.min(op.maxWidth / nH, baseHeadWidth) / baseHeadWidth) * op.xscale;
cur.headHeight = op.boxDim * op.yscale;
cur.headWidth = getHeadScale(nHeads) * op.boxDim;
cur.xPad = cur.headWidth;
cur.yPad = (op.boxDim - cur.headHeight) / 2;
const getBoxWidth = (headWidth) => {
const maxBwidth = 100;
const bwidth = this._data.rows[0].length * cur.headWidth
const scale = d3.scaleLinear
if (bwidth > maxBwidth) {
return
}
}
cur.boxWidth = (this._data.rows[0].length * cur.headWidth);
cur.totalWidth = (2 * cur.xPad) + cur.boxWidth;
cur.totalHeight = (op.boxDim * (this._data.rows.length + op.offset));
return this._current
}
private updateData() {
const op = this.options;
const self = this;
const boxEvent = (i) => { return { ind: i, side: op.side, head: self._data.labels[i] } }
const cur = this.updateCurrent()
const getBaseX = () => (self.base.node()).getBoundingClientRect().left
const getBaseY = () => (self.base.node()).getBoundingClientRect().top
this.base.html('');
this.parent
.attr("width", cur.totalWidth)
.attr("height", cur.totalHeight)
this.headRows = this.base.selectAll(`.${self.rowCssName}`)
.data(self._data.rows)
.join("g")
.attrs({
class: (d, i) => `${self.rowCssName} ${self.rowCssName}-${i}`,
transform: (d, i) => {
return SVG.translate(
{
x: cur.xPad,
y: (op.boxDim * (i + op.offset)) + cur.yPad,
})
},
width: cur.boxWidth,
height: cur.headHeight,
})
.on("mouseover", (d, i) => {
self.eventHandler.trigger(AttentionHeadBox.events.rowMouseOver, { ind: i, side: op.side })
})
.on("mouseout", (d, i) => {
self.eventHandler.trigger(AttentionHeadBox.events.rowMouseOut, { ind: i, side: op.side })
})
this.headCells = this.headRows
.selectAll(`${this.boxCssName}`)
.data(d => d)
.join('rect')
.attrs({
x: (d, i) => i * cur.headWidth,
y: 0,
class: this.boxCssName,
head: (d, i) => self._data.labels[i],
width: cur.headWidth,
height: cur.headHeight,
opacity: (d: number) => this.opacityScale(d),
fill: "blue"
})
.on("mouseover", (d, i) => {
self.eventHandler.trigger(AttentionHeadBox.events.boxMouseOver, boxEvent(i))
})
.on("mouseout", (d, i) => {
self.eventHandler.trigger(AttentionHeadBox.events.boxMouseOut, boxEvent(i))
})
.on("click", (d, i) => {
self.eventHandler.trigger(AttentionHeadBox.events.boxClick, boxEvent(i))
})
.on("mousemove", function(d, i) {
const op = self.options
const mouse = d3.mouse(self.base.node())
self.eventHandler.trigger(AttentionHeadBox.events.boxMouseMove, { ind: i, side: op.side, baseX: getBaseX(), baseY: getBaseY(), mouse: mouse })
})
.append("svg:title")
.text((d, i) => "Head " + (self._data.labels[i] + 1))
}
_wrangle(data: AttentionHeadBoxI) {
this._data = data;
this.opacityScale = this.opacityScale.domain([0, data.max])
return data;
}
_render(data: AttentionHeadBoxI) {
this.updateData();
}
}