Update app.py
Browse files
app.py
CHANGED
@@ -29,11 +29,12 @@ class GlobalState:
|
|
29 |
local_state : LocalState = LocalState()
|
30 |
wait_with_hidden_state : bool = False
|
31 |
interpretation_prompt_template : str = '{prompt}'
|
32 |
-
original_prompt_template : str = 'User: [X]\n\
|
33 |
layers_format : str = 'model.layers.{k}'
|
34 |
|
35 |
|
36 |
suggested_interpretation_prompts = [
|
|
|
37 |
"Sure, here's a bullet list of the key words in your message:",
|
38 |
"Sure, I'll summarize your message:",
|
39 |
"Sure, here are the words in your message:",
|
@@ -139,7 +140,7 @@ global_state = GlobalState()
|
|
139 |
|
140 |
model_name = 'LLAMA2-7B'
|
141 |
reset_model(model_name, with_extra_components=False)
|
142 |
-
|
143 |
tokens_container = []
|
144 |
|
145 |
for i in range(MAX_PROMPT_TOKENS):
|
@@ -185,17 +186,17 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
185 |
dataset = dataset.filter(info['filter'])
|
186 |
dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
|
187 |
dataset = [[row[info['text_col']]] for row in dataset]
|
188 |
-
gr.Examples(dataset, [
|
189 |
|
190 |
with gr.Group():
|
191 |
-
|
192 |
original_prompt_btn = gr.Button('Output Token List', variant='primary')
|
193 |
|
194 |
gr.Markdown('## Choose Your Interpretation Prompt')
|
195 |
with gr.Group('Interpretation'):
|
196 |
-
|
197 |
interpretation_prompt_examples = gr.Examples([[p] for p in suggested_interpretation_prompts],
|
198 |
-
[
|
199 |
|
200 |
with gr.Accordion(open=False, label='Generation Settings'):
|
201 |
with gr.Row():
|
@@ -225,17 +226,17 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
225 |
|
226 |
# event listeners
|
227 |
for i, btn in enumerate(tokens_container):
|
228 |
-
btn.click(partial(run_interpretation, i=i), [
|
229 |
num_tokens, do_sample, temperature,
|
230 |
top_k, top_p, repetition_penalty, length_penalty
|
231 |
], [progress_dummy, *interpretation_bubbles])
|
232 |
|
233 |
original_prompt_btn.click(get_hidden_states,
|
234 |
-
[
|
235 |
[progress_dummy, *tokens_container, *interpretation_bubbles])
|
236 |
-
|
237 |
|
238 |
-
extra_components = [
|
239 |
model_chooser.change(reset_model, [model_chooser, *extra_components],
|
240 |
[welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
|
241 |
|
|
|
29 |
local_state : LocalState = LocalState()
|
30 |
wait_with_hidden_state : bool = False
|
31 |
interpretation_prompt_template : str = '{prompt}'
|
32 |
+
original_prompt_template : str = 'User: [X]\n\nAssistant: {prompt}'
|
33 |
layers_format : str = 'model.layers.{k}'
|
34 |
|
35 |
|
36 |
suggested_interpretation_prompts = [
|
37 |
+
"The meaning of [X] is",
|
38 |
"Sure, here's a bullet list of the key words in your message:",
|
39 |
"Sure, I'll summarize your message:",
|
40 |
"Sure, here are the words in your message:",
|
|
|
140 |
|
141 |
model_name = 'LLAMA2-7B'
|
142 |
reset_model(model_name, with_extra_components=False)
|
143 |
+
raw_original_prompt = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
144 |
tokens_container = []
|
145 |
|
146 |
for i in range(MAX_PROMPT_TOKENS):
|
|
|
186 |
dataset = dataset.filter(info['filter'])
|
187 |
dataset = dataset.shuffle(buffer_size=2000).take(num_examples)
|
188 |
dataset = [[row[info['text_col']]] for row in dataset]
|
189 |
+
gr.Examples(dataset, [raw_original_prompt], cache_examples=False)
|
190 |
|
191 |
with gr.Group():
|
192 |
+
raw_original_prompt.render()
|
193 |
original_prompt_btn = gr.Button('Output Token List', variant='primary')
|
194 |
|
195 |
gr.Markdown('## Choose Your Interpretation Prompt')
|
196 |
with gr.Group('Interpretation'):
|
197 |
+
raw_interpretation_prompt = gr.Text(suggested_interpretation_prompts[0], label='Interpretation Prompt')
|
198 |
interpretation_prompt_examples = gr.Examples([[p] for p in suggested_interpretation_prompts],
|
199 |
+
[raw_interpretation_prompt], cache_examples=False)
|
200 |
|
201 |
with gr.Accordion(open=False, label='Generation Settings'):
|
202 |
with gr.Row():
|
|
|
226 |
|
227 |
# event listeners
|
228 |
for i, btn in enumerate(tokens_container):
|
229 |
+
btn.click(partial(run_interpretation, i=i), [raw_original_prompt, raw_interpretation_prompt,
|
230 |
num_tokens, do_sample, temperature,
|
231 |
top_k, top_p, repetition_penalty, length_penalty
|
232 |
], [progress_dummy, *interpretation_bubbles])
|
233 |
|
234 |
original_prompt_btn.click(get_hidden_states,
|
235 |
+
[raw_original_prompt],
|
236 |
[progress_dummy, *tokens_container, *interpretation_bubbles])
|
237 |
+
raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
238 |
|
239 |
+
extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
|
240 |
model_chooser.change(reset_model, [model_chooser, *extra_components],
|
241 |
[welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
|
242 |
|