llm-transparency-tool-demo_2
/
llm_transparency_tool
/components
/frontend
/src
/ContributionGraph.tsx
/** | |
* Copyright (c) Meta Platforms, Inc. and affiliates. | |
* All rights reserved. | |
* | |
* This source code is licensed under the license found in the | |
* LICENSE file in the root directory of this source tree. | |
*/ | |
import { | |
ComponentProps, | |
Streamlit, | |
withStreamlitConnection, | |
} from 'streamlit-component-lib' | |
import React, { useEffect, useMemo, useRef, useState } from 'react'; | |
import * as d3 from 'd3'; | |
import { | |
Label, | |
Point, | |
} from './common'; | |
import './LlmViewer.css'; | |
export const renderParams = { | |
cellH: 32, | |
cellW: 32, | |
attnSize: 8, | |
afterFfnSize: 8, | |
ffnSize: 6, | |
tokenSelectorSize: 16, | |
layerCornerRadius: 6, | |
} | |
interface Cell { | |
layer: number | |
token: number | |
} | |
enum CellItem { | |
AfterAttn = 'after_attn', | |
AfterFfn = 'after_ffn', | |
Ffn = 'ffn', | |
Original = 'original', // They will only be at level = 0 | |
} | |
interface Node { | |
cell: Cell | null | |
item: CellItem | null | |
} | |
interface NodeProps { | |
node: Node | |
pos: Point | |
isActive: boolean | |
} | |
interface EdgeRaw { | |
weight: number | |
source: string | |
target: string | |
} | |
interface Edge { | |
weight: number | |
from: Node | |
to: Node | |
fromPos: Point | |
toPos: Point | |
isSelectable: boolean | |
isFfn: boolean | |
} | |
interface Selection { | |
node: Node | null | |
edge: Edge | null | |
} | |
function tokenPointerPolygon(origin: Point) { | |
const r = renderParams.tokenSelectorSize / 2 | |
const dy = r / 2 | |
const dx = r * Math.sqrt(3.0) / 2 | |
// Draw an arrow looking down | |
return [ | |
[origin.x, origin.y + r], | |
[origin.x + dx, origin.y - dy], | |
[origin.x - dx, origin.y - dy], | |
].toString() | |
} | |
function isSameCell(cell1: Cell | null, cell2: Cell | null) { | |
if (cell1 == null || cell2 == null) { | |
return false | |
} | |
return cell1.layer === cell2.layer && cell1.token === cell2.token | |
} | |
function isSameNode(node1: Node | null, node2: Node | null) { | |
if (node1 === null || node2 === null) { | |
return false | |
} | |
return isSameCell(node1.cell, node2.cell) | |
&& node1.item === node2.item; | |
} | |
function isSameEdge(edge1: Edge | null, edge2: Edge | null) { | |
if (edge1 === null || edge2 === null) { | |
return false | |
} | |
return isSameNode(edge1.from, edge2.from) && isSameNode(edge1.to, edge2.to); | |
} | |
function nodeFromString(name: string) { | |
const match = name.match(/([AIMX])(\d+)_(\d+)/) | |
if (match == null) { | |
return { | |
cell: null, | |
item: null, | |
} | |
} | |
const [, type, layerStr, tokenStr] = match | |
const layer = +layerStr | |
const token = +tokenStr | |
const typeToCellItem = new Map<string, CellItem>([ | |
['A', CellItem.AfterAttn], | |
['I', CellItem.AfterFfn], | |
['M', CellItem.Ffn], | |
['X', CellItem.Original], | |
]) | |
return { | |
cell: { | |
layer: layer, | |
token: token, | |
}, | |
item: typeToCellItem.get(type) ?? null, | |
} | |
} | |
function isValidNode(node: Node, nLayers: number, nTokens: number) { | |
if (node.cell === null) { | |
return true | |
} | |
return node.cell.layer < nLayers && node.cell.token < nTokens | |
} | |
function isValidSelection(selection: Selection, nLayers: number, nTokens: number) { | |
if (selection.node !== null) { | |
return isValidNode(selection.node, nLayers, nTokens) | |
} | |
if (selection.edge !== null) { | |
return isValidNode(selection.edge.from, nLayers, nTokens) && | |
isValidNode(selection.edge.to, nLayers, nTokens) | |
} | |
return true | |
} | |
const ContributionGraph = ({ args }: ComponentProps) => { | |
const modelInfo = args['model_info'] | |
const tokens = args['tokens'] | |
const edgesRaw: EdgeRaw[][] = args['edges_per_token'] | |
const nLayers = modelInfo === null ? 0 : modelInfo.n_layers | |
const nTokens = tokens === null ? 0 : tokens.length | |
const [selection, setSelection] = useState<Selection>({ | |
node: null, | |
edge: null, | |
}) | |
var curSelection = selection | |
if (!isValidSelection(selection, nLayers, nTokens)) { | |
curSelection = { | |
node: null, | |
edge: null, | |
} | |
setSelection(curSelection) | |
Streamlit.setComponentValue(curSelection) | |
} | |
const [startToken, setStartToken] = useState<number>(nTokens - 1) | |
// We have startToken state var, but it won't be updated till next render, so use | |
// this var in the current render. | |
var curStartToken = startToken | |
if (startToken >= nTokens) { | |
curStartToken = nTokens - 1 | |
setStartToken(curStartToken) | |
} | |
const handleRepresentationClick = (node: Node) => { | |
const newSelection: Selection = { | |
node: node, | |
edge: null, | |
} | |
setSelection(newSelection) | |
Streamlit.setComponentValue(newSelection) | |
} | |
const handleEdgeClick = (edge: Edge) => { | |
if (!edge.isSelectable) { | |
return | |
} | |
const newSelection: Selection = { | |
node: edge.to, | |
edge: edge, | |
} | |
setSelection(newSelection) | |
Streamlit.setComponentValue(newSelection) | |
} | |
const handleTokenClick = (t: number) => { | |
setStartToken(t) | |
} | |
const [xScale, yScale] = useMemo(() => { | |
const x = d3.scaleLinear() | |
.domain([-2, nTokens - 1]) | |
.range([0, renderParams.cellW * (nTokens + 2)]) | |
const y = d3.scaleLinear() | |
.domain([-1, nLayers]) | |
.range([renderParams.cellH * (nLayers + 2), 0]) | |
return [x, y] | |
}, [nLayers, nTokens]) | |
const cells = useMemo(() => { | |
let result: Cell[] = [] | |
for (let l = 0; l < nLayers; l++) { | |
for (let t = 0; t < nTokens; t++) { | |
result.push({ | |
layer: l, | |
token: t, | |
}) | |
} | |
} | |
return result | |
}, [nLayers, nTokens]) | |
const nodeCoords = useMemo(() => { | |
let result = new Map<string, Point>() | |
const w = renderParams.cellW | |
const h = renderParams.cellH | |
for (var cell of cells) { | |
const cx = xScale(cell.token + 0.5) | |
const cy = yScale(cell.layer - 0.5) | |
result.set( | |
JSON.stringify({ cell: cell, item: CellItem.AfterAttn }), | |
{ x: cx, y: cy + h / 4 }, | |
) | |
result.set( | |
JSON.stringify({ cell: cell, item: CellItem.AfterFfn }), | |
{ x: cx, y: cy - h / 4 }, | |
) | |
result.set( | |
JSON.stringify({ cell: cell, item: CellItem.Ffn }), | |
{ x: cx + 5 * w / 16, y: cy }, | |
) | |
} | |
for (let t = 0; t < nTokens; t++) { | |
cell = { | |
layer: 0, | |
token: t, | |
} | |
const cx = xScale(cell.token + 0.5) | |
const cy = yScale(cell.layer - 1.0) | |
result.set( | |
JSON.stringify({ cell: cell, item: CellItem.Original }), | |
{ x: cx, y: cy + h / 4 }, | |
) | |
} | |
return result | |
}, [cells, nTokens, xScale, yScale]) | |
const edges: Edge[][] = useMemo(() => { | |
let result = [] | |
for (var edgeList of edgesRaw) { | |
let edgesPerStartToken = [] | |
for (var edge of edgeList) { | |
const u = nodeFromString(edge.source) | |
const v = nodeFromString(edge.target) | |
var isSelectable = ( | |
u.cell !== null && v.cell !== null && v.item === CellItem.AfterAttn | |
) | |
var isFfn = ( | |
u.cell !== null && v.cell !== null && ( | |
u.item === CellItem.Ffn || v.item === CellItem.Ffn | |
) | |
) | |
edgesPerStartToken.push({ | |
weight: edge.weight, | |
from: u, | |
to: v, | |
fromPos: nodeCoords.get(JSON.stringify(u)) ?? { 'x': 0, 'y': 0 }, | |
toPos: nodeCoords.get(JSON.stringify(v)) ?? { 'x': 0, 'y': 0 }, | |
isSelectable: isSelectable, | |
isFfn: isFfn, | |
}) | |
} | |
result.push(edgesPerStartToken) | |
} | |
return result | |
}, [edgesRaw, nodeCoords]) | |
const activeNodes = useMemo(() => { | |
let result = new Set<string>() | |
for (var edge of edges[curStartToken]) { | |
const u = JSON.stringify(edge.from) | |
const v = JSON.stringify(edge.to) | |
result.add(u) | |
result.add(v) | |
} | |
return result | |
}, [edges, curStartToken]) | |
const nodeProps = useMemo(() => { | |
let result: Array<NodeProps> = [] | |
nodeCoords.forEach((p: Point, node: string) => { | |
result.push({ | |
node: JSON.parse(node), | |
pos: p, | |
isActive: activeNodes.has(node), | |
}) | |
}) | |
return result | |
}, [nodeCoords, activeNodes]) | |
const tokenLabels: Label[] = useMemo(() => { | |
if (!tokens) { | |
return [] | |
} | |
return tokens.map((s: string, i: number) => ({ | |
text: s.replace(/ /g, 'Β·'), | |
pos: { | |
x: xScale(i + 0.5), | |
y: yScale(-1.5), | |
}, | |
})) | |
}, [tokens, xScale, yScale]) | |
const layerLabels: Label[] = useMemo(() => { | |
return Array.from(Array(nLayers).keys()).map(i => ({ | |
text: 'L' + i, | |
pos: { | |
x: xScale(-0.25), | |
y: yScale(i - 0.5), | |
}, | |
})) | |
}, [nLayers, xScale, yScale]) | |
const tokenSelectors: Array<[number, Point]> = useMemo(() => { | |
return Array.from(Array(nTokens).keys()).map(i => ([ | |
i, | |
{ | |
x: xScale(i + 0.5), | |
y: yScale(nLayers - 0.5), | |
} | |
])) | |
}, [nTokens, nLayers, xScale, yScale]) | |
const totalW = xScale(nTokens + 2) | |
const totalH = yScale(-4) | |
useEffect(() => { | |
Streamlit.setFrameHeight(totalH) | |
}, [totalH]) | |
const colorScale = d3.scaleLinear( | |
[0.0, 0.5, 1.0], | |
['#9eba66', 'darkolivegreen', 'darkolivegreen'] | |
) | |
const ffnEdgeColorScale = d3.scaleLinear( | |
[0.0, 0.5, 1.0], | |
['orchid', 'purple', 'purple'] | |
) | |
const edgeWidthScale = d3.scaleLinear([0.0, 0.5, 1.0], [2.0, 3.0, 3.0]) | |
const svgRef = useRef(null); | |
useEffect(() => { | |
const getNodeStyle = (p: NodeProps, type: string) => { | |
if (isSameNode(p.node, curSelection.node)) { | |
return 'selectable-item selection' | |
} | |
if (p.isActive) { | |
return 'selectable-item active-' + type + '-node' | |
} | |
return 'selectable-item inactive-node' | |
} | |
const svg = d3.select(svgRef.current) | |
svg.selectAll('*').remove() | |
svg | |
.selectAll('layers') | |
.data(Array.from(Array(nLayers).keys()).filter((x) => x % 2 === 1)) | |
.enter() | |
.append('rect') | |
.attr('class', 'layer-highlight') | |
.attr('x', xScale(-1.0)) | |
.attr('y', (layer) => yScale(layer)) | |
.attr('width', xScale(nTokens + 0.25) - xScale(-1.0)) | |
.attr('height', (layer) => yScale(layer) - yScale(layer + 1)) | |
.attr('rx', renderParams.layerCornerRadius) | |
svg | |
.selectAll('edges') | |
.data(edges[curStartToken]) | |
.enter() | |
.append('line') | |
.style('stroke', (edge: Edge) => { | |
if (isSameEdge(edge, curSelection.edge)) { | |
return 'orange' | |
} | |
if (edge.isFfn) { | |
return ffnEdgeColorScale(edge.weight) | |
} | |
return colorScale(edge.weight) | |
}) | |
.attr('class', (edge: Edge) => edge.isSelectable ? 'selectable-edge' : '') | |
.style('stroke-width', (edge: Edge) => edgeWidthScale(edge.weight)) | |
.attr('x1', (edge: Edge) => edge.fromPos.x) | |
.attr('y1', (edge: Edge) => edge.fromPos.y) | |
.attr('x2', (edge: Edge) => edge.toPos.x) | |
.attr('y2', (edge: Edge) => edge.toPos.y) | |
.on('click', (event: PointerEvent, edge) => { | |
handleEdgeClick(edge) | |
}) | |
svg | |
.selectAll('residual') | |
.data(nodeProps) | |
.enter() | |
.filter((p) => { | |
return p.node.item === CellItem.AfterAttn | |
|| p.node.item === CellItem.AfterFfn | |
}) | |
.append('circle') | |
.attr('class', (p) => getNodeStyle(p, 'residual')) | |
.attr('cx', (p) => p.pos.x) | |
.attr('cy', (p) => p.pos.y) | |
.attr('r', renderParams.attnSize / 2) | |
.on('click', (event: PointerEvent, p) => { | |
handleRepresentationClick(p.node) | |
}) | |
svg | |
.selectAll('ffn') | |
.data(nodeProps) | |
.enter() | |
.filter((p) => p.node.item === CellItem.Ffn && p.isActive) | |
.append('rect') | |
.attr('class', (p) => getNodeStyle(p, 'ffn')) | |
.attr('x', (p) => p.pos.x - renderParams.ffnSize / 2) | |
.attr('y', (p) => p.pos.y - renderParams.ffnSize / 2) | |
.attr('width', renderParams.ffnSize) | |
.attr('height', renderParams.ffnSize) | |
.on('click', (event: PointerEvent, p) => { | |
handleRepresentationClick(p.node) | |
}) | |
svg | |
.selectAll('token_labels') | |
.data(tokenLabels) | |
.enter() | |
.append('text') | |
.attr('x', (label: Label) => label.pos.x) | |
.attr('y', (label: Label) => label.pos.y) | |
.attr('text-anchor', 'end') | |
.attr('dominant-baseline', 'middle') | |
.attr('alignment-baseline', 'top') | |
.attr('transform', (label: Label) => | |
'rotate(-40, ' + label.pos.x + ', ' + label.pos.y + ')') | |
.text((label: Label) => label.text) | |
svg | |
.selectAll('layer_labels') | |
.data(layerLabels) | |
.enter() | |
.append('text') | |
.attr('x', (label: Label) => label.pos.x) | |
.attr('y', (label: Label) => label.pos.y) | |
.attr('text-anchor', 'middle') | |
.attr('alignment-baseline', 'middle') | |
.text((label: Label) => label.text) | |
svg | |
.selectAll('token_selectors') | |
.data(tokenSelectors) | |
.enter() | |
.append('polygon') | |
.attr('class', ([i,]) => ( | |
curStartToken === i | |
? 'selectable-item selection' | |
: 'selectable-item token-selector' | |
)) | |
.attr('points', ([, p]) => tokenPointerPolygon(p)) | |
.attr('r', renderParams.tokenSelectorSize / 2) | |
.on('click', (event: PointerEvent, [i,]) => { | |
handleTokenClick(i) | |
}) | |
}, [ | |
cells, | |
edges, | |
nodeProps, | |
tokenLabels, | |
layerLabels, | |
tokenSelectors, | |
curStartToken, | |
curSelection, | |
colorScale, | |
ffnEdgeColorScale, | |
edgeWidthScale, | |
nLayers, | |
nTokens, | |
xScale, | |
yScale | |
]) | |
return <svg ref={svgRef} width={totalW} height={totalH}></svg> | |
} | |
export default withStreamlitConnection(ContributionGraph) | |