File size: 4,010 Bytes
63858e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import * as _ from 'lodash'
import * as x_ from '../etc/_Tools'
import * as tp from '../etc/types'
import * as tf from '@tensorflow/tfjs'

/**
 * Notes:
 * 
 * - Also encapsulate the CLS/SEP info vs. no CLS/SEP info
 * - When layer format changes from list, drop the index into conf.layer
 */

const bpeTokens = ["[CLS]", "[SEP]", "<s>", "</s>", "<|endoftext|>"]
const findBadIndexes = (x: tp.FullSingleTokenInfo[]) => x_.findAllIndexes(x.map(t => t.text), (a) => _.includes(bpeTokens, a))

export function makeFromMetaResponse(r:tp.AttentionResponse, isZeroed){
    const key = 'aa' // Change this if backend response changes to be simpler
    const currPair = r[key]
    const left = <tp.FullSingleTokenInfo[]>currPair.left
    const right = <tp.FullSingleTokenInfo[]>currPair.right
    const leftZero = x_.findAllIndexes(left.map(t => t.text), (a) => _.includes(bpeTokens, a))
    const rightZero = x_.findAllIndexes(right.map(t => t.text), (a) => _.includes(bpeTokens, a))
    return new AttentionWrapper(currPair.att, [leftZero, rightZero], isZeroed)
}

export class AttentionWrapper {
    protected _att:number[][][]
    protected _attTensor:tf.Tensor3D
    protected _zeroedAttTensor:tf.Tensor3D

    badToks:[number[], number[]] // Indexes for the CLS and SEP tokens
    isZeroed: boolean
    nLayers = 12;
    nHeads = 12;

    constructor(att:number[][][], badToks:[number[], number[]]=[[],[]], isZeroed=true){
        this.init(att, badToks, isZeroed)
    }

    init(att:number[][][], badToks:[number[], number[]]=[[],[]], isZeroed) {
        this.isZeroed = isZeroed
        this._att = att;
        this._zeroedAttTensor = zeroRowCol(tf.tensor3d(att), badToks[0], badToks[1])
        this._attTensor = tf.tensor3d(att) // If I put this first, buffer modifications change this too.
        this.badToks = badToks;
    }

    updateFromNormal(r:tp.AttentionResponse, isZeroed){
        const key = 'aa' // Change this if backend response changes to be simpler
        const currPair = r[key]
        const left = <tp.FullSingleTokenInfo[]>currPair.left
        const right = <tp.FullSingleTokenInfo[]>currPair.right

        const leftZero = findBadIndexes(left)
        const rightZero = findBadIndexes(right)
        this.init(currPair.att, [leftZero, rightZero], isZeroed)
    }

    get attTensor() {
        const tens = this.isZeroed ? this._zeroedAttTensor : this._attTensor
        return tens
    }

    get att() {
        return this.attTensor.arraySync()
    }

    zeroed(): boolean
    zeroed(val:boolean): this
    zeroed(val?) {
        if (val == null) return this.isZeroed
        this.isZeroed = val
        return this
    }

    toggleZeroing() {
        this.zeroed(!this.zeroed())
    }

    protected _byHeads(heads:number[]):tf.Tensor2D {
        if (heads.length == 0) {
            return tf.zerosLike(this._byHead(0))
        }

        return (<tf.Tensor2D>this.attTensor.gather(heads, 0).sum(0))
    }

    protected _byHead(head:number):tf.Tensor2D {
        return (<tf.Tensor2D>this.attTensor.gather([head], 0).squeeze([0]))
    }

    byHeads(heads:number[]):number[][] {
        return this._byHeads(heads).arraySync()
    }

    byHead(head:number):number[][] {
        return this._byHead(head).arraySync()
    }
}

function zeroRowCol(tens:tf.Tensor3D, rows:number[], cols:number[]):tf.Tensor3D {
    let outTens = tens.clone()
    let atb = outTens.bufferSync()
    _.range(atb.shape[0]).forEach((head) => {
        _.range(atb.shape[1]).forEach((i) => {
            // Set rows to 0
            if (_.includes(rows, i)) {
                _.range(atb.shape[2]).forEach((j) => {
                    atb.set(0, head, i, j)
                })
            }

            // Set cols to 0
            _.range(atb.shape[2]).forEach((j) => {
                if (_.includes(cols, j))
                    _.range(atb.shape[1]).forEach((i) => {
                        atb.set(0, head, i, j)
                    })
            })
        })
    })

    return outTens
}