Initial Commit π
@@ -0,0 +1,80 @@
1 |
# Code of Conduct
2 |
3 |
## Our Pledge
4 |
5 |
In the interest of fostering an open and welcoming environment, we as
6 |
contributors and maintainers pledge to make participation in our project and
7 |
our community a harassment-free experience for everyone, regardless of age, body
8 |
size, disability, ethnicity, sex characteristics, gender identity and expression,
9 |
level of experience, education, socio-economic status, nationality, personal
10 |
appearance, race, religion, or sexual identity and orientation.
11 |
12 |
## Our Standards
13 |
14 |
Examples of behavior that contributes to creating a positive environment
15 |
16 |
17 |
* Using welcoming and inclusive language
18 |
* Being respectful of differing viewpoints and experiences
19 |
* Gracefully accepting constructive criticism
20 |
* Focusing on what is best for the community
21 |
* Showing empathy towards other community members
22 |
23 |
Examples of unacceptable behavior by participants include:
24 |
25 |
* The use of sexualized language or imagery and unwelcome sexual attention or
26 |
27 |
* Trolling, insulting/derogatory comments, and personal or political attacks
28 |
* Public or private harassment
29 |
* Publishing others' private information, such as a physical or electronic
30 |
address, without explicit permission
31 |
* Other conduct which could reasonably be considered inappropriate in a
32 |
professional setting
33 |
34 |
## Our Responsibilities
35 |
36 |
Project maintainers are responsible for clarifying the standards of acceptable
37 |
behavior and are expected to take appropriate and fair corrective action in
38 |
response to any instances of unacceptable behavior.
39 |
40 |
Project maintainers have the right and responsibility to remove, edit, or
41 |
reject comments, commits, code, wiki edits, issues, and other contributions
42 |
that are not aligned to this Code of Conduct, or to ban temporarily or
43 |
permanently any contributor for other behaviors that they deem inappropriate,
44 |
threatening, offensive, or harmful.
45 |
46 |
## Scope
47 |
48 |
This Code of Conduct applies within all project spaces, and it also applies when
49 |
an individual is representing the project or its community in public spaces.
50 |
Examples of representing a project or community include using an official
51 |
project e-mail address, posting via an official social media account, or acting
52 |
as an appointed representative at an online or offline event. Representation of
53 |
a project may be further defined and clarified by project maintainers.
54 |
55 |
This Code of Conduct also applies outside the project spaces when there is a
56 |
reasonable belief that an individual's behavior may have a negative impact on
57 |
the project or its community.
58 |
59 |
## Enforcement
60 |
61 |
Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 |
reported by contacting the project team at <[email protected]>. All
63 |
complaints will be reviewed and investigated and will result in a response that
64 |
is deemed necessary and appropriate to the circumstances. The project team is
65 |
obligated to maintain confidentiality with regard to the reporter of an incident.
66 |
Further details of specific enforcement policies may be posted separately.
67 |
68 |
Project maintainers who do not follow or enforce the Code of Conduct in good
69 |
faith may face temporary or permanent repercussions as determined by other
70 |
members of the project's leadership.
71 |
72 |
## Attribution
73 |
74 |
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 |
available at
76 |
77 |
78 |
79 |
For answers to common questions about this code of conduct, see
80 |
@@ -0,0 +1,399 @@
1 |
Attribution-NonCommercial 4.0 International
2 |
3 |
4 |
5 |
Creative Commons Corporation ("Creative Commons") is not a law firm and
6 |
does not provide legal services or legal advice. Distribution of
7 |
Creative Commons public licenses does not create a lawyer-client or
8 |
other relationship. Creative Commons makes its licenses and related
9 |
information available on an "as-is" basis. Creative Commons gives no
10 |
warranties regarding its licenses, any material licensed under their
11 |
terms and conditions, or any related information. Creative Commons
12 |
disclaims all liability for damages resulting from their use to the
13 |
fullest extent possible.
14 |
15 |
Using Creative Commons Public Licenses
16 |
17 |
Creative Commons public licenses provide a standard set of terms and
18 |
conditions that creators and other rights holders may use to share
19 |
original works of authorship and other material subject to copyright
20 |
and certain other rights specified in the public license below. The
21 |
following considerations are for informational purposes only, are not
22 |
exhaustive, and do not form part of our licenses.
23 |
24 |
Considerations for licensors: Our public licenses are
25 |
intended for use by those authorized to give the public
26 |
permission to use material in ways otherwise restricted by
27 |
copyright and certain other rights. Our licenses are
28 |
irrevocable. Licensors should read and understand the terms
29 |
and conditions of the license they choose before applying it.
30 |
Licensors should also secure all rights necessary before
31 |
applying our licenses so that the public can reuse the
32 |
material as expected. Licensors should clearly mark any
33 |
material not subject to the license. This includes other CC-
34 |
licensed material, or material used under an exception or
35 |
limitation to copyright. More considerations for licensors:
36 |
37 |
38 |
Considerations for the public: By using one of our public
39 |
licenses, a licensor grants the public permission to use the
40 |
licensed material under specified terms and conditions. If
41 |
the licensor's permission is not necessary for any reason--for
42 |
example, because of any applicable exception or limitation to
43 |
copyright--then that use is not regulated by the license. Our
44 |
licenses grant only permissions under copyright and certain
45 |
other rights that a licensor has authority to grant. Use of
46 |
the licensed material may still be restricted for other
47 |
reasons, including because others have copyright or other
48 |
rights in the material. A licensor may make special requests,
49 |
such as asking that all changes be marked or described.
50 |
Although not required by our licenses, you are encouraged to
51 |
respect those requests where reasonable. More_considerations
52 |
for the public:
53 |
54 |
55 |
56 |
57 |
Creative Commons Attribution-NonCommercial 4.0 International Public
58 |
59 |
60 |
By exercising the Licensed Rights (defined below), You accept and agree
61 |
to be bound by the terms and conditions of this Creative Commons
62 |
Attribution-NonCommercial 4.0 International Public License ("Public
63 |
License"). To the extent this Public License may be interpreted as a
64 |
contract, You are granted the Licensed Rights in consideration of Your
65 |
acceptance of these terms and conditions, and the Licensor grants You
66 |
such rights in consideration of benefits the Licensor receives from
67 |
making the Licensed Material available under these terms and
68 |
69 |
70 |
Section 1 -- Definitions.
71 |
72 |
a. Adapted Material means material subject to Copyright and Similar
73 |
Rights that is derived from or based upon the Licensed Material
74 |
and in which the Licensed Material is translated, altered,
75 |
arranged, transformed, or otherwise modified in a manner requiring
76 |
permission under the Copyright and Similar Rights held by the
77 |
Licensor. For purposes of this Public License, where the Licensed
78 |
Material is a musical work, performance, or sound recording,
79 |
Adapted Material is always produced where the Licensed Material is
80 |
synched in timed relation with a moving image.
81 |
82 |
b. Adapter's License means the license You apply to Your Copyright
83 |
and Similar Rights in Your contributions to Adapted Material in
84 |
accordance with the terms and conditions of this Public License.
85 |
86 |
c. Copyright and Similar Rights means copyright and/or similar rights
87 |
closely related to copyright including, without limitation,
88 |
performance, broadcast, sound recording, and Sui Generis Database
89 |
Rights, without regard to how the rights are labeled or
90 |
categorized. For purposes of this Public License, the rights
91 |
specified in Section 2(b)(1)-(2) are not Copyright and Similar
92 |
93 |
d. Effective Technological Measures means those measures that, in the
94 |
absence of proper authority, may not be circumvented under laws
95 |
fulfilling obligations under Article 11 of the WIPO Copyright
96 |
Treaty adopted on December 20, 1996, and/or similar international
97 |
98 |
99 |
e. Exceptions and Limitations means fair use, fair dealing, and/or
100 |
any other exception or limitation to Copyright and Similar Rights
101 |
that applies to Your use of the Licensed Material.
102 |
103 |
f. Licensed Material means the artistic or literary work, database,
104 |
or other material to which the Licensor applied this Public
105 |
106 |
107 |
g. Licensed Rights means the rights granted to You subject to the
108 |
terms and conditions of this Public License, which are limited to
109 |
all Copyright and Similar Rights that apply to Your use of the
110 |
Licensed Material and that the Licensor has authority to license.
111 |
112 |
h. Licensor means the individual(s) or entity(ies) granting rights
113 |
under this Public License.
114 |
115 |
i. NonCommercial means not primarily intended for or directed towards
116 |
commercial advantage or monetary compensation. For purposes of
117 |
this Public License, the exchange of the Licensed Material for
118 |
other material subject to Copyright and Similar Rights by digital
119 |
file-sharing or similar means is NonCommercial provided there is
120 |
no payment of monetary compensation in connection with the
121 |
122 |
123 |
j. Share means to provide material to the public by any means or
124 |
process that requires permission under the Licensed Rights, such
125 |
as reproduction, public display, public performance, distribution,
126 |
dissemination, communication, or importation, and to make material
127 |
available to the public including in ways that members of the
128 |
public may access the material from a place and at a time
129 |
individually chosen by them.
130 |
131 |
k. Sui Generis Database Rights means rights other than copyright
132 |
resulting from Directive 96/9/EC of the European Parliament and of
133 |
the Council of 11 March 1996 on the legal protection of databases,
134 |
as amended and/or succeeded, as well as other essentially
135 |
equivalent rights anywhere in the world.
136 |
137 |
l. You means the individual or entity exercising the Licensed Rights
138 |
under this Public License. Your has a corresponding meaning.
139 |
140 |
Section 2 -- Scope.
141 |
142 |
a. License grant.
143 |
144 |
1. Subject to the terms and conditions of this Public License,
145 |
the Licensor hereby grants You a worldwide, royalty-free,
146 |
non-sublicensable, non-exclusive, irrevocable license to
147 |
exercise the Licensed Rights in the Licensed Material to:
148 |
149 |
a. reproduce and Share the Licensed Material, in whole or
150 |
in part, for NonCommercial purposes only; and
151 |
152 |
b. produce, reproduce, and Share Adapted Material for
153 |
NonCommercial purposes only.
154 |
155 |
2. Exceptions and Limitations. For the avoidance of doubt, where
156 |
Exceptions and Limitations apply to Your use, this Public
157 |
License does not apply, and You do not need to comply with
158 |
its terms and conditions.
159 |
160 |
3. Term. The term of this Public License is specified in Section
161 |
162 |
163 |
4. Media and formats; technical modifications allowed. The
164 |
Licensor authorizes You to exercise the Licensed Rights in
165 |
all media and formats whether now known or hereafter created,
166 |
and to make technical modifications necessary to do so. The
167 |
Licensor waives and/or agrees not to assert any right or
168 |
authority to forbid You from making technical modifications
169 |
necessary to exercise the Licensed Rights, including
170 |
technical modifications necessary to circumvent Effective
171 |
Technological Measures. For purposes of this Public License,
172 |
simply making modifications authorized by this Section 2(a)
173 |
(4) never produces Adapted Material.
174 |
175 |
5. Downstream recipients.
176 |
177 |
a. Offer from the Licensor -- Licensed Material. Every
178 |
recipient of the Licensed Material automatically
179 |
receives an offer from the Licensor to exercise the
180 |
Licensed Rights under the terms and conditions of this
181 |
Public License.
182 |
183 |
b. No downstream restrictions. You may not offer or impose
184 |
any additional or different terms or conditions on, or
185 |
apply any Effective Technological Measures to, the
186 |
Licensed Material if doing so restricts exercise of the
187 |
Licensed Rights by any recipient of the Licensed
188 |
189 |
190 |
6. No endorsement. Nothing in this Public License constitutes or
191 |
may be construed as permission to assert or imply that You
192 |
are, or that Your use of the Licensed Material is, connected
193 |
with, or sponsored, endorsed, or granted official status by,
194 |
the Licensor or others designated to receive attribution as
195 |
provided in Section 3(a)(1)(A)(i).
196 |
197 |
b. Other rights.
198 |
199 |
1. Moral rights, such as the right of integrity, are not
200 |
licensed under this Public License, nor are publicity,
201 |
privacy, and/or other similar personality rights; however, to
202 |
the extent possible, the Licensor waives and/or agrees not to
203 |
assert any such rights held by the Licensor to the limited
204 |
extent necessary to allow You to exercise the Licensed
205 |
Rights, but not otherwise.
206 |
207 |
2. Patent and trademark rights are not licensed under this
208 |
Public License.
209 |
210 |
3. To the extent possible, the Licensor waives any right to
211 |
collect royalties from You for the exercise of the Licensed
212 |
Rights, whether directly or through a collecting society
213 |
under any voluntary or waivable statutory or compulsory
214 |
licensing scheme. In all other cases the Licensor expressly
215 |
reserves any right to collect such royalties, including when
216 |
the Licensed Material is used other than for NonCommercial
217 |
218 |
219 |
Section 3 -- License Conditions.
220 |
221 |
Your exercise of the Licensed Rights is expressly made subject to the
222 |
following conditions.
223 |
224 |
a. Attribution.
225 |
226 |
1. If You Share the Licensed Material (including in modified
227 |
form), You must:
228 |
229 |
a. retain the following if it is supplied by the Licensor
230 |
with the Licensed Material:
231 |
232 |
i. identification of the creator(s) of the Licensed
233 |
Material and any others designated to receive
234 |
attribution, in any reasonable manner requested by
235 |
the Licensor (including by pseudonym if
236 |
237 |
238 |
ii. a copyright notice;
239 |
240 |
iii. a notice that refers to this Public License;
241 |
242 |
iv. a notice that refers to the disclaimer of
243 |
244 |
245 |
v. a URI or hyperlink to the Licensed Material to the
246 |
extent reasonably practicable;
247 |
248 |
b. indicate if You modified the Licensed Material and
249 |
retain an indication of any previous modifications; and
250 |
251 |
c. indicate the Licensed Material is licensed under this
252 |
Public License, and include the text of, or the URI or
253 |
hyperlink to, this Public License.
254 |
255 |
2. You may satisfy the conditions in Section 3(a)(1) in any
256 |
reasonable manner based on the medium, means, and context in
257 |
which You Share the Licensed Material. For example, it may be
258 |
reasonable to satisfy the conditions by providing a URI or
259 |
hyperlink to a resource that includes the required
260 |
261 |
262 |
3. If requested by the Licensor, You must remove any of the
263 |
information required by Section 3(a)(1)(A) to the extent
264 |
reasonably practicable.
265 |
266 |
4. If You Share Adapted Material You produce, the Adapter's
267 |
License You apply must not prevent recipients of the Adapted
268 |
Material from complying with this Public License.
269 |
270 |
Section 4 -- Sui Generis Database Rights.
271 |
272 |
Where the Licensed Rights include Sui Generis Database Rights that
273 |
apply to Your use of the Licensed Material:
274 |
275 |
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276 |
to extract, reuse, reproduce, and Share all or a substantial
277 |
portion of the contents of the database for NonCommercial purposes
278 |
279 |
280 |
b. if You include all or a substantial portion of the database
281 |
contents in a database in which You have Sui Generis Database
282 |
Rights, then the database in which You have Sui Generis Database
283 |
Rights (but not its individual contents) is Adapted Material; and
284 |
285 |
c. You must comply with the conditions in Section 3(a) if You Share
286 |
all or a substantial portion of the contents of the database.
287 |
288 |
For the avoidance of doubt, this Section 4 supplements and does not
289 |
replace Your obligations under this Public License where the Licensed
290 |
Rights include other Copyright and Similar Rights.
291 |
292 |
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
c. The disclaimer of warranties and limitation of liability provided
316 |
above shall be interpreted in a manner that, to the extent
317 |
possible, most closely approximates an absolute disclaimer and
318 |
waiver of all liability.
319 |
320 |
Section 6 -- Term and Termination.
321 |
322 |
a. This Public License applies for the term of the Copyright and
323 |
Similar Rights licensed here. However, if You fail to comply with
324 |
this Public License, then Your rights under this Public License
325 |
terminate automatically.
326 |
327 |
b. Where Your right to use the Licensed Material has terminated under
328 |
Section 6(a), it reinstates:
329 |
330 |
1. automatically as of the date the violation is cured, provided
331 |
it is cured within 30 days of Your discovery of the
332 |
violation; or
333 |
334 |
2. upon express reinstatement by the Licensor.
335 |
336 |
For the avoidance of doubt, this Section 6(b) does not affect any
337 |
right the Licensor may have to seek remedies for Your violations
338 |
of this Public License.
339 |
340 |
c. For the avoidance of doubt, the Licensor may also offer the
341 |
Licensed Material under separate terms or conditions or stop
342 |
distributing the Licensed Material at any time; however, doing so
343 |
will not terminate this Public License.
344 |
345 |
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346 |
347 |
348 |
Section 7 -- Other Terms and Conditions.
349 |
350 |
a. The Licensor shall not be bound by any additional or different
351 |
terms or conditions communicated by You unless expressly agreed.
352 |
353 |
b. Any arrangements, understandings, or agreements regarding the
354 |
Licensed Material not stated herein are separate from and
355 |
independent of the terms and conditions of this Public License.
356 |
357 |
Section 8 -- Interpretation.
358 |
359 |
a. For the avoidance of doubt, this Public License does not, and
360 |
shall not be interpreted to, reduce, limit, restrict, or impose
361 |
conditions on any use of the Licensed Material that could lawfully
362 |
be made without permission under this Public License.
363 |
364 |
b. To the extent possible, if any provision of this Public License is
365 |
deemed unenforceable, it shall be automatically reformed to the
366 |
minimum extent necessary to make it enforceable. If the provision
367 |
cannot be reformed, it shall be severed from this Public License
368 |
without affecting the enforceability of the remaining terms and
369 |
370 |
371 |
c. No term or condition of this Public License will be waived and no
372 |
failure to comply consented to unless expressly agreed to by the
373 |
374 |
375 |
d. Nothing in this Public License constitutes or may be interpreted
376 |
as a limitation upon, or waiver of, any privileges and immunities
377 |
that apply to the Licensor or You, including from the legal
378 |
processes of any jurisdiction or authority.
379 |
380 |
381 |
382 |
Creative Commons is not a party to its public
383 |
licenses. Notwithstanding, Creative Commons may elect to apply one of
384 |
its public licenses to material it publishes and in those instances
385 |
will be considered the βLicensor.β The text of the Creative Commons
386 |
public licenses is dedicated to the public domain under the CC0 Public
387 |
Domain Dedication. Except for the limited purpose of indicating that
388 |
material is shared under a Creative Commons public license or as
389 |
otherwise permitted by the Creative Commons policies published at
390 |
+, Creative Commons does not authorize the
391 |
use of the trademark "Creative Commons" or any other trademark or logo
392 |
of Creative Commons without its prior written consent including,
393 |
without limitation, in connection with any unauthorized modifications
394 |
to any of its public licenses or any other arrangements,
395 |
understandings, or agreements concerning use of licensed material. For
396 |
the avoidance of doubt, this paragraph does not form part of the
397 |
public licenses.
398 |
399 |
Creative Commons may be contacted at
@@ -0,0 +1,88 @@
1 |
2 |
<img width="500" alt="LLM Transparency Tool" src="">
3 |
4 |
5 |
<img width="832" alt="screenshot" src="">
6 |
7 |
8 |
## Key functionality
9 |
10 |
* Choose your model, choose or add your prompt, run the inference.
11 |
* Browse contribution graph.
12 |
* Select the token to build the graph from.
13 |
* Tune the contribution threshold.
14 |
* Select representation of any token after any block.
15 |
* For the representation, see its projection to the output vocabulary, see which tokens
16 |
were promoted/suppressed but the previous block.
17 |
* The following things are clickable:
18 |
* Edges. That shows more info about the contributing attention head.
19 |
* Heads when an edge is selected. You can see what this head is promoting/suppressing.
20 |
* FFN blocks (little squares on the graph).
21 |
* Neurons when an FFN block is selected.
22 |
23 |
24 |
## Installation
25 |
26 |
### Dockerized running
27 |
28 |
# From the repository root directory
29 |
docker build -t llm_transparency_tool .
30 |
docker run --rm -p 7860:7860 llm_transparency_tool
31 |
32 |
33 |
### Local Installation
34 |
35 |
36 |
37 |
# download
38 |
git clone [email protected]:facebookresearch/llm-transparency-tool.git
39 |
cd llm-transparency-tool
40 |
41 |
# install the necessary packages
42 |
conda env create --name llmtt -f env.yaml
43 |
# install the `llm_transparency_tool` package
44 |
pip install -e .
45 |
46 |
# now, we need to build the frontend
47 |
# don't worry, even `yarn` comes preinstalled by `env.yaml`
48 |
cd llm_transparency_tool/components/frontend
49 |
yarn install
50 |
yarn build
51 |
52 |
53 |
### Launch
54 |
55 |
56 |
streamlit run llm_transparency_tool/server/ -- config/local.json
57 |
58 |
59 |
60 |
## Adding support for your LLM
61 |
62 |
Initially, the tool allows you to select from just a handful of models. Here are the
63 |
options you can try for using your model in the tool, from least to most
64 |
65 |
66 |
67 |
### The model is already supported by TransformerLens
68 |
69 |
Full list of models is [here](
70 |
In this case, the model can be added to the configuration json file.
71 |
72 |
73 |
### Tuned version of a model supported by TransformerLens
74 |
75 |
Add the official name of the model to the config along with the location to read the
76 |
weights from.
77 |
78 |
79 |
### The model is not supported by TransformerLens
80 |
81 |
In this case the UI wouldn't know how to create proper hooks for the model. You'd need
82 |
to implement your version of [TransparentLlm](./llm_transparency_tool/models/ class and alter the
83 |
Streamlit app to use your implementation.
84 |
85 |
86 |
## License
87 |
This code is made available under a [CC BY-NC 4.0]( license, as found in the LICENSE file.
88 |
However you may have other legal obligations that govern your use of other content, such as the terms of service for third-party models.
@@ -0,0 +1,13 @@
1 |
2 |
"allow_loading_dataset_files": false,
3 |
"max_user_string_length": 100,
4 |
"preloaded_dataset_filename": "sample_input.txt",
5 |
"debug": false,
6 |
"demo_mode": true,
7 |
"models": {
8 |
"facebook/opt-125m": null,
9 |
"gpt2": null,
10 |
"distilgpt2": null
11 |
12 |
"default_model": "gpt2"
13 |
@@ -0,0 +1,25 @@
1 |
2 |
"allow_loading_dataset_files": true,
3 |
"preloaded_dataset_filename": "sample_input.txt",
4 |
"debug": true,
5 |
"models": {
6 |
"": null,
7 |
"facebook/opt-125m": null,
8 |
"facebook/opt-1.3b": null,
9 |
"facebook/opt-2.7b": null,
10 |
"facebook/opt-6.7b": null,
11 |
"facebook/opt-13b": null,
12 |
"facebook/opt-30b": null,
13 |
"meta-llama/Llama-2-7b-hf": null,
14 |
"meta-llama/Llama-2-7b-chat-hf": null,
15 |
"meta-llama/Llama-2-13b-hf": null,
16 |
"meta-llama/Llama-2-13b-chat-hf": null,
17 |
"gpt2": null,
18 |
"gpt2-medium": null,
19 |
"gpt2-large": null,
20 |
"gpt2-xl": null,
21 |
"distilgpt2": null
22 |
23 |
"default_model": "distilgpt2",
24 |
"demo_mode": false
25 |
@@ -0,0 +1,47 @@
1 |
2 |
"allow_loading_dataset_files": true,
3 |
"preloaded_dataset_filename": "sample_input.txt",
4 |
"debug": true,
5 |
"models": {
6 |
"": null,
7 |
8 |
"gpt2": null,
9 |
"distilgpt2": null,
10 |
"facebook/opt-125m": null,
11 |
"facebook/opt-1.3b": null,
12 |
"EleutherAI/gpt-neo-125M": null,
13 |
"Qwen/Qwen-1_8B": null,
14 |
"Qwen/Qwen1.5-0.5B": null,
15 |
"Qwen/Qwen1.5-0.5B-Chat": null,
16 |
"Qwen/Qwen1.5-1.8B": null,
17 |
"Qwen/Qwen1.5-1.8B-Chat": null,
18 |
"microsoft/phi-1": null,
19 |
"microsoft/phi-1_5": null,
20 |
"microsoft/phi-2": null,
21 |
22 |
"meta-llama/Llama-2-7b-hf": null,
23 |
"meta-llama/Llama-2-7b-chat-hf": null,
24 |
25 |
"meta-llama/Llama-2-13b-hf": null,
26 |
"meta-llama/Llama-2-13b-chat-hf": null,
27 |
28 |
29 |
"gpt2-medium": null,
30 |
"gpt2-large": null,
31 |
"gpt2-xl": null,
32 |
33 |
"mistralai/Mistral-7B-v0.1": null,
34 |
"mistralai/Mistral-7B-Instruct-v0.1": null,
35 |
"mistralai/Mistral-7B-Instruct-v0.2": null,
36 |
37 |
"google/gemma-7b": null,
38 |
"google/gemma-2b": null,
39 |
40 |
"facebook/opt-2.7b": null,
41 |
"facebook/opt-6.7b": null,
42 |
"facebook/opt-13b": null,
43 |
"facebook/opt-30b": null
44 |
45 |
"default_model": "",
46 |
"demo_mode": false
47 |
@@ -0,0 +1,27 @@
1 |
name: llmtt
2 |
3 |
- pytorch
4 |
- nvidia
5 |
- conda-forge
6 |
7 |
- python
8 |
- pytorch
9 |
- pytorch-cuda=11.8
10 |
- nodejs
11 |
- yarn
12 |
- pip
13 |
- pip:
14 |
- datasets
15 |
- einops
16 |
- fancy_einsum
17 |
- jaxtyping
18 |
- networkx
19 |
- plotly
20 |
- pyinstrument
21 |
- setuptools
22 |
- streamlit
23 |
- streamlit_extras
24 |
- tokenizers
25 |
- transformer_lens
26 |
- transformers
27 |
- pytest # fixes wrong dependencies of transformer_lens
@@ -0,0 +1,111 @@
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import os
8 |
from typing import List, Optional
9 |
10 |
import networkx as nx
11 |
import streamlit.components.v1 as components
12 |
13 |
from llm_transparency_tool.models.transparent_llm import ModelInfo
14 |
from llm_transparency_tool.server.graph_selection import GraphSelection, UiGraphNode
15 |
16 |
17 |
18 |
19 |
parent_dir = os.path.dirname(os.path.abspath(__file__))
20 |
config = {
21 |
"path": os.path.join(parent_dir, "frontend/build"),
22 |
23 |
24 |
config = {
25 |
"url": "http://localhost:3001",
26 |
27 |
28 |
_component_func = components.declare_component("contribution_graph", **config)
29 |
30 |
31 |
def is_node_valid(node: UiGraphNode, n_layers: int, n_tokens: int):
32 |
return node.layer < n_layers and node.token < n_tokens
33 |
34 |
35 |
def is_selection_valid(s: GraphSelection, n_layers: int, n_tokens: int):
36 |
if not s:
37 |
return True
38 |
if s.node:
39 |
if not is_node_valid(s.node, n_layers, n_tokens):
40 |
return False
41 |
if s.edge:
42 |
for node in [s.edge.source,]:
43 |
if not is_node_valid(node, n_layers, n_tokens):
44 |
return False
45 |
return True
46 |
47 |
48 |
def contribution_graph(
49 |
model_info: ModelInfo,
50 |
tokens: List[str],
51 |
graphs: List[nx.Graph],
52 |
key: str,
53 |
) -> Optional[GraphSelection]:
54 |
"""Create a new instance of contribution graph.
55 |
56 |
Returns selected graph node or None if nothing was selected.
57 |
58 |
assert len(tokens) == len(graphs)
59 |
60 |
result = _component_func(
61 |
62 |
63 |
64 |
edges_per_token=[nx.node_link_data(g)["links"] for g in graphs],
65 |
66 |
67 |
68 |
69 |
selection = GraphSelection.from_json(result)
70 |
71 |
n_tokens = len(tokens)
72 |
n_layers = model_info.n_layers
73 |
# We need this extra protection because even though the component has to check for
74 |
# the validity of the selection, sometimes it allows invalid output. It's some
75 |
# unexpected effect that has something to do with React and how the output value is
76 |
# set for the component.
77 |
if not is_selection_valid(selection, n_layers, n_tokens):
78 |
selection = None
79 |
80 |
return selection
81 |
82 |
83 |
def selector(
84 |
items: List[str],
85 |
indices: List[int],
86 |
temperatures: Optional[List[float]],
87 |
preselected_index: Optional[int],
88 |
key: str,
89 |
) -> Optional[int]:
90 |
"""Create a new instance of selector.
91 |
92 |
Returns selected item index.
93 |
94 |
n = len(items)
95 |
assert n == len(indices)
96 |
items = [{"index": i, "text": s} for s, i in zip(items, indices)]
97 |
98 |
if temperatures is not None:
99 |
assert n == len(temperatures)
100 |
for i, t in enumerate(temperatures):
101 |
items[i]["temperature"] = t
102 |
103 |
result = _component_func(
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
return None if result is None else int(result)
@@ -0,0 +1,5 @@
1 |
2 |
"endOfLine": "lf",
3 |
"semi": false,
4 |
"trailingComma": "es5"
5 |
@@ -0,0 +1,39 @@
1 |
2 |
"name": "contribution_graph",
3 |
"version": "0.1.0",
4 |
"private": true,
5 |
"dependencies": {
6 |
"@types/d3": "^7.4.0",
7 |
"d3": "^7.8.5",
8 |
"react": "^18.2.0",
9 |
"react-dom": "^18.2.0",
10 |
"streamlit-component-lib": "^2.0.0"
11 |
12 |
"scripts": {
13 |
"start": "react-scripts start",
14 |
"build": "react-scripts build",
15 |
"test": "react-scripts test",
16 |
"eject": "react-scripts eject"
17 |
18 |
"browserslist": {
19 |
"production": [
20 |
21 |
"not dead",
22 |
"not op_mini all"
23 |
24 |
"development": [
25 |
"last 1 chrome version",
26 |
"last 1 firefox version",
27 |
"last 1 safari version"
28 |
29 |
30 |
"homepage": ".",
31 |
"devDependencies": {
32 |
"@types/node": "^20.11.17",
33 |
"@types/react": "^18.2.55",
34 |
"@types/react-dom": "^18.2.19",
35 |
"eslint-config-react-app": "^7.0.1",
36 |
"react-scripts": "^5.0.1",
37 |
"typescript": "^5.3.3"
38 |
39 |
9 |
import {
10 |
11 |
12 |
13 |
} from 'streamlit-component-lib'
14 |
import React, { useEffect, useMemo, useRef, useState } from 'react';
15 |
import * as d3 from 'd3';
16 |
17 |
import {
18 |
19 |
20 |
} from './common';
21 |
import './LlmViewer.css';
export const renderParams = {
24 |
cellH: 32,
25 |
cellW: 32,
26 |
attnSize: 8,
27 |
afterFfnSize: 8,
28 |
ffnSize: 6,
29 |
tokenSelectorSize: 16,
30 |
layerCornerRadius: 6,
31 |
32 |
33 |
interface Cell {
34 |
layer: number
35 |
token: number
36 |
37 |
38 |
enum CellItem {
39 |
AfterAttn = 'after_attn',
40 |
AfterFfn = 'after_ffn',
41 |
Ffn = 'ffn',
42 |
Original = 'original', // They will only be at level = 0
43 |
44 |
45 |
interface Node {
46 |
cell: Cell | null
47 |
item: CellItem | null
48 |
49 |
50 |
interface NodeProps {
51 |
node: Node
52 |
pos: Point
53 |
isActive: boolean
54 |
55 |
56 |
interface EdgeRaw {
57 |
weight: number
58 |
source: string
59 |
target: string
60 |
61 |
62 |
interface Edge {
63 |
weight: number
64 |
from: Node
65 |
to: Node
66 |
fromPos: Point
67 |
toPos: Point
68 |
isSelectable: boolean
69 |
isFfn: boolean
70 |
71 |
72 |
interface Selection {
73 |
node: Node | null
74 |
edge: Edge | null
75 |
76 |
77 |
function tokenPointerPolygon(origin: Point) {
78 |
const r = renderParams.tokenSelectorSize / 2
79 |
const dy = r / 2
80 |
const dx = r * Math.sqrt(3.0) / 2
81 |
// Draw an arrow looking down
82 |
return [
83 |
[origin.x, origin.y + r],
84 |
[origin.x + dx, origin.y - dy],
85 |
[origin.x - dx, origin.y - dy],
86 |
87 |
88 |
89 |
function isSameCell(cell1: Cell | null, cell2: Cell | null) {
90 |
if (cell1 == null || cell2 == null) {
91 |
return false
92 |
93 |
return cell1.layer === cell2.layer && cell1.token === cell2.token
94 |
95 |
96 |
function isSameNode(node1: Node | null, node2: Node | null) {
97 |
if (node1 === null || node2 === null) {
98 |
return false
99 |
100 |
return isSameCell(node1.cell, node2.cell)
101 |
&& node1.item === node2.item;
102 |
103 |
104 |
function isSameEdge(edge1: Edge | null, edge2: Edge | null) {
105 |
if (edge1 === null || edge2 === null) {
106 |
return false
107 |
108 |
return isSameNode(edge1.from, edge2.from) && isSameNode(,;
109 |
110 |
111 |
function nodeFromString(name: string) {
112 |
const match = name.match(/([AIMX])(\d+)_(\d+)/)
113 |
if (match == null) {
114 |
return {
115 |
cell: null,
116 |
item: null,
117 |
118 |
119 |
const [, type, layerStr, tokenStr] = match
120 |
const layer = +layerStr
121 |
const token = +tokenStr
122 |
123 |
const typeToCellItem = new Map<string, CellItem>([
124 |
['A', CellItem.AfterAttn],
125 |
['I', CellItem.AfterFfn],
126 |
['M', CellItem.Ffn],
127 |
['X', CellItem.Original],
128 |
129 |
return {
130 |
cell: {
131 |
layer: layer,
132 |
token: token,
133 |
134 |
item: typeToCellItem.get(type) ?? null,
135 |
136 |
137 |
138 |
function isValidNode(node: Node, nLayers: number, nTokens: number) {
139 |
if (node.cell === null) {
140 |
return true
141 |
142 |
return node.cell.layer < nLayers && node.cell.token < nTokens
143 |
144 |
145 |
function isValidSelection(selection: Selection, nLayers: number, nTokens: number) {
146 |
if (selection.node !== null) {
147 |
return isValidNode(selection.node, nLayers, nTokens)
148 |
149 |
if (selection.edge !== null) {
150 |
return isValidNode(selection.edge.from, nLayers, nTokens) &&
151 |
isValidNode(, nLayers, nTokens)
152 |
153 |
return true
154 |
155 |
156 |
const ContributionGraph = ({ args }: ComponentProps) => {
157 |
const modelInfo = args['model_info']
158 |
const tokens = args['tokens']
159 |
const edgesRaw: EdgeRaw[][] = args['edges_per_token']
160 |
161 |
const nLayers = modelInfo === null ? 0 : modelInfo.n_layers
162 |
const nTokens = tokens === null ? 0 : tokens.length
163 |
164 |
const [selection, setSelection] = useState<Selection>({
165 |
node: null,
166 |
edge: null,
167 |
168 |
var curSelection = selection
169 |
if (!isValidSelection(selection, nLayers, nTokens)) {
170 |
curSelection = {
171 |
node: null,
172 |
edge: null,
173 |
174 |
175 |
176 |
177 |
178 |
const [startToken, setStartToken] = useState<number>(nTokens - 1)
179 |
// We have startToken state var, but it won't be updated till next render, so use
180 |
// this var in the current render.
181 |
var curStartToken = startToken
182 |
if (startToken >= nTokens) {
183 |
curStartToken = nTokens - 1
184 |
185 |
186 |
187 |
const handleRepresentationClick = (node: Node) => {
188 |
const newSelection: Selection = {
189 |
node: node,
190 |
edge: null,
191 |
192 |
193 |
194 |
195 |
196 |
const handleEdgeClick = (edge: Edge) => {
197 |
if (!edge.isSelectable) {
198 |
199 |
200 |
const newSelection: Selection = {
201 |
202 |
edge: edge,
203 |
204 |
205 |
206 |
207 |
208 |
const handleTokenClick = (t: number) => {
209 |
210 |
211 |
212 |
const [xScale, yScale] = useMemo(() => {
213 |
const x = d3.scaleLinear()
214 |
.domain([-2, nTokens - 1])
215 |
.range([0, renderParams.cellW * (nTokens + 2)])
216 |
const y = d3.scaleLinear()
217 |
.domain([-1, nLayers])
218 |
.range([renderParams.cellH * (nLayers + 2), 0])
219 |
return [x, y]
220 |
}, [nLayers, nTokens])
221 |
222 |
const cells = useMemo(() => {
223 |
let result: Cell[] = []
224 |
for (let l = 0; l < nLayers; l++) {
225 |
for (let t = 0; t < nTokens; t++) {
226 |
227 |
layer: l,
228 |
token: t,
229 |
230 |
231 |
232 |
return result
233 |
}, [nLayers, nTokens])
234 |
235 |
const nodeCoords = useMemo(() => {
236 |
let result = new Map<string, Point>()
237 |
const w = renderParams.cellW
238 |
const h = renderParams.cellH
239 |
for (var cell of cells) {
240 |
const cx = xScale(cell.token + 0.5)
241 |
const cy = yScale(cell.layer - 0.5)
242 |
243 |
JSON.stringify({ cell: cell, item: CellItem.AfterAttn }),
244 |
{ x: cx, y: cy + h / 4 },
245 |
246 |
247 |
JSON.stringify({ cell: cell, item: CellItem.AfterFfn }),
248 |
{ x: cx, y: cy - h / 4 },
249 |
250 |
251 |
JSON.stringify({ cell: cell, item: CellItem.Ffn }),
252 |
{ x: cx + 5 * w / 16, y: cy },
253 |
254 |
255 |
for (let t = 0; t < nTokens; t++) {
256 |
cell = {
257 |
layer: 0,
258 |
token: t,
259 |
260 |
const cx = xScale(cell.token + 0.5)
261 |
const cy = yScale(cell.layer - 1.0)
262 |
263 |
JSON.stringify({ cell: cell, item: CellItem.Original }),
264 |
{ x: cx, y: cy + h / 4 },
265 |
266 |
267 |
return result
268 |
}, [cells, nTokens, xScale, yScale])
269 |
270 |
const edges: Edge[][] = useMemo(() => {
271 |
let result = []
272 |
for (var edgeList of edgesRaw) {
273 |
let edgesPerStartToken = []
274 |
for (var edge of edgeList) {
275 |
const u = nodeFromString(edge.source)
276 |
const v = nodeFromString(
277 |
var isSelectable = (
278 |
u.cell !== null && v.cell !== null && v.item === CellItem.AfterAttn
279 |
280 |
var isFfn = (
281 |
u.cell !== null && v.cell !== null && (
282 |
u.item === CellItem.Ffn || v.item === CellItem.Ffn
283 |
284 |
285 |
286 |
weight: edge.weight,
287 |
from: u,
288 |
to: v,
289 |
fromPos: nodeCoords.get(JSON.stringify(u)) ?? { 'x': 0, 'y': 0 },
290 |
toPos: nodeCoords.get(JSON.stringify(v)) ?? { 'x': 0, 'y': 0 },
291 |
isSelectable: isSelectable,
292 |
isFfn: isFfn,
293 |
294 |
295 |
296 |
297 |
return result
298 |
}, [edgesRaw, nodeCoords])
299 |
300 |
const activeNodes = useMemo(() => {
301 |
let result = new Set<string>()
302 |
for (var edge of edges[curStartToken]) {
303 |
const u = JSON.stringify(edge.from)
304 |
const v = JSON.stringify(
305 |
306 |
307 |
308 |
return result
309 |
}, [edges, curStartToken])
310 |
311 |
const nodeProps = useMemo(() => {
312 |
let result: Array<NodeProps> = []
313 |
nodeCoords.forEach((p: Point, node: string) => {
314 |
315 |
node: JSON.parse(node),
316 |
pos: p,
317 |
isActive: activeNodes.has(node),
318 |
319 |
320 |
return result
321 |
}, [nodeCoords, activeNodes])
322 |
323 |
const tokenLabels: Label[] = useMemo(() => {
324 |
if (!tokens) {
325 |
return []
326 |
327 |
return string, i: number) => ({
328 |
text: s.replace(/ /g, 'Β·'),
329 |
pos: {
330 |
x: xScale(i + 0.5),
331 |
y: yScale(-1.5),
332 |
333 |
334 |
}, [tokens, xScale, yScale])
335 |
336 |
const layerLabels: Label[] = useMemo(() => {
337 |
return Array.from(Array(nLayers).keys()).map(i => ({
338 |
text: 'L' + i,
339 |
pos: {
340 |
x: xScale(-0.25),
341 |
y: yScale(i - 0.5),
342 |
343 |
344 |
}, [nLayers, xScale, yScale])
345 |
346 |
const tokenSelectors: Array<[number, Point]> = useMemo(() => {
347 |
return Array.from(Array(nTokens).keys()).map(i => ([
348 |
349 |
350 |
x: xScale(i + 0.5),
351 |
y: yScale(nLayers - 0.5),
352 |
353 |
354 |
}, [nTokens, nLayers, xScale, yScale])
355 |
356 |
const totalW = xScale(nTokens + 2)
357 |
const totalH = yScale(-4)
358 |
useEffect(() => {
359 |
360 |
}, [totalH])
361 |
362 |
const colorScale = d3.scaleLinear(
363 |
[0.0, 0.5, 1.0],
364 |
['#9eba66', 'darkolivegreen', 'darkolivegreen']
365 |
366 |
const ffnEdgeColorScale = d3.scaleLinear(
367 |
[0.0, 0.5, 1.0],
368 |
['orchid', 'purple', 'purple']
369 |
370 |
const edgeWidthScale = d3.scaleLinear([0.0, 0.5, 1.0], [2.0, 3.0, 3.0])
371 |
372 |
const svgRef = useRef(null);
373 |
374 |
useEffect(() => {
375 |
const getNodeStyle = (p: NodeProps, type: string) => {
376 |
if (isSameNode(p.node, curSelection.node)) {
377 |
return 'selectable-item selection'
378 |
379 |
if (p.isActive) {
380 |
return 'selectable-item active-' + type + '-node'
381 |
382 |
return 'selectable-item inactive-node'
383 |
384 |
385 |
const svg =
386 |
387 |
388 |
389 |
390 |
.data(Array.from(Array(nLayers).keys()).filter((x) => x % 2 === 1))
391 |
392 |
393 |
.attr('class', 'layer-highlight')
394 |
.attr('x', xScale(-1.0))
395 |
.attr('y', (layer) => yScale(layer))
396 |
.attr('width', xScale(nTokens + 0.25) - xScale(-1.0))
397 |
.attr('height', (layer) => yScale(layer) - yScale(layer + 1))
398 |
.attr('rx', renderParams.layerCornerRadius)
399 |
400 |
401 |
402 |
403 |
404 |
405 |
.style('stroke', (edge: Edge) => {
406 |
if (isSameEdge(edge, curSelection.edge)) {
407 |
return 'orange'
408 |
409 |
if (edge.isFfn) {
410 |
return ffnEdgeColorScale(edge.weight)
411 |
412 |
return colorScale(edge.weight)
413 |
414 |
.attr('class', (edge: Edge) => edge.isSelectable ? 'selectable-edge' : '')
415 |
.style('stroke-width', (edge: Edge) => edgeWidthScale(edge.weight))
416 |
.attr('x1', (edge: Edge) => edge.fromPos.x)
417 |
.attr('y1', (edge: Edge) => edge.fromPos.y)
418 |
.attr('x2', (edge: Edge) => edge.toPos.x)
419 |
.attr('y2', (edge: Edge) => edge.toPos.y)
420 |
.on('click', (event: PointerEvent, edge) => {
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
.filter((p) => {
429 |
return p.node.item === CellItem.AfterAttn
430 |
|| p.node.item === CellItem.AfterFfn
431 |
432 |
433 |
.attr('class', (p) => getNodeStyle(p, 'residual'))
434 |
.attr('cx', (p) => p.pos.x)
435 |
.attr('cy', (p) => p.pos.y)
436 |
.attr('r', renderParams.attnSize / 2)
437 |
.on('click', (event: PointerEvent, p) => {
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
.filter((p) => p.node.item === CellItem.Ffn && p.isActive)
446 |
447 |
.attr('class', (p) => getNodeStyle(p, 'ffn'))
448 |
.attr('x', (p) => p.pos.x - renderParams.ffnSize / 2)
449 |
.attr('y', (p) => p.pos.y - renderParams.ffnSize / 2)
450 |
.attr('width', renderParams.ffnSize)
451 |
.attr('height', renderParams.ffnSize)
452 |
.on('click', (event: PointerEvent, p) => {
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
.attr('x', (label: Label) => label.pos.x)
462 |
.attr('y', (label: Label) => label.pos.y)
463 |
.attr('text-anchor', 'end')
464 |
.attr('dominant-baseline', 'middle')
465 |
.attr('alignment-baseline', 'top')
466 |
.attr('transform', (label: Label) =>
467 |
'rotate(-40, ' + label.pos.x + ', ' + label.pos.y + ')')
468 |
.text((label: Label) => label.text)
469 |
470 |
471 |
472 |
473 |
474 |
475 |
.attr('x', (label: Label) => label.pos.x)
476 |
.attr('y', (label: Label) => label.pos.y)
477 |
.attr('text-anchor', 'middle')
478 |
.attr('alignment-baseline', 'middle')
479 |
.text((label: Label) => label.text)
480 |
481 |
482 |
483 |
484 |
485 |
486 |
.attr('class', ([i,]) => (
487 |
curStartToken === i
488 |
? 'selectable-item selection'
489 |
: 'selectable-item token-selector'
490 |
491 |
.attr('points', ([, p]) => tokenPointerPolygon(p))
492 |
.attr('r', renderParams.tokenSelectorSize / 2)
493 |
.on('click', (event: PointerEvent, [i,]) => {
494 |
495 |
496 |
}, [
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
return <svg ref={svgRef} width={totalW} height={totalH}></svg>
515 |
516 |
517 |
export default withStreamlitConnection(ContributionGraph)
9 |
.graph-container {
10 |
display: flex;
11 |
justify-content: center;
12 |
align-items: center;
13 |
height: 100vh;
14 |
15 |
16 |
.svg {
17 |
border: 1px solid #ccc;
18 |
19 |
20 |
.layer-highlight {
21 |
fill: #f0f5f0;
22 |
23 |
24 |
.selectable-item {
25 |
stroke: black;
26 |
cursor: pointer;
27 |
28 |
29 |
30 |
.selection:hover {
31 |
fill: orange;
32 |
33 |
34 |
.active-residual-node {
35 |
fill: yellowgreen;
36 |
37 |
38 |
.active-residual-node:hover {
39 |
fill: olivedrab;
40 |
41 |
42 |
.active-ffn-node {
43 |
fill: orchid;
44 |
45 |
46 |
.active-ffn-node:hover {
47 |
fill: purple;
48 |
49 |
50 |
.inactive-node {
51 |
fill: lightgray;
52 |
stroke-width: 0.5px;
53 |
54 |
55 |
.inactive-node:hover {
56 |
fill: gray;
57 |
58 |
59 |
.selectable-edge {
60 |
cursor: pointer;
61 |
62 |
63 |
.token-selector {
64 |
fill: lightblue;
65 |
66 |
67 |
.token-selector:hover {
68 |
fill: cornflowerblue;
69 |
70 |
71 |
.selector-item {
72 |
fill: lightblue;
73 |
74 |
75 |
.selector-item:hover {
76 |
fill: cornflowerblue;
77 |
@@ -0,0 +1,154 @@
9 |
import {
10 |
11 |
12 |
13 |
} from "streamlit-component-lib"
14 |
import React, { useEffect, useMemo, useRef, useState } from 'react';
15 |
import * as d3 from 'd3';
16 |
17 |
import {
18 |
19 |
} from './common';
20 |
import './LlmViewer.css';
export const renderParams = {
23 |
verticalGap: 24,
24 |
horizontalGap: 24,
25 |
itemSize: 8,
26 |
27 |
28 |
interface Item {
29 |
index: number
30 |
text: string
31 |
temperature: number
32 |
33 |
34 |
const Selector = ({ args }: ComponentProps) => {
35 |
const items: Item[] = args["items"]
36 |
const preselected_index: number | null = args["preselected_index"]
37 |
const n = items.length
38 |
39 |
const [selection, setSelection] = useState<number | null>(null)
40 |
41 |
// Ensure the preselected element has effect only when it's a new data.
42 |
var args_json = JSON.stringify(args)
43 |
useEffect(() => {
44 |
45 |
46 |
}, [args_json, preselected_index]);
47 |
48 |
const handleItemClick = (index: number) => {
49 |
50 |
51 |
52 |
53 |
const [xScale, yScale] = useMemo(() => {
54 |
const x = d3.scaleLinear()
55 |
.domain([0, 1])
56 |
.range([0, renderParams.horizontalGap])
57 |
const y = d3.scaleLinear()
58 |
.domain([0, n - 1])
59 |
.range([0, renderParams.verticalGap * (n - 1)])
60 |
return [x, y]
61 |
}, [n])
62 |
63 |
const itemCoords: Point[] = useMemo(() => {
64 |
return Array.from(Array(n).keys()).map(i => ({
65 |
x: xScale(0.5),
66 |
y: yScale(i + 0.5),
67 |
68 |
}, [n, xScale, yScale])
69 |
70 |
var hasTemperature = false
71 |
if (n > 0) {
72 |
var t = items[0].temperature
73 |
hasTemperature = (t !== null && t !== undefined)
74 |
75 |
const colorScale = useMemo(() => {
76 |
var min_t = 0.0
77 |
var max_t = 1.0
78 |
if (hasTemperature) {
79 |
min_t = items[0].temperature
80 |
max_t = items[0].temperature
81 |
for (var i = 0; i < n; i++) {
82 |
const t = items[i].temperature
83 |
min_t = Math.min(min_t, t)
84 |
max_t = Math.max(max_t, t)
85 |
86 |
87 |
const norm = d3.scaleLinear([min_t, max_t], [0.0, 1.0])
88 |
const colorScale = d3.scaleSequential(d3.interpolateYlGn);
89 |
return d3.scaleSequential(value => colorScale(norm(value)))
90 |
}, [items, hasTemperature, n])
91 |
92 |
const totalW = 100
93 |
const totalH = yScale(n)
94 |
useEffect(() => {
95 |
96 |
}, [totalH])
97 |
98 |
const svgRef = useRef(null);
99 |
100 |
useEffect(() => {
101 |
const svg =
102 |
103 |
104 |
const getItemClass = (index: number) => {
105 |
var style = 'selectable-item '
106 |
style += index === selection ? 'selection' : 'selector-item'
107 |
return style
108 |
109 |
110 |
const getItemColor = (item: Item) => {
111 |
var t = item.temperature ?? 0.0
112 |
return item.index === selection ? 'orange' : colorScale(t)
113 |
114 |
115 |
var icons = svg
116 |
117 |
118 |
119 |
120 |
.attr('cx', (i) => itemCoords[i].x)
121 |
.attr('cy', (i) => itemCoords[i].y)
122 |
.attr('r', renderParams.itemSize / 2)
123 |
.on('click', (event: PointerEvent, i) => {
124 |
125 |
126 |
.attr('class', (i) => getItemClass(items[i].index))
127 |
if (hasTemperature) {
128 |
+'fill', (i) => getItemColor(items[i]))
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
.attr('x', (i) => itemCoords[i].x + renderParams.horizontalGap / 2)
137 |
.attr('y', (i) => itemCoords[i].y)
138 |
.attr('text-anchor', 'left')
139 |
.attr('alignment-baseline', 'middle')
140 |
.text((i) => items[i].text)
141 |
142 |
}, [
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
return <svg ref={svgRef} width={totalW} height={totalH}></svg>
152 |
153 |
154 |
export default withStreamlitConnection(Selector)
9 |
export interface Point {
10 |
x: number
11 |
y: number
12 |
13 |
14 |
export interface Label {
15 |
text: string
16 |
pos: Point
17 |
9 |
import React from "react"
10 |
import ReactDOM from "react-dom"
11 |
12 |
import {
13 |
14 |
15 |
} from "streamlit-component-lib"
16 |
17 |
18 |
import ContributionGraph from "./ContributionGraph"
19 |
import Selector from "./Selector"
20 |
21 |
const LlmViewerComponent = (props: ComponentProps) => {
22 |
switch (props.args['component']) {
23 |
case 'graph':
24 |
return <ContributionGraph />
25 |
case 'selector':
26 |
return <Selector />
27 |
28 |
return <></>
29 |
30 |
31 |
32 |
const StreamlitLlmViewerComponent = withStreamlitConnection(LlmViewerComponent)
33 |
34 |
35 |
36 |
<StreamlitLlmViewerComponent />
37 |
38 |
39 |
15 |
class TransparentLlmTestCase(unittest.TestCase):
16 |
17 |
def setUpClass(cls):
18 |
# Picking the smallest model possible so that the test runs faster. It's ok to
19 |
# change this model, but you'll need to update tokenization specifics in some
20 |
# tests.
21 |
cls._llm = TransformerLensTransparentLlm(
22 |
23 |
24 |
25 |
26 |
def setUp(self):
27 |
+["test", "test 1"])
28 |
self._eps = 1e-5
29 |
30 |
def test_model_info(self):
31 |
info = self._llm.model_info()
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
def test_tokens(self):
45 |
tokens = self._llm.tokens()
46 |
47 |
pad = 1
48 |
bos = 2
49 |
test = 21959
50 |
one = 112
51 |
52 |
self.assertEqual(tokens.tolist(), [[bos, test, pad], [bos, test, one]])
53 |
54 |
def test_tokens_to_strings(self):
55 |
s = self._llm.tokens_to_strings(torch.Tensor([2, 21959, 112]).to(
56 |
self.assertEqual(s, ["</s>", "test", " 1"])
57 |
58 |
def test_manage_state(self):
59 |
# One was called at the setup. Call one more and make sure the object
60 |
# returns values for the new state.
61 |
+["one", "two", "three", "four"])
62 |
self.assertEqual(self._llm.tokens().shape[0], 4)
63 |
64 |
def test_residual_in_and_out(self):
65 |
66 |
Test that residual_in is a residual_out for the previous layer.
67 |
68 |
for layer in range(1, 12):
69 |
prev_residual_out = self._llm.residual_out(layer - 1)
70 |
residual_in = self._llm.residual_in(layer)
71 |
diff = torch.max(torch.abs(residual_in - prev_residual_out)).item()
72 |
self.assertLess(diff, self._eps, f"layer {layer}")
73 |
74 |
def test_residual_plus_block(self):
75 |
76 |
Make sure that new residual = old residual + block output. Here, block is an ffn
77 |
or attention. It's not that obvious because it could be that layer norm is
78 |
applied after the block output, but before saving the result to residual.
79 |
Luckily, this is not the case in TransformerLens, and we're relying on that.
80 |
81 |
layer = 3
82 |
batch = 0
83 |
pos = 0
84 |
85 |
residual_in = self._llm.residual_in(layer)[batch][pos]
86 |
residual_mid = self._llm.residual_after_attn(layer)[batch][pos]
87 |
residual_out = self._llm.residual_out(layer)[batch][pos]
88 |
ffn_out = self._llm.ffn_out(layer)[batch][pos]
89 |
attn_out = self._llm.attention_output(batch, layer, pos)
90 |
91 |
a = residual_mid
92 |
b = residual_in + attn_out
93 |
diff = torch.max(torch.abs(a - b)).item()
94 |
self.assertLess(diff, self._eps, "attn")
95 |
96 |
a = residual_out
97 |
b = residual_mid + ffn_out
98 |
diff = torch.max(torch.abs(a - b)).item()
99 |
self.assertLess(diff, self._eps, "ffn")
100 |
101 |
def test_tensor_shapes(self):
102 |
# Not much we can do about the tensors, but at least check their shapes and
103 |
# that they don't contain NaNs.
104 |
vocab_size = 50272
105 |
n_batch = 2
106 |
n_tokens = 3
107 |
d_model = 768
108 |
d_hidden = d_model * 4
109 |
n_heads = 12
110 |
layer = 5
111 |
112 |
device = self._llm.residual_in(0).device
113 |
114 |
for name, tensor, expected_shape in [
115 |
("r_in", self._llm.residual_in(layer), [n_batch, n_tokens, d_model]),
116 |
117 |
118 |
119 |
[n_batch, n_tokens, d_model],
120 |
121 |
("r_out", self._llm.residual_out(layer), [n_batch, n_tokens, d_model]),
122 |
("logits", self._llm.logits(), [n_batch, n_tokens, vocab_size]),
123 |
("ffn_out", self._llm.ffn_out(layer), [n_batch, n_tokens, d_model]),
124 |
125 |
126 |
self._llm.decomposed_ffn_out(0, 0, 0),
127 |
[d_hidden, d_model],
128 |
129 |
("neuron_activations", self._llm.neuron_activations(0, 0, 0), [d_hidden]),
130 |
("neuron_output", self._llm.neuron_output(0, 0), [d_model]),
131 |
132 |
133 |
self._llm.attention_matrix(0, 0, 0),
134 |
[n_tokens, n_tokens],
135 |
136 |
137 |
138 |
self._llm.attention_output_per_head(0, 0, 0, 0),
139 |
140 |
141 |
142 |
143 |
self._llm.attention_output(0, 0, 0),
144 |
145 |
146 |
147 |
148 |
self._llm.decomposed_attn(0, layer),
149 |
[n_tokens, n_tokens, n_heads, d_model],
150 |
151 |
152 |
153 |
self._llm.unembed(torch.zeros([d_model]).to(device), normalize=True),
154 |
155 |
156 |
157 |
self.assertEqual(list(tensor.shape), expected_shape, name)
158 |
self.assertFalse(torch.any(tensor.isnan()), name)
159 |
160 |
161 |
if __name__ == "__main__":
162 |
7 |
from dataclasses import dataclass
8 |
from typing import List, Optional
9 |
10 |
import torch
11 |
import transformer_lens
12 |
import transformers
13 |
from fancy_einsum import einsum
14 |
from jaxtyping import Float, Int
15 |
from typeguard import typechecked
16 |
import streamlit as st
17 |
18 |
from llm_transparency_tool.models.transparent_llm import ModelInfo, TransparentLlm
19 |
20 |
28 |
29 |
30 |
31 |
32 |
transformers.PreTrainedModel: id,
33 |
transformers.PreTrainedTokenizer: id
34 |
35 |
36 |
def load_hooked_transformer(
37 |
model_name: str,
38 |
hf_model: Optional[transformers.PreTrainedModel] = None,
39 |
tlens_device: str = "cuda",
40 |
dtype: torch.dtype = torch.float32,
41 |
42 |
# if tlens_device == "cuda":
43 |
# n_devices = torch.cuda.device_count()
44 |
# else:
45 |
# n_devices = 1
46 |
tlens_model = transformer_lens.HookedTransformer.from_pretrained(
47 |
48 |
49 |
fold_ln=False, # Keep layer norm where it is.
50 |
51 |
52 |
53 |
# n_devices=n_devices,
54 |
55 |
56 |
57 |
return tlens_model
58 |
59 |
60 |
# TODO(igortufanov): If we want to scale the app to multiple users, we need more careful
61 |
# thread-safe implementation. The simplest option could be to wrap the existing methods
62 |
# in mutexes.
63 |
class TransformerLensTransparentLlm(TransparentLlm):
64 |
65 |
Implementation of Transparent LLM based on transformer lens.
66 |
67 |
68 |
- model_name: The official name of the model from HuggingFace. Even if the model was
69 |
patched or loaded locally, the name should still be official because that's how
70 |
transformer_lens treats the model.
71 |
- hf_model: The language model as a HuggingFace class.
72 |
- tokenizer,
73 |
- device: "gpu" or "cpu"
74 |
75 |
76 |
def __init__(
77 |
78 |
model_name: str,
79 |
hf_model: Optional[transformers.PreTrainedModel] = None,
80 |
tokenizer: Optional[transformers.PreTrainedTokenizer] = None,
81 |
device: str = "gpu",
82 |
dtype: torch.dtype = torch.float32,
83 |
84 |
if device == "gpu":
85 |
self.device = "cuda"
86 |
if not torch.cuda.is_available():
87 |
RuntimeError("Asked to run on gpu, but torch couldn't find cuda")
88 |
elif device == "cpu":
89 |
self.device = "cpu"
90 |
91 |
raise RuntimeError(f"Specified device {device} is not a valid option")
92 |
93 |
self.dtype = dtype
94 |
self.hf_tokenizer = tokenizer
95 |
self.hf_model = hf_model
96 |
97 |
# self._model = tlens_model
98 |
self._model_name = model_name
99 |
self._prepend_bos = True
100 |
self._last_run = None
101 |
self._run_exception = RuntimeError(
102 |
"Tried to use the model output before calling the `run` method"
103 |
104 |
105 |
def copy(self):
106 |
import copy
107 |
return copy.copy(self)
108 |
109 |
110 |
def _model(self):
111 |
tlens_model = load_hooked_transformer(
112 |
113 |
114 |
115 |
116 |
117 |
118 |
if self.hf_tokenizer is not None:
119 |
tlens_model.set_tokenizer(self.hf_tokenizer, default_padding_side="left")
120 |
121 |
122 |
123 |
124 |
125 |
return tlens_model
126 |
127 |
def model_info(self) -> ModelInfo:
128 |
cfg = self._model.cfg
129 |
return ModelInfo(
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
def run(self, sentences: List[str]) -> None:
140 |
tokens = self._model.to_tokens(sentences, prepend_bos=self._prepend_bos)
141 |
logits, cache = self._model.run_with_cache(tokens)
142 |
143 |
self._last_run = _RunInfo(
144 |
145 |
146 |
147 |
148 |
149 |
def batch_size(self) -> int:
150 |
if not self._last_run:
151 |
raise self._run_exception
152 |
return self._last_run.logits.shape[0]
153 |
154 |
155 |
def tokens(self) -> Int[torch.Tensor, "batch pos"]:
156 |
if not self._last_run:
157 |
raise self._run_exception
158 |
return self._last_run.tokens
159 |
160 |
161 |
def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
162 |
return self._model.to_str_tokens(tokens)
163 |
164 |
165 |
def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
166 |
if not self._last_run:
167 |
raise self._run_exception
168 |
return self._last_run.logits
169 |
170 |
171 |
172 |
def unembed(
173 |
174 |
t: Float[torch.Tensor, "d_model"],
175 |
normalize: bool,
176 |
) -> Float[torch.Tensor, "vocab"]:
177 |
# t: [d_model] -> [batch, pos, d_model]
178 |
tdim = t.unsqueeze(0).unsqueeze(0)
179 |
if normalize:
180 |
normalized = self._model.ln_final(tdim)
181 |
result = self._model.unembed(normalized)
182 |
183 |
result = self._model.unembed(tdim)
184 |
return result[0][0]
185 |
186 |
def _get_block(self, layer: int, block_name: str) -> str:
187 |
if not self._last_run:
188 |
raise self._run_exception
189 |
return self._last_run.cache[f"blocks.{layer}.{block_name}"]
190 |
191 |
# ================= Methods related to the residual stream =================
192 |
193 |
194 |
def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
195 |
if not self._last_run:
196 |
raise self._run_exception
197 |
return self._get_block(layer, "hook_resid_pre")
198 |
199 |
200 |
def residual_after_attn(
201 |
self, layer: int
202 |
) -> Float[torch.Tensor, "batch pos d_model"]:
203 |
if not self._last_run:
204 |
raise self._run_exception
205 |
return self._get_block(layer, "hook_resid_mid")
206 |
207 |
208 |
def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
209 |
if not self._last_run:
210 |
raise self._run_exception
211 |
return self._get_block(layer, "hook_resid_post")
212 |
213 |
# ================ Methods related to the feed-forward layer ===============
214 |
215 |
216 |
def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
217 |
if not self._last_run:
218 |
raise self._run_exception
219 |
return self._get_block(layer, "hook_mlp_out")
220 |
221 |
222 |
223 |
def decomposed_ffn_out(
224 |
225 |
batch_i: int,
226 |
layer: int,
227 |
pos: int,
228 |
) -> Float[torch.Tensor, "hidden d_model"]:
229 |
# Take activations right before they're multiplied by W_out, i.e. non-linearity
230 |
# and layer norm are already applied.
231 |
processed_activations = self._get_block(layer, "mlp.hook_post")[batch_i][pos]
232 |
return torch.mul(processed_activations.unsqueeze(-1), self._model.W_out[layer])
233 |
234 |
235 |
def neuron_activations(
236 |
237 |
batch_i: int,
238 |
layer: int,
239 |
pos: int,
240 |
) -> Float[torch.Tensor, "hidden"]:
241 |
return self._get_block(layer, "mlp.hook_pre")[batch_i][pos]
242 |
243 |
244 |
def neuron_output(
245 |
246 |
layer: int,
247 |
neuron: int,
248 |
) -> Float[torch.Tensor, "d_model"]:
249 |
return self._model.W_out[layer][neuron]
250 |
251 |
# ==================== Methods related to the attention ====================
252 |
253 |
254 |
def attention_matrix(
255 |
self, batch_i: int, layer: int, head: int
256 |
) -> Float[torch.Tensor, "query_pos key_pos"]:
257 |
return self._get_block(layer, "attn.hook_pattern")[batch_i][head]
258 |
259 |
260 |
def attention_output_per_head(
261 |
262 |
batch_i: int,
263 |
layer: int,
264 |
pos: int,
265 |
head: int,
266 |
) -> Float[torch.Tensor, "d_model"]:
267 |
return self._get_block(layer, "attn.hook_result")[batch_i][pos][head]
268 |
269 |
270 |
def attention_output(
271 |
272 |
batch_i: int,
273 |
layer: int,
274 |
pos: int,
275 |
) -> Float[torch.Tensor, "d_model"]:
276 |
return self._get_block(layer, "hook_attn_out")[batch_i][pos]
277 |
278 |
279 |
280 |
def decomposed_attn(
281 |
self, batch_i: int, layer: int
282 |
) -> Float[torch.Tensor, "pos key_pos head d_model"]:
283 |
if not self._last_run:
284 |
raise self._run_exception
285 |
hook_v = self._get_block(layer, "attn.hook_v")[batch_i]
286 |
b_v = self._model.b_V[layer]
287 |
v = hook_v + b_v
288 |
pattern = self._get_block(layer, "attn.hook_pattern")[batch_i].to(v.dtype)
289 |
z = einsum(
290 |
"key_pos head d_head, "
291 |
"head query_pos key_pos -> "
292 |
"query_pos key_pos head d_head",
293 |
294 |
295 |
296 |
decomposed_attn = einsum(
297 |
"pos key_pos head d_head, "
298 |
"head d_head d_model -> "
299 |
"pos key_pos head d_model",
300 |
301 |
302 |
303 |
return decomposed_attn
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
from abc import ABC, abstractmethod
8 |
from dataclasses import dataclass
9 |
from typing import List
10 |
11 |
import torch
12 |
from jaxtyping import Float, Int
13 |
14 |
15 |
16 |
class ModelInfo:
17 |
name: str
18 |
19 |
# Not the actual number of parameters, but rather the order of magnitude
20 |
n_params_estimate: int
21 |
22 |
n_layers: int
23 |
n_heads: int
24 |
d_model: int
25 |
d_vocab: int
26 |
27 |
28 |
class TransparentLlm(ABC):
29 |
30 |
An abstract stateful interface for a language model. The model is supposed to be
31 |
loaded at the class initialization.
32 |
33 |
The internal state is the resulting tensors from the last call of the `run` method.
34 |
Most of the methods could return values based on the state, but some may do cheap
35 |
computations based on them.
36 |
37 |
38 |
39 |
def model_info(self) -> ModelInfo:
40 |
41 |
Gives general info about the model. This method must be available before any
42 |
calls of the `run`.
43 |
44 |
45 |
46 |
47 |
def run(self, sentences: List[str]) -> None:
48 |
49 |
Run the inference on the given sentences in a single batch and store all
50 |
necessary info in the internal state.
51 |
52 |
53 |
54 |
55 |
def batch_size(self) -> int:
56 |
57 |
The size of the batch that was used for the last call of `run`.
58 |
59 |
60 |
61 |
62 |
def tokens(self) -> Int[torch.Tensor, "batch pos"]:
63 |
64 |
65 |
66 |
def tokens_to_strings(self, tokens: Int[torch.Tensor, "pos"]) -> List[str]:
67 |
68 |
69 |
70 |
def logits(self) -> Float[torch.Tensor, "batch pos d_vocab"]:
71 |
72 |
73 |
74 |
def unembed(
75 |
76 |
t: Float[torch.Tensor, "d_model"],
77 |
normalize: bool,
78 |
) -> Float[torch.Tensor, "vocab"]:
79 |
80 |
Project the given vector (for example, the state of the residual stream for a
81 |
layer and token) into the output vocabulary.
82 |
83 |
normalize: whether to apply the final normalization before the unembedding.
84 |
Setting it to True and applying to output of the last layer gives the output of
85 |
the model.
86 |
87 |
88 |
89 |
# ================= Methods related to the residual stream =================
90 |
91 |
92 |
def residual_in(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
93 |
94 |
The state of the residual stream before entering the layer. For example, when
95 |
layer == 0 these must the embedded tokens (including positional embedding).
96 |
97 |
98 |
99 |
100 |
def residual_after_attn(
101 |
self, layer: int
102 |
) -> Float[torch.Tensor, "batch pos d_model"]:
103 |
104 |
The state of the residual stream after attention, but before the FFN in the
105 |
given layer.
106 |
107 |
108 |
109 |
110 |
def residual_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
111 |
112 |
The state of the residual stream after the given layer. This is equivalent to the
113 |
next layer's input.
114 |
115 |
116 |
117 |
# ================ Methods related to the feed-forward layer ===============
118 |
119 |
120 |
def ffn_out(self, layer: int) -> Float[torch.Tensor, "batch pos d_model"]:
121 |
122 |
The output of the FFN layer, before it gets merged into the residual stream.
123 |
124 |
125 |
126 |
127 |
def decomposed_ffn_out(
128 |
129 |
batch_i: int,
130 |
layer: int,
131 |
pos: int,
132 |
) -> Float[torch.Tensor, "hidden d_model"]:
133 |
134 |
A collection of vectors added to the residual stream by each neuron. It should
135 |
be the same as neuron activations multiplied by neuron outputs.
136 |
137 |
138 |
139 |
140 |
def neuron_activations(
141 |
142 |
batch_i: int,
143 |
layer: int,
144 |
pos: int,
145 |
) -> Float[torch.Tensor, "d_ffn"]:
146 |
147 |
The content of the hidden layer right after the activation function was applied.
148 |
149 |
150 |
151 |
152 |
def neuron_output(
153 |
154 |
layer: int,
155 |
neuron: int,
156 |
) -> Float[torch.Tensor, "d_model"]:
157 |
158 |
Return the value that the given neuron adds to the residual stream. It's a raw
159 |
vector from the model parameters, no activation involved.
160 |
161 |
162 |
163 |
# ==================== Methods related to the attention ====================
164 |
165 |
166 |
def attention_matrix(
167 |
self, batch_i, layer: int, head: int
168 |
) -> Float[torch.Tensor, "query_pos key_pos"]:
169 |
170 |
Return a lower-diagonal attention matrix.
171 |
172 |
173 |
174 |
175 |
def attention_output(
176 |
177 |
batch_i: int,
178 |
layer: int,
179 |
pos: int,
180 |
head: int,
181 |
) -> Float[torch.Tensor, "d_model"]:
182 |
183 |
Return what the given head at the given layer and pos added to the residual
184 |
185 |
186 |
187 |
188 |
189 |
def decomposed_attn(
190 |
self, batch_i: int, layer: int
191 |
) -> Float[torch.Tensor, "source target head d_model"]:
192 |
193 |
194 |
- source: index of token from the previous layer
195 |
- target: index of token on the current layer
196 |
The decomposed attention tells what vector from source representation was used
197 |
in order to contribute to the taget representation.
198 |
199 |
@@ -0,0 +1,5 @@
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
from typing import Tuple
8 |
9 |
import einops
10 |
import torch
11 |
from jaxtyping import Float
12 |
from typeguard import typechecked
13 |
14 |
15 |
16 |
17 |
def get_contributions(
18 |
parts: torch.Tensor,
19 |
whole: torch.Tensor,
20 |
distance_norm: int = 1,
21 |
) -> torch.Tensor:
22 |
23 |
Compute contributions of the `parts` vectors into the `whole` vector.
24 |
25 |
Shapes of the tensors are as follows:
26 |
parts: p_1 ... p_k, v_1 ... v_n, d
27 |
whole: v_1 ... v_n, d
28 |
result: p_1 ... p_k, v_1 ... v_n
29 |
30 |
31 |
* `p_1 ... p_k`: dimensions for enumerating the parts
32 |
* `v_1 ... v_n`: dimensions listing the independent cases (batching),
33 |
* `d` is the dimension to compute the distances on.
34 |
35 |
The resulting contributions will be normalized so that
36 |
for each v_: sum(over p_ of result(p_, v_)) = 1.
37 |
38 |
EPS = 1e-5
39 |
40 |
k = len(parts.shape) - len(whole.shape)
41 |
assert k >= 0
42 |
assert parts.shape[k:] == whole.shape
43 |
bc_whole = whole.expand(parts.shape) # new dims p_1 ... p_k are added to the front
44 |
45 |
distance = torch.nn.functional.pairwise_distance(parts, bc_whole, p=distance_norm)
46 |
47 |
whole_norm = torch.norm(whole, p=distance_norm, dim=-1)
48 |
distance = (whole_norm - distance).clip(min=EPS)
49 |
50 |
sum = distance.sum(dim=tuple(range(k)), keepdim=True)
51 |
52 |
return distance / sum
53 |
54 |
55 |
56 |
57 |
def get_contributions_with_one_off_part(
58 |
parts: torch.Tensor,
59 |
one_off: torch.Tensor,
60 |
whole: torch.Tensor,
61 |
distance_norm: int = 1,
62 |
) -> Tuple[torch.Tensor, torch.Tensor]:
63 |
64 |
Same as computing the contributions, but there is one additional part. That's useful
65 |
because we always have the residual stream as one of the parts.
66 |
67 |
See `get_contributions` documentation about `parts` and `whole` dimensions. The
68 |
`one_off` should have the same dimensions as `whole`.
69 |
70 |
Returns a pair consisting of
71 |
1. contributions tensor for the `parts`
72 |
2. contributions tensor for the `one_off` vector
73 |
74 |
assert one_off.shape == whole.shape
75 |
76 |
k = len(parts.shape) - len(whole.shape)
77 |
assert k >= 0
78 |
79 |
# Flatten the p_ dimensions, get contributions for the list, unflatten.
80 |
flat = parts.flatten(start_dim=0, end_dim=k - 1)
81 |
flat =[flat, one_off.unsqueeze(0)])
82 |
contributions = get_contributions(flat, whole, distance_norm)
83 |
parts_contributions, one_off_contributions = torch.split(
84 |
contributions, flat.shape[0] - 1
85 |
86 |
return (
87 |
parts_contributions.unflatten(0, parts.shape[0:k]),
88 |
89 |
90 |
91 |
92 |
93 |
94 |
def get_attention_contributions(
95 |
resid_pre: Float[torch.Tensor, "batch pos d_model"],
96 |
resid_mid: Float[torch.Tensor, "batch pos d_model"],
97 |
decomposed_attn: Float[torch.Tensor, "batch pos key_pos head d_model"],
98 |
distance_norm: int = 1,
99 |
) -> Tuple[
100 |
Float[torch.Tensor, "batch pos key_pos head"],
101 |
Float[torch.Tensor, "batch pos"],
102 |
103 |
104 |
Returns a pair of
105 |
- a tensor of contributions of each token via each head
106 |
- the contribution of the residual stream.
107 |
108 |
109 |
# part dimensions | batch dimensions | vector dimension
110 |
# ----------------+------------------+-----------------
111 |
# key_pos, head | batch, pos | d_model
112 |
parts = einops.rearrange(
113 |
114 |
"batch pos key_pos head d_model -> key_pos head batch pos d_model",
115 |
116 |
attn_contribution, residual_contribution = get_contributions_with_one_off_part(
117 |
parts, resid_pre, resid_mid, distance_norm
118 |
119 |
return (
120 |
121 |
attn_contribution, "key_pos head batch pos -> batch pos key_pos head"
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
def get_mlp_contributions(
130 |
resid_mid: Float[torch.Tensor, "batch pos d_model"],
131 |
resid_post: Float[torch.Tensor, "batch pos d_model"],
132 |
mlp_out: Float[torch.Tensor, "batch pos d_model"],
133 |
distance_norm: int = 1,
134 |
) -> Tuple[Float[torch.Tensor, "batch pos"], Float[torch.Tensor, "batch pos"]]:
135 |
136 |
Returns a pair of (mlp, residual) contributions for each sentence and token.
137 |
138 |
139 |
contributions = get_contributions(
140 |
torch.stack((mlp_out, resid_mid)), resid_post, distance_norm
141 |
142 |
return contributions[0], contributions[1]
143 |
144 |
145 |
146 |
147 |
def get_decomposed_mlp_contributions(
148 |
resid_mid: Float[torch.Tensor, "d_model"],
149 |
resid_post: Float[torch.Tensor, "d_model"],
150 |
decomposed_mlp_out: Float[torch.Tensor, "hidden d_model"],
151 |
distance_norm: int = 1,
152 |
) -> Tuple[Float[torch.Tensor, "hidden"], float]:
153 |
154 |
Similar to `get_mlp_contributions`, but it takes the MLP output for each neuron of
155 |
the hidden layer and thus computes a contribution per neuron.
156 |
157 |
Doesn't contain batch and token dimensions for sake of saving memory. But we may
158 |
consider adding them.
159 |
160 |
161 |
neuron_contributions, residual_contribution = get_contributions_with_one_off_part(
162 |
decomposed_mlp_out, resid_mid, resid_post, distance_norm
163 |
164 |
return neuron_contributions, residual_contribution.item()
165 |
166 |
167 |
168 |
def apply_threshold_and_renormalize(
169 |
threshold: float,
170 |
c_blocks: torch.Tensor,
171 |
c_residual: torch.Tensor,
172 |
) -> Tuple[torch.Tensor, torch.Tensor]:
173 |
174 |
Thresholding mechanism used in the original graphs paper. After the threshold is
175 |
applied, the remaining contributions are renormalized on order to sum up to 1 for
176 |
each representation.
177 |
178 |
threshold: The threshold.
179 |
c_residual: Contribution of the residual stream for each representation. This tensor
180 |
should contain 1 element per representation, i.e., its dimensions are all batch
181 |
182 |
c_blocks: Contributions of the blocks. Could be 1 block per representation, like
183 |
ffn, or heads*tokens blocks in case of attention. The shape of `c_residual`
184 |
must be a prefix if the shape of this tensor. The remaining dimensions are for
185 |
listing the blocks.
186 |
187 |
188 |
block_dims = len(c_blocks.shape)
189 |
resid_dims = len(c_residual.shape)
190 |
bound_dims = block_dims - resid_dims
191 |
assert bound_dims >= 0
192 |
assert c_blocks.shape[0:resid_dims] == c_residual.shape
193 |
194 |
c_blocks = c_blocks * (c_blocks > threshold)
195 |
c_residual = c_residual * (c_residual > threshold)
196 |
197 |
denom = c_residual + c_blocks.sum(dim=tuple(range(resid_dims, block_dims)))
198 |
return (
199 |
c_blocks / denom.reshape(denom.shape + (1,) * bound_dims),
200 |
c_residual / denom,
201 |
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
from typing import List, Optional
8 |
9 |
import networkx as nx
10 |
import torch
11 |
12 |
import llm_transparency_tool.routes.contributions as contributions
13 |
from llm_transparency_tool.models.transparent_llm import TransparentLlm
14 |
15 |
16 |
class GraphBuilder:
17 |
18 |
Constructs the contributions graph with edges given one by one. The resulting graph
19 |
is a networkx graph that can be accessed via the `graph` field. It contains the
20 |
following types of nodes:
21 |
22 |
- X0_<token>: the original token.
23 |
- A<layer>_<token>: the residual stream after attention at the given layer for the
24 |
given token.
25 |
- M<layer>_<token>: the ffn block.
26 |
- I<layer>_<token>: the residual stream after the ffn block.
27 |
28 |
29 |
def __init__(self, n_layers: int, n_tokens: int):
30 |
self._n_layers = n_layers
31 |
self._n_tokens = n_tokens
32 |
33 |
self.graph = nx.DiGraph()
34 |
for layer in range(n_layers):
35 |
for token in range(n_tokens):
36 |
37 |
38 |
39 |
for token in range(n_tokens):
40 |
41 |
42 |
def get_output_node(self, token: int):
43 |
return f"I{self._n_layers - 1}_{token}"
44 |
45 |
def _add_edge(self, u: str, v: str, weight: float):
46 |
# TODO(igortufanov): Here we sum up weights for multi-edges. It happens with
47 |
# attention from the current token and the residual edge. Ideally these need to
48 |
# be 2 separate edges, but then we need to do a MultiGraph. Multigraph is fine,
49 |
# but when we try to traverse it, we face some NetworkX issue with EDGE_OK
50 |
# receiving 3 arguments instead of 2.
51 |
if self.graph.has_edge(u, v):
52 |
self.graph[u][v]["weight"] += weight
53 |
54 |
self.graph.add_edge(u, v, weight=weight)
55 |
56 |
def add_attention_edge(self, layer: int, token_from: int, token_to: int, w: float):
57 |
58 |
f"I{layer-1}_{token_from}" if layer > 0 else f"X0_{token_from}",
59 |
60 |
61 |
62 |
63 |
def add_residual_to_attn(self, layer: int, token: int, w: float):
64 |
65 |
f"I{layer-1}_{token}" if layer > 0 else f"X0_{token}",
66 |
67 |
68 |
69 |
70 |
def add_ffn_edge(self, layer: int, token: int, w: float):
71 |
self._add_edge(f"A{layer}_{token}", f"M{layer}_{token}", w)
72 |
self._add_edge(f"M{layer}_{token}", f"I{layer}_{token}", w)
73 |
74 |
def add_residual_to_ffn(self, layer: int, token: int, w: float):
75 |
self._add_edge(f"A{layer}_{token}", f"I{layer}_{token}", w)
76 |
77 |
78 |
79 |
def build_full_graph(
80 |
model: TransparentLlm,
81 |
batch_i: int = 0,
82 |
renormalizing_threshold: Optional[float] = None,
83 |
) -> nx.Graph:
84 |
85 |
Build the contribution graph for all blocks of the model and all tokens.
86 |
87 |
model: The transparent llm which already did the inference.
88 |
batch_i: Which sentence to use from the batch that was given to the model.
89 |
renormalizing_threshold: If specified, will apply renormalizing thresholding to the
90 |
contributions. All contributions below the threshold will be erazed and the rest
91 |
will be renormalized.
92 |
93 |
n_layers = model.model_info().n_layers
94 |
n_tokens = model.tokens()[batch_i].shape[0]
95 |
96 |
builder = GraphBuilder(n_layers, n_tokens)
97 |
98 |
for layer in range(n_layers):
99 |
c_attn, c_resid_attn = contributions.get_attention_contributions(
100 |
101 |
102 |
decomposed_attn=model.decomposed_attn(batch_i, layer).unsqueeze(0),
103 |
104 |
if renormalizing_threshold is not None:
105 |
c_attn, c_resid_attn = contributions.apply_threshold_and_renormalize(
106 |
renormalizing_threshold, c_attn, c_resid_attn
107 |
108 |
for token_from in range(n_tokens):
109 |
for token_to in range(n_tokens):
110 |
# Sum attention contributions over heads.
111 |
c = c_attn[batch_i, token_to, token_from].sum().item()
112 |
builder.add_attention_edge(layer, token_from, token_to, c)
113 |
for token in range(n_tokens):
114 |
115 |
layer, token, c_resid_attn[batch_i, token].item()
116 |
117 |
118 |
c_ffn, c_resid_ffn = contributions.get_mlp_contributions(
119 |
120 |
121 |
122 |
123 |
if renormalizing_threshold is not None:
124 |
c_ffn, c_resid_ffn = contributions.apply_threshold_and_renormalize(
125 |
renormalizing_threshold, c_ffn, c_resid_ffn
126 |
127 |
for token in range(n_tokens):
128 |
builder.add_ffn_edge(layer, token, c_ffn[batch_i, token].item())
129 |
130 |
layer, token, c_resid_ffn[batch_i, token].item()
131 |
132 |
133 |
return builder.graph
134 |
135 |
136 |
def build_paths_to_predictions(
137 |
graph: nx.Graph,
138 |
n_layers: int,
139 |
n_tokens: int,
140 |
starting_tokens: List[int],
141 |
threshold: float,
142 |
) -> List[nx.Graph]:
143 |
144 |
Given the full graph, this function returns only the trees leading to the specified
145 |
tokens. Edges with weight below `threshold` will be ignored.
146 |
147 |
builder = GraphBuilder(n_layers, n_tokens)
148 |
149 |
rgraph = graph.reverse()
150 |
search_graph = nx.subgraph_view(
151 |
rgraph, filter_edge=lambda u, v: rgraph[u][v]["weight"] > threshold
152 |
153 |
154 |
result = []
155 |
for start in starting_tokens:
156 |
assert start < n_tokens
157 |
assert start >= 0
158 |
edges = nx.edge_dfs(search_graph, source=builder.get_output_node(start))
159 |
tree = search_graph.edge_subgraph(edges)
160 |
# Reverse the edges because the dfs was going from upper layer downwards.
161 |
162 |
163 |
return result
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
from dataclasses import dataclass
8 |
from enum import Enum
9 |
from typing import List, Optional
10 |
11 |
12 |
class NodeType(Enum):
13 |
AFTER_ATTN = "after_attn"
14 |
AFTER_FFN = "after_ffn"
15 |
FFN = "ffn"
16 |
ORIGINAL = "original" # The original tokens
17 |
18 |
19 |
def _format_block_hierachy_string(blocks: List[str]) -> str:
20 |
return " βΈ ".join(blocks)
21 |
22 |
23 |
24 |
class GraphNode:
25 |
layer: int
26 |
token: int
27 |
type: NodeType
28 |
29 |
def is_in_residual_stream(self) -> bool:
30 |
return self.type in [NodeType.AFTER_ATTN, NodeType.AFTER_FFN]
31 |
32 |
def get_residual_predecessor(self) -> Optional["GraphNode"]:
33 |
34 |
Get another graph node which points to the state of the residual stream before
35 |
this node.
36 |
37 |
Retun None if current representation is the first one in the residual stream.
38 |
39 |
scheme = {
40 |
NodeType.AFTER_ATTN: GraphNode(
41 |
layer=max(self.layer - 1, 0),
42 |
43 |
type=NodeType.AFTER_FFN if self.layer > 0 else NodeType.ORIGINAL,
44 |
45 |
NodeType.AFTER_FFN: GraphNode(
46 |
47 |
48 |
49 |
50 |
NodeType.FFN: GraphNode(
51 |
52 |
53 |
54 |
55 |
NodeType.ORIGINAL: None,
56 |
57 |
node = scheme[self.type]
58 |
if node.layer < 0:
59 |
return None
60 |
return node
61 |
62 |
def get_name(self) -> str:
63 |
return _format_block_hierachy_string(
64 |
[f"L{self.layer}", f"T{self.token}", str(self.type.value)]
65 |
66 |
67 |
def get_predecessor_block_name(self) -> str:
68 |
69 |
Return the name of the block standing between current node and its predecessor
70 |
in the residual stream.
71 |
72 |
scheme = {
73 |
NodeType.AFTER_ATTN: [f"L{self.layer}", "attn"],
74 |
NodeType.AFTER_FFN: [f"L{self.layer}", "ffn"],
75 |
NodeType.FFN: [f"L{self.layer}", "ffn"],
76 |
NodeType.ORIGINAL: ["Nothing"],
77 |
78 |
return _format_block_hierachy_string(scheme[self.type])
79 |
80 |
def get_head_name(self, head: Optional[int]) -> str:
81 |
path = [f"L{self.layer}", "attn"]
82 |
if head is not None:
83 |
84 |
return _format_block_hierachy_string(path)
85 |
86 |
def get_neuron_name(self, neuron: Optional[int]) -> str:
87 |
path = [f"L{self.layer}", "ffn"]
88 |
if neuron is not None:
89 |
90 |
return _format_block_hierachy_string(path)
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import unittest
8 |
from typing import Any, List
9 |
10 |
import torch
11 |
12 |
import llm_transparency_tool.routes.contributions as contributions
13 |
14 |
15 |
class TestContributions(unittest.TestCase):
16 |
def setUp(self):
17 |
18 |
19 |
self.eps = 1e-4
20 |
21 |
# It may be useful to run the test on GPU in case there are any issues with
22 |
# creating temporary tensors on another device. But turn this off by default.
23 |
self.test_on_gpu = False
24 |
25 |
self.device = "cuda" if self.test_on_gpu else "cpu"
26 |
27 |
self.batch = 4
28 |
self.tokens = 5
29 |
self.heads = 6
30 |
self.d_model = 10
31 |
32 |
self.decomposed_attn = torch.rand(
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
self.mlp_out = torch.rand(
41 |
self.batch, self.tokens, self.d_model, device=self.device
42 |
43 |
self.resid_pre = torch.rand(
44 |
self.batch, self.tokens, self.d_model, device=self.device
45 |
46 |
self.resid_mid = torch.rand(
47 |
self.batch, self.tokens, self.d_model, device=self.device
48 |
49 |
self.resid_post = torch.rand(
50 |
self.batch, self.tokens, self.d_model, device=self.device
51 |
52 |
53 |
def _assert_tensor_eq(self, t: torch.Tensor, expected: List[Any]):
54 |
55 |
torch.isclose(t, torch.Tensor(expected), atol=self.eps).all(),
56 |
57 |
58 |
59 |
def test_mlp_contributions(self):
60 |
mlp_out = torch.tensor([[[1.0, 1.0]]])
61 |
resid_mid = torch.tensor([[[0.0, 0.0]]])
62 |
resid_post = torch.tensor([[[1.0, 1.0]]])
63 |
64 |
c_mlp, c_residual = contributions.get_mlp_contributions(
65 |
resid_mid, resid_post, mlp_out
66 |
67 |
self.assertAlmostEqual(c_mlp.item(), 1.0, delta=self.eps)
68 |
self.assertAlmostEqual(c_residual.item(), 0.0, delta=self.eps)
69 |
70 |
def test_decomposed_attn_contributions(self):
71 |
resid_pre = torch.tensor([[[2.0, 1.0]]])
72 |
resid_mid = torch.tensor([[[2.0, 2.0]]])
73 |
decomposed_attn = torch.tensor(
74 |
75 |
76 |
77 |
78 |
[1.0, 1.0],
79 |
[-1.0, 0.0],
80 |
81 |
82 |
83 |
84 |
85 |
86 |
c_attn, c_residual = contributions.get_attention_contributions(
87 |
resid_pre, resid_mid, decomposed_attn, distance_norm=2
88 |
89 |
self._assert_tensor_eq(c_attn, [[[[0.43613, 0]]]])
90 |
self.assertAlmostEqual(c_residual.item(), 0.56387, delta=self.eps)
91 |
92 |
def test_decomposed_mlp_contributions(self):
93 |
pre = torch.tensor([10.0, 10.0])
94 |
post = torch.tensor([-10.0, 10.0])
95 |
neuron_impacts = torch.tensor(
96 |
97 |
[0.0, 1.0],
98 |
[1.0, 0.0],
99 |
[-21.0, -1.0],
100 |
101 |
102 |
c_mlp, c_residual = contributions.get_decomposed_mlp_contributions(
103 |
pre, post, neuron_impacts, distance_norm=2
104 |
105 |
# A bit counter-intuitive, but the only vector pointing from 0 towards the
106 |
# output is the first one.
107 |
self._assert_tensor_eq(c_mlp, [1, 0, 0])
108 |
self.assertAlmostEqual(c_residual, 0, delta=self.eps)
109 |
110 |
def test_decomposed_mlp_contributions_single_direction(self):
111 |
pre = torch.tensor([1.0, 1.0])
112 |
post = torch.tensor([4.0, 4.0])
113 |
neuron_impacts = torch.tensor(
114 |
115 |
[1.0, 1.0],
116 |
[2.0, 2.0],
117 |
118 |
119 |
c_mlp, c_residual = contributions.get_decomposed_mlp_contributions(
120 |
pre, post, neuron_impacts, distance_norm=2
121 |
122 |
self._assert_tensor_eq(c_mlp, [0.25, 0.5])
123 |
self.assertAlmostEqual(c_residual, 0.25, delta=self.eps)
124 |
125 |
def test_attention_contributions_shape(self):
126 |
c_attn, c_residual = contributions.get_attention_contributions(
127 |
self.resid_pre, self.resid_mid, self.decomposed_attn
128 |
129 |
130 |
list(c_attn.shape), [self.batch, self.tokens, self.tokens, self.heads]
131 |
132 |
self.assertEqual(list(c_residual.shape), [self.batch, self.tokens])
133 |
134 |
def test_mlp_contributions_shape(self):
135 |
c_mlp, c_residual = contributions.get_mlp_contributions(
136 |
self.resid_mid, self.resid_post, self.mlp_out
137 |
138 |
self.assertEqual(list(c_mlp.shape), [self.batch, self.tokens])
139 |
self.assertEqual(list(c_residual.shape), [self.batch, self.tokens])
140 |
141 |
def test_renormalizing_threshold(self):
142 |
c_blocks = torch.Tensor([[0.05, 0.15], [0.05, 0.05]])
143 |
c_residual = torch.Tensor([0.8, 0.9])
144 |
norm_blocks, norm_residual = contributions.apply_threshold_and_renormalize(
145 |
0.1, c_blocks, c_residual
146 |
147 |
self._assert_tensor_eq(norm_blocks, [[0.0, 0.157894], [0.0, 0.0]])
148 |
self._assert_tensor_eq(norm_residual, [0.842105, 1.0])
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import argparse
8 |
from dataclasses import dataclass, field
9 |
from typing import Dict, List, Optional, Tuple
10 |
11 |
import networkx as nx
12 |
import pandas as pd
13 |
14 |
import plotly.graph_objects as go
15 |
import streamlit as st
16 |
import streamlit_extras.row as st_row
17 |
import torch
18 |
from jaxtyping import Float
19 |
from torch.amp import autocast
20 |
from transformers import HfArgumentParser
21 |
22 |
import llm_transparency_tool.components
23 |
from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm
24 |
import llm_transparency_tool.routes.contributions as contributions
25 |
import llm_transparency_tool.routes.graph
26 |
from llm_transparency_tool.models.transparent_llm import TransparentLlm
27 |
from llm_transparency_tool.routes.graph_node import NodeType
28 |
from llm_transparency_tool.server.graph_selection import (
29 |
30 |
31 |
32 |
33 |
from llm_transparency_tool.server.styles import (
34 |
35 |
36 |
37 |
38 |
39 |
from llm_transparency_tool.server.utils import (
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
from llm_transparency_tool.server.monitor import SystemMonitor
49 |
50 |
from networkx.classes.digraph import DiGraph
51 |
52 |
53 |
54 |
55 |
nx.Graph: id,
56 |
DiGraph: id
57 |
58 |
59 |
def cached_build_paths_to_predictions(
60 |
graph: nx.Graph,
61 |
n_layers: int,
62 |
n_tokens: int,
63 |
starting_tokens: List[int],
64 |
threshold: float,
65 |
66 |
return llm_transparency_tool.routes.graph.build_paths_to_predictions(
67 |
graph, n_layers, n_tokens, starting_tokens, threshold
68 |
69 |
70 |
71 |
72 |
TransformerLensTransparentLlm: id
73 |
74 |
75 |
def cached_run_inference_and_populate_state(
76 |
77 |
78 |
79 |
stateful_model = stateless_model.copy()
80 |
81 |
return stateful_model
82 |
83 |
84 |
85 |
class LlmViewerConfig:
86 |
debug: bool = field(
87 |
88 |
metadata={"help": "Show debugging information, like the time profile."},
89 |
90 |
91 |
preloaded_dataset_filename: Optional[str] = field(
92 |
93 |
metadata={"help": "The name of the text file to load the lines from."},
94 |
95 |
96 |
demo_mode: bool = field(
97 |
98 |
metadata={"help": "Whether the app should be in the demo mode."},
99 |
100 |
101 |
allow_loading_dataset_files: bool = field(
102 |
103 |
metadata={"help": "Whether the app should be able to load the dataset files " "on the server side."},
104 |
105 |
106 |
max_user_string_length: Optional[int] = field(
107 |
108 |
109 |
"help": "Limit for the length of user-provided sentences (in characters), " "or None if there is no limit."
110 |
111 |
112 |
113 |
models: Dict[str, str] = field(
114 |
115 |
116 |
"help": "Locations of models which are stored locally. Dictionary: official "
117 |
"HuggingFace name -> path to dir. If None is specified, the model will be"
118 |
"downloaded from HuggingFace."
119 |
120 |
121 |
122 |
default_model: str = field(
123 |
124 |
metadata={"help": "The model to load once the UI is started."},
125 |
126 |
127 |
128 |
class App:
129 |
_stateful_model: TransparentLlm = None
130 |
render_settings = RenderSettings()
131 |
_graph: Optional[nx.Graph] = None
132 |
_contribution_threshold: float = 0.0
133 |
_renormalize_after_threshold: bool = False
134 |
_normalize_before_unembedding: bool = True
135 |
136 |
137 |
def stateful_model(self) -> TransparentLlm:
138 |
return self._stateful_model
139 |
140 |
def __init__(self, config: LlmViewerConfig):
141 |
self._config = config
142 |
143 |
st.markdown(margins_css, unsafe_allow_html=True)
144 |
145 |
def _get_representation(self, node: Optional[UiGraphNode]) -> Optional[Float[torch.Tensor, "d_model"]]:
146 |
if node is None:
147 |
return None
148 |
fn = {
149 |
NodeType.AFTER_ATTN: self.stateful_model.residual_after_attn,
150 |
NodeType.AFTER_FFN: self.stateful_model.residual_out,
151 |
NodeType.FFN: None,
152 |
NodeType.ORIGINAL: self.stateful_model.residual_in,
153 |
154 |
return fn[node.type](node.layer)[B0][node.token]
155 |
156 |
def draw_model_info(self):
157 |
info = self.stateful_model.model_info().__dict__
158 |
df = pd.DataFrame(
159 |
data=[str(x) for x in info.values()],
160 |
161 |
columns=["Model parameter"],
162 |
163 |
st.dataframe(df, use_container_width=False)
164 |
165 |
def draw_dataset_selection(self) -> int:
166 |
def update_dataset(filename: Optional[str]):
167 |
dataset = load_dataset(filename) if filename is not None else []
168 |
st.session_state["dataset"] = dataset
169 |
st.session_state["dataset_file"] = filename
170 |
171 |
if "dataset" not in st.session_state:
172 |
173 |
174 |
175 |
if not self._config.demo_mode:
176 |
if self._config.allow_loading_dataset_files:
177 |
row_f = st_row.row([2, 1], vertical_align="bottom")
178 |
filename = row_f.text_input("Dataset", value=st.session_state.dataset_file or "")
179 |
if row_f.button("Load"):
180 |
181 |
row_s = st_row.row([2, 1], vertical_align="bottom")
182 |
new_sentence = row_s.text_input("New sentence")
183 |
new_sentence_added = False
184 |
185 |
if row_s.button("Add"):
186 |
max_len = self._config.max_user_string_length
187 |
n = len(new_sentence)
188 |
if max_len is None or n <= max_len:
189 |
190 |
new_sentence_added = True
191 |
st.session_state.sentence_selector = new_sentence
192 |
193 |
st.warning(f"Sentence length {n} is larger than " f"the configured limit of {max_len}")
194 |
195 |
sentences = st.session_state.dataset
196 |
selection = st.selectbox(
197 |
198 |
199 |
index=len(sentences) - 1,
200 |
201 |
202 |
return selection
203 |
204 |
def _unembed(
205 |
206 |
representation: torch.Tensor,
207 |
) -> torch.Tensor:
208 |
return self.stateful_model.unembed(representation, normalize=self._normalize_before_unembedding)
209 |
210 |
def draw_graph(self, contribution_threshold: float) -> Optional[GraphSelection]:
211 |
tokens = self.stateful_model.tokens()[B0]
212 |
n_tokens = tokens.shape[0]
213 |
model_info = self.stateful_model.model_info()
214 |
215 |
graphs = cached_build_paths_to_predictions(
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
return llm_transparency_tool.components.contribution_graph(
224 |
225 |
226 |
227 |
228 |
229 |
230 |
def draw_token_matrix(
231 |
232 |
values: Float[torch.Tensor, "t t"],
233 |
tokens: List[str],
234 |
value_name: str,
235 |
title: str,
236 |
237 |
assert values.shape[0] == len(tokens)
238 |
labels = {
239 |
"x": "<b>src</b>",
240 |
"y": "<b>tgt</b>",
241 |
"color": value_name,
242 |
243 |
244 |
captions = [f"({i}){t}" for i, t in enumerate(tokens)]
245 |
246 |
fig =
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
l=50, # left margin
259 |
r=0, # right margin
260 |
b=100, # bottom margin
261 |
t=100, # top margin
262 |
# pad=10 # padding
263 |
264 |
265 |
266 |
267 |
268 |
269 |
st.plotly_chart(fig, use_container_width=True, theme=None)
270 |
271 |
def draw_attn_info(self, edge: UiGraphEdge, container_attention_map) -> Optional[int]:
272 |
273 |
Returns: the index of the selected head.
274 |
275 |
276 |
n_heads = self.stateful_model.model_info().n_heads
277 |
278 |
layer =
279 |
280 |
head_contrib, _ = contributions.get_attention_contributions(
281 |
282 |
283 |
decomposed_attn=self.stateful_model.decomposed_attn(B0, layer).unsqueeze(0),
284 |
285 |
286 |
# [batch pos key_pos head] -> [head]
287 |
flat_contrib = head_contrib[0,, edge.source.token, :]
288 |
assert flat_contrib.shape[0] == n_heads, f"{flat_contrib.shape} vs {n_heads}"
289 |
290 |
selected_head = llm_transparency_tool.components.selector(
291 |
items=[f"H{h}" if h >= 0 else "All" for h in range(-1, n_heads)],
292 |
indices=range(-1, n_heads),
293 |
temperatures=[sum(flat_contrib).item()] + flat_contrib.tolist(),
294 |
295 |
key=f"head_selector_layer_{layer}" #_from_tok_{edge.source.token}_to_tok_{}",
296 |
297 |
298 |
if selected_head == -1 or selected_head is None:
299 |
# selected_head = None
300 |
selected_head = flat_contrib.argmax().item()
301 |
print('****\n' * 3 + f"selected_head: {selected_head}" + '\n****\n' * 3)
302 |
303 |
# Draw attention matrix and contributions for the selected head.
304 |
if selected_head is not None:
305 |
tokens = [
306 |
string_to_display(s) for s in self.stateful_model.tokens_to_strings(self.stateful_model.tokens()[B0])
307 |
308 |
309 |
with container_attention_map:
310 |
attn_container, contrib_container = st.columns([1, 1])
311 |
with attn_container:
312 |
attn = self.stateful_model.attention_matrix(B0, layer, selected_head)
313 |
314 |
315 |
316 |
317 |
f"Attention map L{layer} H{selected_head}",
318 |
319 |
with contrib_container:
320 |
contrib = head_contrib[B0, :, :, selected_head]
321 |
322 |
323 |
324 |
325 |
f"Contribution map L{layer} H{selected_head}",
326 |
327 |
328 |
return selected_head
329 |
330 |
def draw_ffn_info(self, node: UiGraphNode) -> Optional[int]:
331 |
332 |
Returns: the index of the selected neuron.
333 |
334 |
335 |
resid_mid = self.stateful_model.residual_after_attn(node.layer)[B0][node.token]
336 |
resid_post = self.stateful_model.residual_out(node.layer)[B0][node.token]
337 |
decomposed_ffn = self.stateful_model.decomposed_ffn_out(B0, node.layer, node.token)
338 |
c_ffn, _ = contributions.get_decomposed_mlp_contributions(resid_mid, resid_post, decomposed_ffn)
339 |
340 |
top_values, top_i = c_ffn.sort(descending=True)
341 |
n = min(self.render_settings.n_top_neurons, c_ffn.shape[0])
342 |
top_neurons = top_i[0:n].tolist()
343 |
344 |
selected_neuron = llm_transparency_tool.components.selector(
345 |
items=[f"{top_neurons[i]}" if i >= 0 else "All" for i in range(-1, n)],
346 |
indices=range(-1, n),
347 |
temperatures=[0.0] + top_values[0:n].tolist(),
348 |
349 |
350 |
351 |
if selected_neuron is None:
352 |
selected_neuron = -1
353 |
selected_neuron = None if selected_neuron == -1 else top_neurons[selected_neuron]
354 |
355 |
return selected_neuron
356 |
357 |
def _draw_token_table(
358 |
359 |
n_top: int,
360 |
n_bottom: int,
361 |
representation: torch.Tensor,
362 |
predecessor: Optional[torch.Tensor] = None,
363 |
364 |
n_total = n_top + n_bottom
365 |
366 |
logits = self._unembed(representation)
367 |
n_vocab = logits.shape[0]
368 |
scores, indices = torch.topk(logits, n_top, largest=True)
369 |
positions = list(range(n_top))
370 |
371 |
if n_bottom > 0:
372 |
low_scores, low_indices = torch.topk(logits, n_bottom, largest=False)
373 |
indices =, low_indices.flip(0)))
374 |
scores =, low_scores.flip(0)))
375 |
positions += range(n_vocab - n_bottom, n_vocab)
376 |
377 |
tokens = [string_to_display(w) for w in self.stateful_model.tokens_to_strings(indices)]
378 |
379 |
if predecessor is not None:
380 |
pre_logits = self._unembed(predecessor)
381 |
_, sorted_pre_indices = pre_logits.sort(descending=True)
382 |
pre_indices_dict = {index: pos for pos, index in enumerate(sorted_pre_indices.tolist())}
383 |
old_positions = [pre_indices_dict[i] for i in indices.tolist()]
384 |
385 |
def pos_gain_string(pos, old_pos):
386 |
if pos == old_pos:
387 |
return ""
388 |
sign = "β" if pos > old_pos else "β"
389 |
return f"({sign}{abs(pos - old_pos)})"
390 |
391 |
position_strings = [f"{i} {pos_gain_string(i, old_i)}" for (i, old_i) in zip(positions, old_positions)]
392 |
393 |
position_strings = [str(pos) for pos in positions]
394 |
395 |
def pos_gain_color(s):
396 |
color = "black"
397 |
if isinstance(s, str):
398 |
if "β" in s:
399 |
color = "red"
400 |
if "β" in s:
401 |
color = "green"
402 |
return f"color: {color}"
403 |
404 |
top_df = pd.DataFrame(
405 |
data=zip(position_strings, tokens, scores.tolist()),
406 |
columns=["Pos", "Token", "Score"],
407 |
408 |
409 |
410 |
411 |
412 |
413 |
cmap=logits_color_map(positive_and_negative=n_bottom > 0),
414 |
415 |
416 |
417 |
height=self.render_settings.table_cell_height * (n_total + 1),
418 |
419 |
420 |
421 |
def draw_token_dynamics(self, representation: torch.Tensor, block_name: str) -> None:
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
def draw_top_tokens(
431 |
432 |
node: UiGraphNode,
433 |
434 |
435 |
) -> None:
436 |
pre_node = node.get_residual_predecessor()
437 |
if pre_node is None:
438 |
439 |
440 |
representation = self._get_representation(node)
441 |
predecessor = self._get_representation(pre_node)
442 |
443 |
with container_top_tokens:
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
if container_token_dynamics is not None:
452 |
with container_token_dynamics:
453 |
self.draw_token_dynamics(representation - predecessor, node.get_predecessor_block_name())
454 |
455 |
def draw_attention_dynamics(self, node: UiGraphNode, head: Optional[int]):
456 |
block_name = node.get_head_name(head)
457 |
block_output = (
458 |
self.stateful_model.attention_output_per_head(B0, node.layer, node.token, head)
459 |
if head is not None
460 |
else self.stateful_model.attention_output(B0, node.layer, node.token)
461 |
462 |
self.draw_token_dynamics(block_output, block_name)
463 |
464 |
def draw_ffn_dynamics(self, node: UiGraphNode, neuron: Optional[int]):
465 |
block_name = node.get_neuron_name(neuron)
466 |
block_output = (
467 |
self.stateful_model.neuron_output(node.layer, neuron)
468 |
if neuron is not None
469 |
else self.stateful_model.ffn_out(node.layer)[B0][node.token]
470 |
471 |
self.draw_token_dynamics(block_output, block_name)
472 |
473 |
def draw_precision_controls(self, device: str) -> Tuple[torch.dtype, bool]:
474 |
475 |
Draw fp16/fp32 switch and AMP control.
476 |
477 |
return: The selected precision and whether AMP should be enabled.
478 |
479 |
480 |
if device == "cpu":
481 |
dtype = torch.float32
482 |
483 |
dtype = st.selectbox(
484 |
485 |
[torch.float16, torch.bfloat16, torch.float32],
486 |
487 |
488 |
489 |
amp_enabled = dtype != torch.float32
490 |
491 |
return dtype, amp_enabled
492 |
493 |
def draw_controls(self):
494 |
# model_container, data_container = st.columns([1, 1])
495 |
with st.sidebar.expander("Model", expanded=True):
496 |
list_of_devices = possible_devices()
497 |
if len(list_of_devices) > 1:
498 |
self.device = st.selectbox(
499 |
500 |
501 |
502 |
503 |
504 |
self.device = list_of_devices[0]
505 |
506 |
self.dtype, self.amp_enabled = self.draw_precision_controls(self.device)
507 |
508 |
model_list = list(self._config.models)
509 |
default_choice = model_list.index(self._config.default_model)
510 |
511 |
self.model_name = st.selectbox(
512 |
513 |
514 |
515 |
516 |
517 |
if self.model_name:
518 |
self._stateful_model = load_model(
519 |
520 |
521 |
522 |
523 |
524 |
self.model_key = self.model_name # TODO maybe something else?
525 |
526 |
527 |
self.sentence = self.draw_dataset_selection()
528 |
529 |
with st.sidebar.expander("Graph", expanded=True):
530 |
self._contribution_threshold = st.slider(
531 |
532 |
533 |
534 |
535 |
536 |
label="Contribution threshold",
537 |
538 |
self._renormalize_after_threshold = st.checkbox("Renormalize after threshold", value=True)
539 |
self._normalize_before_unembedding = st.checkbox("Normalize before unembedding", value=True)
540 |
541 |
def run_inference(self):
542 |
543 |
with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype):
544 |
self._stateful_model = cached_run_inference_and_populate_state(self.stateful_model, [self.sentence])
545 |
546 |
with autocast(enabled=self.amp_enabled, device_type="cuda", dtype=self.dtype):
547 |
self._graph = get_contribution_graph(
548 |
549 |
550 |
551 |
(self._contribution_threshold if self._renormalize_after_threshold else 0.0),
552 |
553 |
554 |
def draw_graph_and_selection(
555 |
556 |
) -> None:
557 |
558 |
559 |
560 |
) = st.columns(self.render_settings.column_proportions)
561 |
562 |
container_graph_left, container_graph_right = container_graph.columns([5, 1])
563 |
564 |
container_graph_left.write('##### Graph')
565 |
heads_placeholder = container_graph_right.empty()
566 |
heads_placeholder.write('##### Blocks')
567 |
container_graph_right_used = False
568 |
569 |
container_top_tokens, container_token_dynamics = container_tokens.columns([1, 1])
570 |
container_top_tokens.write('##### Top Tokens')
571 |
container_top_tokens_used = False
572 |
container_token_dynamics.write('##### Promoted Tokens')
573 |
container_token_dynamics_used = False
574 |
575 |
576 |
577 |
if self.sentence is None:
578 |
579 |
580 |
with container_graph_left:
581 |
selection = self.draw_graph(self._contribution_threshold if not self._renormalize_after_threshold else 0.0)
582 |
583 |
if selection is None:
584 |
585 |
586 |
node = selection.node
587 |
edge = selection.edge
588 |
589 |
if edge is not None and == NodeType.AFTER_ATTN:
590 |
with container_graph_right:
591 |
container_graph_right_used = True
592 |
heads_placeholder.write('##### Heads')
593 |
head = self.draw_attn_info(edge, container_graph)
594 |
with container_token_dynamics:
595 |
self.draw_attention_dynamics(, head)
596 |
container_token_dynamics_used = True
597 |
elif node is not None and node.type == NodeType.FFN:
598 |
with container_graph_right:
599 |
container_graph_right_used = True
600 |
heads_placeholder.write('##### Neurons')
601 |
neuron = self.draw_ffn_info(node)
602 |
with container_token_dynamics:
603 |
self.draw_ffn_dynamics(node, neuron)
604 |
container_token_dynamics_used = True
605 |
606 |
if node is not None and node.is_in_residual_stream():
607 |
608 |
609 |
610 |
container_token_dynamics if not container_token_dynamics_used else None,
611 |
612 |
container_top_tokens_used = True
613 |
container_token_dynamics_used = True
614 |
615 |
if not container_graph_right_used:
616 |
st_placeholder('Click on an edge to see head contributions. \n\n'
617 |
'Or click on FFN to see individual neuron contributions.', container_graph_right, height=1100)
618 |
if not container_top_tokens_used:
619 |
st_placeholder('Select a node from residual stream to see its top tokens.', container_top_tokens, height=1100)
620 |
if not container_token_dynamics_used:
621 |
st_placeholder('Select a node to see its promoted tokens.', container_token_dynamics, height=1100)
622 |
623 |
624 |
def run(self):
625 |
626 |
with st.sidebar.expander("About", expanded=True):
627 |
if self._config.demo_mode:
628 |
629 |
The app is deployed in Demo Mode, thus only predefined models and inputs are available.\n
630 |
You can still install the app locally and use your own models and inputs.\n
631 |
See for more information.
632 |
633 |
634 |
635 |
636 |
if not self.model_name:
637 |
st.warning("No model selected")
638 |
639 |
640 |
if self.sentence is None:
641 |
st.warning("No sentence selected")
642 |
643 |
with torch.inference_mode():
644 |
645 |
646 |
647 |
648 |
649 |
if __name__ == "__main__":
650 |
top_parser = argparse.ArgumentParser()
651 |
652 |
args = top_parser.parse_args()
653 |
654 |
parser = HfArgumentParser([LlmViewerConfig])
655 |
config = parser.parse_json_file(args.config_file)[0]
656 |
657 |
with SystemMonitor(config.debug) as prof:
658 |
app = App(config)
659 |
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
from dataclasses import dataclass
8 |
from typing import Any, Dict, Optional
9 |
10 |
from llm_transparency_tool.routes.graph_node import GraphNode, NodeType
11 |
12 |
13 |
class UiGraphNode(GraphNode):
14 |
15 |
def from_json(json: Dict[str, Any]) -> Optional["UiGraphNode"]:
16 |
17 |
layer = json["cell"]["layer"]
18 |
token = json["cell"]["token"]
19 |
type = NodeType(json["item"])
20 |
return UiGraphNode(layer, token, type)
21 |
except (TypeError, KeyError):
22 |
return None
23 |
24 |
25 |
26 |
class UiGraphEdge:
27 |
source: UiGraphNode
28 |
target: UiGraphNode
29 |
weight: float
30 |
31 |
32 |
def from_json(json: Dict[str, Any]) -> Optional["UiGraphEdge"]:
33 |
34 |
source = UiGraphNode.from_json(json["from"])
35 |
target = UiGraphNode.from_json(json["to"])
36 |
if source is None or target is None:
37 |
return None
38 |
weight = float(json["weight"])
39 |
return UiGraphEdge(source, target, weight)
40 |
except (TypeError, KeyError):
41 |
return None
42 |
43 |
44 |
45 |
class GraphSelection:
46 |
node: Optional[UiGraphNode]
47 |
edge: Optional[UiGraphEdge]
48 |
49 |
50 |
def from_json(json: Dict[str, Any]) -> Optional["GraphSelection"]:
51 |
52 |
node = UiGraphNode.from_json(json["node"])
53 |
edge = UiGraphEdge.from_json(json["edge"])
54 |
return GraphSelection(node, edge)
55 |
except (TypeError, KeyError):
56 |
return None
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import torch
8 |
import streamlit as st
9 |
from pyinstrument import Profiler
10 |
from typing import Dict
11 |
import pandas as pd
12 |
13 |
14 |
@st.cache_resource(max_entries=1, show_spinner=False)
15 |
def init_gpu_memory():
16 |
17 |
When CUDA is initialized, it occupies some memory on the GPU thus this overhead
18 |
can sometimes make it difficult to understand how much memory is actually used by
19 |
the model.
20 |
21 |
This function is used to initialize CUDA and measure the overhead.
22 |
23 |
if not torch.cuda.is_available():
24 |
return {}
25 |
26 |
# lets init torch gpu for a moment
27 |
gpu_memory_overhead = {}
28 |
for i in range(torch.cuda.device_count()):
29 |
30 |
free, total = torch.cuda.mem_get_info(i)
31 |
occupied = total - free
32 |
gpu_memory_overhead[i] = occupied
33 |
34 |
return gpu_memory_overhead
35 |
36 |
37 |
class SystemMonitor:
38 |
39 |
This class is used to monitor the system resources such as GPU memory and CPU
40 |
usage. It uses the pyinstrument library to profile the code and measure the
41 |
execution time of different parts of the code.
42 |
43 |
44 |
def __init__(
45 |
46 |
enabled: bool = False,
47 |
48 |
self.enabled = enabled
49 |
self.profiler = Profiler()
50 |
self.overhead: Dict[int, int]
51 |
52 |
def __enter__(self):
53 |
if not self.enabled:
54 |
55 |
56 |
self.overhead = init_gpu_memory()
57 |
58 |
59 |
60 |
def __exit__(self, exc_type, exc_value, traceback):
61 |
if not self.enabled:
62 |
63 |
64 |
self.profiler.__exit__(exc_type, exc_value, traceback)
65 |
66 |
67 |
68 |
69 |
with st.expander("Session state"):
70 |
71 |
72 |
return None
73 |
74 |
def report_gpu_usage(self):
75 |
76 |
if not torch.cuda.is_available():
77 |
78 |
79 |
data = []
80 |
81 |
for i in range(torch.cuda.device_count()):
82 |
free, total = torch.cuda.mem_get_info(i)
83 |
occupied = total - free
84 |
85 |
'overhead': self.overhead[i],
86 |
'occupied': occupied - self.overhead[i],
87 |
'free': free,
88 |
89 |
df = pd.DataFrame(data, columns=["overhead", "occupied", "free"])
90 |
91 |
with st.sidebar.expander("System"):
92 |
st.write("GPU memory on server")
93 |
df /= 1024 ** 3 # Convert to GB
94 |
st.bar_chart(df, width=200, height=200, color=["#fefefe", "#84c9ff", "#fe2b2b"])
95 |
96 |
def report_profiler(self):
97 |
html_code = self.profiler.output_html()
98 |
with st.expander("Profiler", expanded=False):
99 |
st.components.v1.html(html_code, height=1000, scrolling=True)
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
from dataclasses import dataclass
8 |
9 |
import matplotlib
10 |
11 |
# Unofficial way do make the padding a bit smaller.
12 |
margins_css = """
13 |
14 |
.main > div {
15 |
padding: 1rem;
16 |
padding-top: 2rem; # Still need this gap for the top bar
17 |
gap: 0rem;
18 |
19 |
20 |
section[data-testid="stSidebar"] {
21 |
width: 300px !important; # Set the width to your desired value
22 |
23 |
24 |
25 |
26 |
27 |
28 |
class RenderSettings:
29 |
column_proportions = [50, 30]
30 |
31 |
# We don't know the actual height. This will be used in order to compute the table
32 |
# viewport height when needed.
33 |
table_cell_height = 36
34 |
35 |
n_top_tokens = 30
36 |
n_promoted_tokens = 15
37 |
n_suppressed_tokens = 15
38 |
39 |
n_top_neurons = 20
40 |
41 |
attention_color_map = "Blues"
42 |
43 |
no_model_alt_text = "<no model selected>"
44 |
45 |
46 |
def string_to_display(s: str) -> str:
47 |
return s.replace(" ", "Β·")
48 |
49 |
50 |
def logits_color_map(positive_and_negative: bool) -> matplotlib.colors.Colormap:
51 |
background_colors = {
52 |
"red": [
53 |
[0.0, 0.40, 0.40],
54 |
[0.1, 0.69, 0.69],
55 |
[0.2, 0.83, 0.83],
56 |
[0.3, 0.95, 0.95],
57 |
[0.4, 0.99, 0.99],
58 |
[0.5, 1.0, 1.0],
59 |
[0.6, 0.90, 0.90],
60 |
[0.7, 0.72, 0.72],
61 |
[0.8, 0.49, 0.49],
62 |
[0.9, 0.30, 0.30],
63 |
[1.0, 0.15, 0.15],
64 |
65 |
"green": [
66 |
[0.0, 0.0, 0.0],
67 |
[0.1, 0.09, 0.09],
68 |
[0.2, 0.37, 0.37],
69 |
[0.3, 0.64, 0.64],
70 |
[0.4, 0.85, 0.85],
71 |
[0.5, 1.0, 1.0],
72 |
[0.6, 0.96, 0.96],
73 |
[0.7, 0.88, 0.88],
74 |
[0.8, 0.73, 0.73],
75 |
[0.9, 0.57, 0.57],
76 |
[1.0, 0.39, 0.39],
77 |
78 |
"blue": [
79 |
[0.0, 0.12, 0.12],
80 |
[0.1, 0.16, 0.16],
81 |
[0.2, 0.30, 0.30],
82 |
[0.3, 0.50, 0.50],
83 |
[0.4, 0.78, 0.78],
84 |
[0.5, 1.0, 1.0],
85 |
[0.6, 0.81, 0.81],
86 |
[0.7, 0.52, 0.52],
87 |
[0.8, 0.25, 0.25],
88 |
[0.9, 0.12, 0.12],
89 |
[1.0, 0.09, 0.09],
90 |
91 |
92 |
93 |
if not positive_and_negative:
94 |
# Stretch the top part to the whole range
95 |
new_colors = {}
96 |
for channel, colors in background_colors.items():
97 |
new_colors[channel] = [
98 |
[(value - 0.5) * 2, color, color]
99 |
for value, color, _ in colors
100 |
if value >= 0.5
101 |
102 |
background_colors = new_colors
103 |
104 |
return matplotlib.colors.LinearSegmentedColormap(
105 |
106 |
107 |
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
import uuid
8 |
from typing import List, Optional, Tuple
9 |
10 |
import networkx as nx
11 |
import streamlit as st
12 |
import torch
13 |
import transformers
14 |
15 |
import llm_transparency_tool.routes.graph
16 |
from llm_transparency_tool.models.tlens_model import TransformerLensTransparentLlm
17 |
from llm_transparency_tool.models.transparent_llm import TransparentLlm
18 |
19 |
GPU = "gpu"
20 |
CPU = "cpu"
21 |
22 |
# This variable is for expressing the idea that batch_id = 0, but make it more
23 |
# readable than just 0.
24 |
B0 = 0
25 |
26 |
27 |
def possible_devices() -> List[str]:
28 |
devices = []
29 |
if torch.cuda.is_available():
30 |
31 |
32 |
return devices
33 |
34 |
35 |
def load_dataset(filename) -> List[str]:
36 |
with open(filename) as f:
37 |
dataset = [s.strip("\n") for s in f.readlines()]
38 |
print(f"Loaded {len(dataset)} sentences from {filename}")
39 |
return dataset
40 |
41 |
42 |
43 |
44 |
TransformerLensTransparentLlm: id
45 |
46 |
47 |
def load_model(
48 |
model_name: str,
49 |
_device: str,
50 |
_model_path: Optional[str] = None,
51 |
_dtype: torch.dtype = torch.float32,
52 |
) -> TransparentLlm:
53 |
54 |
Returns the loaded model along with its key. The key is just a unique string which
55 |
can be used later to identify if the model has changed.
56 |
57 |
assert _device in possible_devices()
58 |
59 |
causal_lm = None
60 |
tokenizer = None
61 |
62 |
tl_lm = TransformerLensTransparentLlm(
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
return tl_lm
71 |
72 |
73 |
def run_model(model: TransparentLlm, sentence: str) -> None:
74 |
print(f"Running inference for '{sentence}'")
75 |
76 |
77 |
78 |
def load_model_with_session_caching(
79 |
80 |
) -> Tuple[TransparentLlm, str]:
81 |
return load_model(**kwargs)
82 |
83 |
def run_model_with_session_caching(
84 |
_model: TransparentLlm,
85 |
model_key: str,
86 |
sentence: str,
87 |
88 |
LAST_RUN_MODEL_KEY = "last_run_model_key"
89 |
LAST_RUN_SENTENCE = "last_run_sentence"
90 |
state = st.session_state
91 |
92 |
if (
93 |
state.get(LAST_RUN_MODEL_KEY, None) == model_key
94 |
and state.get(LAST_RUN_SENTENCE, None) == sentence
95 |
96 |
97 |
98 |
run_model(_model, sentence)
99 |
state[LAST_RUN_MODEL_KEY] = model_key
100 |
state[LAST_RUN_SENTENCE] = sentence
101 |
102 |
103 |
104 |
105 |
TransformerLensTransparentLlm: id
106 |
107 |
108 |
def get_contribution_graph(
109 |
model: TransparentLlm, # TODO bug here
110 |
model_key: str,
111 |
tokens: List[str],
112 |
threshold: float,
113 |
) -> nx.Graph:
114 |
115 |
The `model_key` and `tokens` are used only for caching. The model itself is not
116 |
hashed, hence the `_` in the beginning.
117 |
118 |
return llm_transparency_tool.routes.graph.build_full_graph(
119 |
120 |
121 |
122 |
123 |
124 |
125 |
def st_placeholder(
126 |
text: str,
127 |
128 |
border: bool = True,
129 |
height: Optional[int] = 500,
130 |
131 |
empty = container.empty()
132 |
empty.container(border=border, height=height).write(f'<small>{text}</small>', unsafe_allow_html=True)
133 |
return empty
1 |
2 |
line-length = 120
1 |
The war lasted from the year 1732 to the year 17
2 |
5 + 4 = 9, 2 + 3 =
3 |
When Mary and John went to the store, John gave a drink to
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
2 |
# All rights reserved.
3 |
4 |
# This source code is licensed under the license found in the
5 |
# LICENSE file in the root directory of this source tree.
6 |
7 |
from setuptools import setup
8 |
9 |
10 |
11 |
12 |
13 |