import * as d3 from 'd3'; import * as _ from "lodash" import * as R from 'ramda' import * as tp from '../etc/types'; import * as rsp from '../api/responses'; import '../etc/xd3' import { API } from '../api/mainApi' import { UIConfig } from '../uiConfig' import { TextTokens, LeftTextToken, RightTextToken } from './TextToken' import { AttentionHeadBox, getAttentionInfo } from './AttentionHeadBox' import { AttentionGraph } from './AttentionConnector' import { CorpusInspector } from './CorpusInspector' import { TokenWrapper, sideToLetter } from '../data/TokenWrapper' import { AttentionWrapper, makeFromMetaResponse } from '../data/AttentionCapsule' import { SimpleEventHandler } from '../etc/SimpleEventHandler' import { CorpusMatManager } from '../vis/CorpusMatManager' import { CorpusHistogram } from '../vis/CorpusHistogram' import { FaissSearchResultWrapper } from '../data/FaissSearchWrapper' import { D3Sel, Sel } from '../etc/Util'; import { from, fromEvent, interval } from 'rxjs' import { switchMap, map, tap } from 'rxjs/operators' import { BaseType } from "d3"; import { SimpleMeta } from "../etc/types"; import ChangeEvent = JQuery.ChangeEvent; function isNullToken(tok: tp.TokenEvent) { const isSomeNull = x => { return (x == null) || (x == "null") } const tokIsNull = tok == null; const tokHasNull = isSomeNull(tok.side) || isSomeNull(tok.ind) return tokIsNull || tokHasNull } function showBySide(e: tp.TokenEvent) { // Check if saved token in uiConf is null if (!isNullToken(e)) { const classSelector = e.side == "left" ? "src-idx" : "target-idx"; Sel.setHidden(".atn-curve") Sel.setVisible(`.atn-curve[${classSelector}='${e.ind}']`) } } function chooseShowBySide(savedEvent: tp.TokenEvent, newEvent: tp.TokenEvent) { if (isNullToken(savedEvent)) { showBySide(newEvent) } } function chooseShowAll(savedEvent: tp.TokenEvent) { if (isNullToken(savedEvent)) Sel.setVisible(".atn-curve") } function unselectHead(head: number) { const affectedHeads = d3.selectAll(`.att-rect[head='${head}']`); affectedHeads.classed("unselected", true) } function selectHead(head: number) { const affectedHeads = d3.selectAll(`.att-rect[head='${head}']`); affectedHeads.classed("unselected", false) } function setSelDisabled(attr: boolean, sel: D3Sel) { const val = attr ? true : null sel.attr('disabled', val) } export class MainGraphic { api: API uiConf: UIConfig attCapsule: AttentionWrapper tokCapsule: TokenWrapper sels: any // Contains initial d3 selections of objects vizs: any // Contains vis components wrapped around parent sel eventHandler: SimpleEventHandler // Orchestrates events raised from components constructor() { this.api = new API() this.uiConf = new UIConfig() this.skeletonInit() this.mainInit(); } /** * Functions that can be called without any information of a response. * * This should be called once and only once */ skeletonInit() { this.sels = { body: d3.select('body'), atnContainer: d3.select('#atn-container'), atnDisplay: d3.select("#atn-display"), modelSelector: d3.select("#model-option-selector"), corpusSelector: d3.select("#corpus-select"), atnHeads: { left: d3.select("#left-att-heads"), right: d3.select("#right-att-heads"), headInfo: d3.select("#head-info-box") .classed('mat-hover-display', true) .classed('text-center', true) .style('width', String(70) + 'px') .style('height', String(30) + 'px') .style('visibillity', 'hidden') }, form: { sentenceA: d3.select("#form-sentence-a"), button: d3.select("#update-sentence"), }, tokens: { left: d3.select("#left-tokens"), right: d3.select("#right-tokens"), }, clsToggle: d3.select("#cls-toggle").select(".switch"), layerCheckboxes: d3.select("#layer-select"), headCheckboxes: d3.select("#head-select"), contextQuery: d3.select("#search-contexts"), embeddingQuery: d3.select("#search-embeddings"), selectedHeads: d3.select("#selected-heads"), headSelectAll: d3.select("#select-all-heads"), headSelectNone: d3.select("#select-no-heads"), testCheckbox: d3.select("#simple-embed-query"), threshSlider: d3.select("#my-range"), corpusInspector: d3.select("#corpus-similar-sentences-div"), corpusMatManager: d3.select("#corpus-mat-container"), corpusMsgBox: d3.select("#corpus-msg-box"), histograms: { matchedWordDescription: d3.select("#match-kind"), matchedWord: d3.select("#matched-histogram-container"), maxAtt: d3.select("#max-att-histogram-container"), }, buttons: { killLeft: d3.select("#kill-left"), addLeft: d3.select("#minus-left"), addRight: d3.select("#plus-right"), killRight: d3.select("#kill-right"), refresh: d3.select("#mat-refresh") }, metaSelector: { matchedWord: d3.select("#matched-meta-select"), maxAtt: d3.select("#max-att-meta-select") } } this.eventHandler = new SimpleEventHandler(this.sels.body.node()); this.vizs = { leftHeads: new AttentionHeadBox(this.sels.atnHeads.left, this.eventHandler, { side: "left", }), rightHeads: new AttentionHeadBox(this.sels.atnHeads.right, this.eventHandler, { side: "right" }), tokens: { left: new LeftTextToken(this.sels.tokens.left, this.eventHandler), right: new RightTextToken(this.sels.tokens.right, this.eventHandler), }, attentionSvg: new AttentionGraph(this.sels.atnDisplay, this.eventHandler), corpusInspector: new CorpusInspector(this.sels.corpusInspector, this.eventHandler), corpusMatManager: new CorpusMatManager(this.sels.corpusMatManager, this.eventHandler, { idxs: this.uiConf.offsetIdxs() }), histograms: { matchedWord: new CorpusHistogram(this.sels.histograms.matchedWord, this.eventHandler), maxAtt: new CorpusHistogram(this.sels.histograms.maxAtt, this.eventHandler), }, } this._bindEventHandler() } private mainInit() { const self = this; this.sels.body.style("cursor", "progress") this.api.getModelDetails(this.uiConf.model()).then(md => { const val = md.payload this.uiConf.nLayers(val.nlayers).nHeads(val.nheads) this.initLayers(this.uiConf.nLayers()) this.api.getMetaAttentions(this.uiConf.model(), this.uiConf.sentence(), this.uiConf.layer()).then(attention => { const att = attention.payload; this.initFromResponse(att) // Wrap postInit into function so asynchronous call does not mess with necessary inits const postResponseDisplayCleanup = () => { this._toggleTokenSel() const toDisplay = this.uiConf.displayInspector() this._searchDisabler() if (toDisplay == 'context') { this._queryContext() } else if (toDisplay == 'embeddings') { this._queryEmbeddings() } } let normBy if ((this.uiConf.modelKind() == tp.ModelKind.Autoregressive) && (!this.uiConf.hideClsSep())) { normBy = tp.NormBy.Col } else { normBy = tp.NormBy.All } this.vizs.attentionSvg.normBy = normBy if (this.uiConf.maskInds().length > 0) { this.tokCapsule.a.maskInds = this.uiConf.maskInds() this.api.updateMaskedAttentions(this.uiConf.model(), this.tokCapsule.a, this.uiConf.sentence(), this.uiConf.layer()).then(resp => { const r = resp.payload; this.attCapsule.updateFromNormal(r, this.uiConf.hideClsSep()); this.tokCapsule.updateTokens(r) this.update() postResponseDisplayCleanup() }) } else { this.update() postResponseDisplayCleanup() } if (this.uiConf.modelKind() == tp.ModelKind.Autoregressive) { // Ensure only 1 mask ind is present for autoregressive models if (this.uiConf.hasToken()) { this.grayToggle(this.uiConf.token().ind) } self.vizs.tokens.left.options.divHover.textInfo = "Would predict next..." self.vizs.tokens.right.options.divHover.textInfo = "Would predict next..." } else { self.vizs.tokens.left.options.divHover.textInfo = "Would predict here..." self.vizs.tokens.right.options.divHover.textInfo = "Would predict here..." } this.sels.body.style("cursor", "default") }); }) } private initFromResponse(attention: tp.AttentionResponse) { this.attCapsule = makeFromMetaResponse(attention, this.uiConf.hideClsSep()) this.tokCapsule = new TokenWrapper(attention); this._staticInits() } private leaveCorpusMsg(msg: string) { this.vizs.corpusInspector.hideView() this.vizs.corpusMatManager.hideView() console.log("Running leave msg"); Sel.unhideElement(this.sels.corpusMsgBox) this.sels.corpusMsgBox.text(msg) } private _bindEventHandler() { const self = this; this.eventHandler.bind(TextTokens.events.tokenDblClick, (e) => { switch (self.uiConf.modelKind()) { case tp.ModelKind.Bidirectional: { e.sel.classed("masked-token", !e.sel.classed("masked-token")); const letter = sideToLetter(e.side, this.uiConf.attType) self.tokCapsule[letter].toggle(e.ind) self.sels.body.style("cursor", "progress") self.api.updateMaskedAttentions(this.uiConf.model(), this.tokCapsule.a, this.uiConf.sentence(), this.uiConf.layer()).then((resp: rsp.AttentionDetailsResponse) => { const r = resp.payload; self.attCapsule.updateFromNormal(r, this.uiConf.hideClsSep()); self.tokCapsule.updateTokens(r); self.uiConf.maskInds(this.tokCapsule.a.maskInds) self.update(); self.sels.body.style("cursor", "default") }) break; } case tp.ModelKind.Autoregressive: { console.log("Autoregressive model doesn't do masking"); break; } default: { console.log("What kind of model is this?"); break; } } }) this.eventHandler.bind(TextTokens.events.tokenMouseOver, (e: tp.TokenEvent) => { chooseShowBySide(this.uiConf.token(), e) }) this.eventHandler.bind(TextTokens.events.tokenMouseOut, (e) => { chooseShowAll(this.uiConf.token()) }) this.eventHandler.bind(TextTokens.events.tokenClick, (e: tp.TokenEvent) => { const tokToggle = () => { this.uiConf.toggleToken(e) this._toggleTokenSel() showBySide(e) } tokToggle() this.renderAttHead() }) this.eventHandler.bind(AttentionHeadBox.events.rowMouseOver, (e: tp.HeadBoxEvent) => { self.sels.atnHeads.headInfo.style('visibility', 'visible') }) this.eventHandler.bind(AttentionHeadBox.events.rowMouseOut, () => { self.sels.atnHeads.headInfo.style('visibility', 'hidden') // Don't do anything special on row mouse out }) this.eventHandler.bind(AttentionHeadBox.events.boxMouseOver, (e: tp.HeadBoxEvent) => { const updateMat = this.attCapsule.byHead(e.head) this.vizs.attentionSvg.data(updateMat) this.vizs.attentionSvg.update(updateMat) showBySide(this.uiConf.token()) }) this.eventHandler.bind(AttentionHeadBox.events.boxMouseOut, () => { const att = this.attCapsule.byHeads(this.uiConf.heads()) this.vizs.attentionSvg.data(att) this.vizs.attentionSvg.update(att) showBySide(this.uiConf.token()) }) this.eventHandler.bind(AttentionHeadBox.events.boxMouseMove, (e) => { const headInfo = self.sels.atnHeads.headInfo let left, top, borderRadius if (e.side == "left") { const divOffset = [12, 3] left = e.mouse[0] + e.baseX - (+headInfo.style('width').replace('px', '') + divOffset[0]) top = e.mouse[1] + e.baseY - (+headInfo.style('height').replace('px', '') + divOffset[1]) borderRadius = "8px 8px 1px 8px" } else { const divOffset = [-13, 3] left = e.mouse[0] + e.baseX + divOffset[0] top = e.mouse[1] + e.baseY - (+headInfo.style('height').replace('px', '') + divOffset[1]) borderRadius = "8px 8px 8px 1px" } headInfo .style('visibility', 'visible') .style('left', String(left) + 'px') .style('top', String(top) + 'px') .style('border-radius', borderRadius) .text(`Head: ${e.ind + 1}`) // Don't do anything special on row mouse over }) this.eventHandler.bind(AttentionHeadBox.events.boxClick, (e: { head }) => { const result = this.uiConf.toggleHead(e.head) if (result == tp.Toggled.ADDED) { selectHead(e.head) } else if (result == tp.Toggled.REMOVED) { unselectHead(e.head) } this._searchDisabler() this._renderHeadSummary(); this.renderSvg(); }) this.eventHandler.bind(CorpusMatManager.events.mouseOver, (e: { val: "pos" | "dep" | "is_ent", offset: number }) => { // Uncomment the below if you want to modify the whole column // const selector = `.inspector-cell[index-offset='${e.offset}']` // d3.selectAll(selector).classed("hovered-col", true) }) this.eventHandler.bind(CorpusMatManager.events.mouseOut, (e: { offset: number, idx: number }) => { // Uncomment the below if you want to modify the whole column // const selector = `.inspector-cell[index-offset='${e.offset}']` // d3.selectAll(selector).classed("hovered-col", false) }) this.eventHandler.bind(CorpusMatManager.events.rectMouseOver, (e: { offset: number, idx: number }) => { const row = d3.select(`.inspector-row[rownum='${e.idx}']`) const word = row.select(`.inspector-cell[index-offset='${e.offset}']`) word.classed("hovered-col", true) }) this.eventHandler.bind(CorpusMatManager.events.rectMouseOut, (e: { offset: number, idx: number }) => { const row = d3.select(`.inspector-row[rownum='${e.idx}']`) const word = row.select(`.inspector-cell[index-offset='${e.offset}']`) word.classed("hovered-col", false) }) } private _toggleTokenSel() { const e = this.uiConf.token() const alreadySelected = d3.select('.selected-token') // If no token should be selected, unselect all tokens if (!this.uiConf.hasToken()) { const newSel: d3.Selection = d3.selectAll('.selected-token') if (!newSel.empty()) newSel.classed('selected-token', false) } // Otherwise, select the indicated token else { const token2String = (e: tp.TokenEvent) => `#${e.side}-token-${e.ind}` const newSel = d3.select(token2String(e)) // Check that selection exists if (!newSel.empty()) newSel.classed('selected-token', true) } // Remove previous token selection, if any if (!alreadySelected.empty()) { alreadySelected.classed('selected-token', false) } if (this.uiConf.modelKind() == tp.ModelKind.Autoregressive) { this.grayToggle(+e.ind) this.markNextToggle(+e.ind, this.tokCapsule.a.length()) } this._searchDisabler() } /** Gray all tokens that have index greater than ind */ private grayBadToks(ind: number) { if (this.uiConf.modelKind() == tp.ModelKind.Autoregressive) { const grayToks = function (d, i) { const s = d3.select(this) s.classed("masked-token", i > ind) } d3.selectAll('.right-token').each(grayToks) d3.selectAll('.left-token').each(grayToks) } } private grayToggle(ind: number) { if (this.uiConf.hasToken()) this.grayBadToks(ind) else d3.selectAll('.token').classed('masked-token', false) } private markNextWordToks(ind: number, N: number) { const markToks = function (d, i) { const s = d3.select(this) s.classed("next-token", i == Math.min(ind + 1, N)) } d3.selectAll('.right-token').each(markToks) d3.selectAll('.left-token').each(markToks) } private markNextToggle(ind: number, N: number) { if (this.uiConf.hasToken()) this.markNextWordToks(ind, N) else d3.selectAll('.token').classed('next-token', false) } private _initModelSelection() { const self = this // Below are the available models. Will need to choose 3 to be available ONLY const data = [ { name: "bert-base-cased", kind: tp.ModelKind.Bidirectional }, { name: "bert-base-uncased", kind: tp.ModelKind.Bidirectional }, { name: "distilbert-base-uncased", kind: tp.ModelKind.Bidirectional }, { name: "distilroberta-base", kind: tp.ModelKind.Bidirectional }, // { name: "roberta-base", kind: tp.ModelKind.Bidirectional }, { name: "gpt2", kind: tp.ModelKind.Autoregressive }, // { name: "gpt2-medium", kind: tp.ModelKind.Autoregressive }, // { name: "distilgpt2", kind: tp.ModelKind.Autoregressive }, ] const names = R.map(R.prop('name'))(data) const kinds = R.map(R.prop('kind'))(data) const kindmap = R.zipObj(names, kinds) this.sels.modelSelector.selectAll('.model-option') .data(data) .join('option') .classed('model-option', true) .property('value', d => d.name) .attr("modelkind", d => d.kind) .text(d => d.name) this.sels.modelSelector.property('value', this.uiConf.model()); this.sels.modelSelector.on('change', function () { const me = d3.select(this) const mname = me.property('value') self.uiConf.model(mname); self.uiConf.modelKind(kindmap[mname]); if (kindmap[mname] == tp.ModelKind.Autoregressive) { console.log("RESETTING MASK INDS"); self.uiConf.maskInds([]) } self.mainInit(); }) } private _initCorpusSelection() { const data = [ { code: "woz", display: "Wizard of Oz" }, { code: "wiki", display: "Wikipedia" }, ] const self = this self.sels.corpusSelector.selectAll('option') .data(data) .join('option') .property('value', d => d.code) .text(d => d.display) this.sels.corpusSelector.on('change', function () { const me = d3.select(this) self.uiConf.corpus(me.property('value')) console.log(self.uiConf.corpus()); }) } private _staticInits() { this._initSentenceForm(); this._initModelSelection(); this._initCorpusSelection(); this._initQueryForm(); this._initAdder(); this._renderHeadSummary(); this._initMetaSelectors(); this._initToggle(); this.renderAttHead(); this.renderTokens(); } private _initAdder() { const updateUrlOffsetIdxs = () => { this.uiConf.offsetIdxs(this.vizs.corpusMatManager.idxs) } const fixCorpusMatHeights = () => { const newWrapped = this._wrapResults(this.vizs.corpusMatManager.data()) this.vizs.corpusMatManager.data(newWrapped.data) updateUrlOffsetIdxs() } this.sels.buttons.addRight.on('click', () => { this.vizs.corpusMatManager.addRight() updateUrlOffsetIdxs() }) this.sels.buttons.addLeft.on('click', () => { this.vizs.corpusMatManager.addLeft() updateUrlOffsetIdxs() }) this.sels.buttons.killRight.on('click', () => { this.vizs.corpusMatManager.killRight() updateUrlOffsetIdxs() }) this.sels.buttons.killLeft.on('click', () => { this.vizs.corpusMatManager.killLeft() updateUrlOffsetIdxs() }) this.sels.buttons.refresh.on('click', () => { fixCorpusMatHeights(); }) const onresize = () => { if (this.sels.corpusInspector.text() != '') fixCorpusMatHeights(); } window.onresize = onresize } private _initMetaSelectors() { this._initMatchedWordSelector(this.sels.metaSelector.matchedWord) this._initMaxAttSelector(this.sels.metaSelector.maxAtt) } private _initMaxAttSelector(sel: D3Sel) { const self = this; const chooseSelected = (value) => { const ms = sel.selectAll('label') ms.classed('active', false) const el = sel.selectAll(`label[value=${value}]`) el.classed('active', true) } chooseSelected(this.uiConf.metaMax()) const el = sel.selectAll('label') el.on('click', function () { const val = d3.select(this).attr('value'); // Do toggle sel.selectAll('.active').classed('active', false) d3.select(this).classed('active', true) self.uiConf.metaMax(val) self.vizs.histograms.maxAtt.meta(val) }) } private _initMatchedWordSelector(sel: D3Sel) { const self = this; const chooseSelected = (value) => { const ms = sel.selectAll('label') ms.classed('active', false) const el = sel.selectAll(`label[value=${value}]`) el.classed('active', true) } chooseSelected(this.uiConf.metaMatch()) const el = sel.selectAll('label') el.on('click', function () { const val = d3.select(this).attr('value') // Do toggle sel.selectAll('.active').classed('active', false) d3.select(this).classed('active', true) self.uiConf.metaMatch(val) self._updateCorpusInspectorFromMeta(val) }) } private _disableSearching(attr: boolean) { setSelDisabled(attr, this.sels.contextQuery) setSelDisabled(attr, this.sels.embeddingQuery) } private _updateCorpusInspectorFromMeta(val: tp.SimpleMeta) { this.vizs.corpusInspector.showNext(this.uiConf.showNext) this.vizs.corpusMatManager.pick(val) this.vizs.histograms.matchedWord.meta(val) } private _initSentenceForm() { const self = this; this.sels.form.sentenceA.attr('placeholder', "Enter new sentence to analyze") this.sels.form.sentenceA.attr('value', this.uiConf.sentence()) const clearInspector = () => { self.vizs.corpusMatManager.clear(); self.vizs.corpusInspector.clear(); self.vizs.histograms.matchedWord.clear(); self.vizs.histograms.maxAtt.clear(); } const submitNewSentence = () => { // replace all occurences of '#' in sentence as this causes the API to break const sentence_a: string = this.sels.form.sentenceA.property("value").replace(/\#/g, '') // Only update if the form is filled correctly if (sentence_a.length) { this.sels.body.style("cursor", "progress") this.api.getMetaAttentions(this.uiConf.model(), sentence_a, this.uiConf.layer()) .then((resp: rsp.AttentionDetailsResponse) => { const r = resp.payload this.uiConf.sentence(sentence_a) this.uiConf.rmToken(); this.attCapsule.updateFromNormal(r, this.uiConf.hideClsSep()); this.tokCapsule.updateFromResponse(r); this._toggleTokenSel(); this.update(); clearInspector(); this.sels.body.style("cursor", "default") }) } } const onEnter = R.curry((keyCode, f, event) => { const e = event || window.event; if (e.keyCode !== keyCode) return; e.preventDefault(); f(); }) const onEnterSubmit = onEnter(13, submitNewSentence) const btn = this.sels.form.button; const inputBox = this.sels.form.sentenceA; btn.on("click", submitNewSentence) inputBox.on('keypress', onEnterSubmit) } private _getSearchEmbeds() { const savedToken = this.uiConf.token(); const out = this.vizs.tokens[savedToken.side].getEmbedding(savedToken.ind) return out.embeddings } private _getSearchContext() { const savedToken = this.uiConf.token(); const out = this.vizs.tokens[savedToken.side].getEmbedding(savedToken.ind) return out.contexts } private _searchEmbeddings() { const self = this; console.log("SEARCHING EMBEDDINGS"); const embed = this._getSearchEmbeds() const layer = self.uiConf.layer() const heads = self.uiConf.heads() const k = 50 self.vizs.corpusInspector.showNext(self.uiConf.showNext) this.sels.body.style("cursor", "progress") self.api.getNearestEmbeddings(self.uiConf.model(), self.uiConf.corpus(), embed, layer, heads, k) .then((val: rsp.NearestNeighborResponse) => { if (val.status == 406) { self.leaveCorpusMsg(`Embeddings are not available for model '${self.uiConf.model()}' and corpus '${self.uiConf.corpus()}' at this time.`) } else { const v = val.payload self.vizs.corpusInspector.unhideView() self.vizs.corpusMatManager.unhideView() // Get heights of corpus inspector rows. self.vizs.corpusInspector.update(v) const wrappedVals = self._wrapResults(v) const countedVals = wrappedVals.getMatchedHistogram() const offsetVals = wrappedVals.getMaxAttHistogram() self.vizs.corpusMatManager.update(wrappedVals.data) self.sels.histograms.matchedWordDescription.text(this.uiConf.matchHistogramDescription) console.log("MATCHER: ", self.sels.histograms.matchedWord); self.vizs.histograms.matchedWord.update(countedVals) self.vizs.histograms.maxAtt.update(offsetVals) self.uiConf.displayInspector('embeddings') this._updateCorpusInspectorFromMeta(this.uiConf.metaMatch()) } this.sels.body.style("cursor", "default") }) } private _searchContext() { const self = this; console.log("SEARCHING CONTEXTS"); const context = this._getSearchContext() const layer = self.uiConf.layer() const heads = self.uiConf.heads() const k = 50 self.vizs.corpusInspector.showNext(self.uiConf.showNext) this.sels.body.style("cursor", "progress") self.api.getNearestContexts(self.uiConf.model(), self.uiConf.corpus(), context, layer, heads, k) .then((val: rsp.NearestNeighborResponse) => { // Get heights of corpus inspector rows. if (val.status == 406) { console.log("Contexts are not available!"); self.leaveCorpusMsg(`Contexts are not available for model '${self.uiConf.model()}' and corpus '${self.uiConf.corpus()}' at this time.`) } else { const v = val.payload; console.log("HIDING"); self.vizs.corpusInspector.update(v) Sel.hideElement(self.sels.corpusMsgBox) self.vizs.corpusInspector.unhideView() self.vizs.corpusMatManager.unhideView() const wrappedVals = self._wrapResults(v) const countedVals = wrappedVals.getMatchedHistogram() const offsetVals = wrappedVals.getMaxAttHistogram() self.vizs.corpusMatManager.update(wrappedVals.data) self.vizs.histograms.matchedWord.update(countedVals) self.vizs.histograms.maxAtt.update(offsetVals) self.uiConf.displayInspector('context') this._updateCorpusInspectorFromMeta(this.uiConf.metaMatch()) self.vizs.histograms.maxAtt.meta(self.uiConf.metaMax()) } this.sels.body.style("cursor", "default") }) } private _queryContext() { const self = this; if (this.uiConf.hasToken()) { this._searchContext(); } else { console.log("Was told to show inspector but was not given a selected token embedding") } } private _queryEmbeddings() { const self = this; if (this.uiConf.hasToken()) { console.log("token: ", this.uiConf.token()); this._searchEmbeddings(); } else { console.log("Was told to show inspector but was not given a selected token embedding") } } private _searchingDisabled() { return (this.uiConf.heads().length == 0) || (!this.uiConf.hasToken()) } private _searchDisabler() { this._disableSearching(this._searchingDisabled()) } private _initQueryForm() { const self = this; this._searchDisabler() this.sels.contextQuery.on("click", () => { self._queryContext() }) this.sels.embeddingQuery.on("click", () => { self._queryEmbeddings() }) } private _renderHeadSummary() { this.sels.selectedHeads .html(R.join(', ', this.uiConf.heads().map(h => h + 1))) } // Modify faiss results with corresponding heights private _wrapResults(returnedFaissResults: tp.FaissSearchResults[]) { const rows = d3.selectAll('.inspector-row') // Don't just use offsetHeight since that rounds to the nearest integer const heights = rows.nodes().map((n: HTMLElement) => n.getBoundingClientRect().height) const newVals = returnedFaissResults.map((v, i) => { return R.assoc('height', heights[i], v) }) const wrappedVals = new FaissSearchResultWrapper(newVals, this.uiConf.showNext) return wrappedVals } private initLayers(nLayers: number) { const self = this; let hasActive = false; const checkboxes = self.sels.layerCheckboxes.selectAll(".layerCheckbox") .data(_.range(0, nLayers)) .join("label") .attr("class", "btn button layerCheckbox") .classed('active', (d, i) => { // Assign to largest layer available if uiConf.layer() > new nLayers if (d == self.uiConf.layer()) { // Javascript is 0 indexed! hasActive = true; return true } if (!hasActive && d == nLayers) { self.uiConf.layer(d) hasActive = true return true } return false }) .text((d) => d + 1) .append("input") .attr("type", "radio") .attr("class", "checkbox-inline") .attr("name", "layerbox") // .attr("head", d => d) .attr("id", (d, i) => "layerCheckbox" + i) // .text((d, i) => d + " ") fromEvent(checkboxes.nodes(), 'change').pipe( tap((e: Event) => { const myData = d3.select(e.target).datum(); console.log(myData, "--- myData"); this.sels.layerCheckboxes.selectAll(".layerCheckbox") .classed('active', d => d === myData) }), map((v: Event) => +d3.select(v.target).datum()), tap(v => { console.log("New layer: ", v); self.uiConf.layer(v); self.sels.body.style("cursor", "progress"); }), switchMap((v) => from(self.api.updateMaskedAttentions(self.uiConf.model(), self.tokCapsule.a, self.uiConf.sentence(), v))) ).subscribe({ next: (resp: rsp.AttentionDetailsResponse) => { const r = resp.payload; self.attCapsule.updateFromNormal(r, this.uiConf.hideClsSep()); self.tokCapsule.updateTokens(r); self.uiConf.maskInds(self.tokCapsule.a.maskInds) self.update(); self.sels.body.style("cursor", "default") self._toggleTokenSel(); } }) const layerId = `#layerCheckbox${this.uiConf.layer()}` console.log("Layer ID: ", layerId); d3.select(layerId).attr("checked", "checked") // Init threshold stuff const dispThresh = (thresh) => Math.round(thresh * 100) d3.select('#my-range-value').text(dispThresh(self.uiConf.threshold())) this.sels.threshSlider.on("input", _.throttle(function () { const node = this; self.uiConf.threshold(+node.value / 100); d3.select('#my-range-value').text(dispThresh(self.uiConf.threshold())) self.vizs.attentionSvg.threshold(self.uiConf.threshold()) }, 100)) this.sels.headSelectAll.on("click", function () { self.uiConf.selectAllHeads(); self._searchDisabler() self.renderSvg() self.renderAttHead() }) this.sels.headSelectNone.on("click", function () { self.uiConf.selectNoHeads(); self._searchDisabler(); self.renderSvg() self.renderAttHead() Sel.setHidden(".atn-curve") }) } _initToggle() { fromEvent(this.sels.clsToggle.node(), 'input').pipe( // @ts-ignore -- TODO: FIX ! map(e => e.srcElement.checked), ).subscribe({ next: v => { this.uiConf.hideClsSep(v) this.attCapsule.zeroed(v) this.renderSvg(); this.renderAttHead(); } }) } renderAttHead() { const heads = _.range(0, this.uiConf._nHeads) const focusAtt = this.attCapsule.att const token = this.uiConf.hasToken() ? this.uiConf.token() : null //@ts-ignore const leftAttInfo = getAttentionInfo(focusAtt, heads, "left", token); //@ts-ignore const rightAttInfo = getAttentionInfo(focusAtt, heads, "right", token); this.vizs.leftHeads.options.offset = this.uiConf.offset this.vizs.leftHeads.update(leftAttInfo) this.vizs.rightHeads.update(rightAttInfo) this._renderHeadSummary(); // Make sure heads.forEach((h) => { if (this.uiConf.headSet().has(h)) { selectHead(h) } else { unselectHead(h) } }) }; renderTokens() { const left = this.tokCapsule[this.uiConf.attType[0]] const right = this.tokCapsule[this.uiConf.attType[1]] console.log("now: ", this.uiConf.offset); this.vizs.tokens.left.options.offset = this.uiConf.offset this.vizs.tokens.left.update(left.tokenData); this.vizs.tokens.left.mask(left.maskInds); this.vizs.tokens.right.update(right.tokenData); this.vizs.tokens.right.mask(right.maskInds); // displaySelectedToken } renderSvg() { const att = this.attCapsule.byHeads(this.uiConf.heads()) this.vizs.attentionSvg.options.offset = this.uiConf.offset const svg = this.vizs.attentionSvg.data(att); svg.update(att) const maxTokens = _.max([this.tokCapsule.a.length()]) const newHeight = svg.options.boxheight * maxTokens svg.height(newHeight) // Don't redisplay everything if one token is selected showBySide(this.uiConf.token()) }; render() { this.renderTokens(); this.renderSvg(); this.renderAttHead(); } update() { this.render(); } }