dar-tau commited on
Commit
4009e7f
·
verified ·
1 Parent(s): acb4a55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -57,7 +57,7 @@ suggested_interpretation_prompts = [
57
  def initialize_gpu():
58
  pass
59
 
60
- def reset_model(global_state, model_name, load_on_gpu, *extra_components, reset_sentence_transformer=False, with_extra_components=True):
61
  # extract model info
62
  model_args = deepcopy(model_info[model_name])
63
  model_path = model_args.pop('model_path')
@@ -84,15 +84,15 @@ def reset_model(global_state, model_name, load_on_gpu, *extra_components, reset_
84
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
85
  gc.collect()
86
  if with_extra_components:
87
- return ([global_state, welcome_message.format(model_name=model_name)]
88
  + [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
89
  + [gr.Button('', visible=False) for _ in range(len(tokens_container))]
90
  + [*extra_components])
91
  else:
92
- return global_state
93
 
94
 
95
- def get_hidden_states(global_state, raw_original_prompt, force_hidden_states=False):
96
  model, tokenizer = global_state.model, global_state.tokenizer
97
  original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
98
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
@@ -120,7 +120,7 @@ def get_hidden_states(global_state, raw_original_prompt, force_hidden_states=Fal
120
 
121
 
122
  @spaces.GPU
123
- def run_interpretation(global_state, raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
124
  temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
125
  num_beams=1):
126
  model = global_state.model
@@ -197,8 +197,8 @@ for i in range(MAX_PROMPT_TOKENS):
197
  tokens_container.append(btn)
198
 
199
  with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
200
- global_state = gr.State(partial(reset_model, GlobalState(),
201
- model_name, load_on_gpu=True, with_extra_components=False, reset_sentence_transformer=True))
202
  with gr.Row():
203
  with gr.Column(scale=5):
204
  gr.Markdown('# 😎 Self-Interpreting Models')
@@ -278,19 +278,19 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
278
 
279
  # event listeners
280
  for i, btn in enumerate(tokens_container):
281
- btn.click(partial(run_interpretation, i=i), [global_state, raw_original_prompt, raw_interpretation_prompt,
282
  num_tokens, do_sample, temperature,
283
  top_k, top_p, repetition_penalty, length_penalty,
284
  use_gpu
285
  ], [progress_dummy, *interpretation_bubbles])
286
 
287
  original_prompt_btn.click(get_hidden_states,
288
- [global_state, raw_original_prompt],
289
  [progress_dummy, *tokens_container, *interpretation_bubbles])
290
  raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
291
 
292
  extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
293
- model_chooser.change(reset_model, [global_state, model_chooser, load_on_gpu, *extra_components],
294
- [global_state, welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
295
 
296
  demo.launch()
 
57
  def initialize_gpu():
58
  pass
59
 
60
+ def reset_model(model_name, load_on_gpu, *extra_components, reset_sentence_transformer=False, with_extra_components=True):
61
  # extract model info
62
  model_args = deepcopy(model_info[model_name])
63
  model_path = model_args.pop('model_path')
 
84
  global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
85
  gc.collect()
86
  if with_extra_components:
87
+ return ([welcome_message.format(model_name=model_name)]
88
  + [gr.Textbox('', visible=False) for _ in range(len(interpretation_bubbles))]
89
  + [gr.Button('', visible=False) for _ in range(len(tokens_container))]
90
  + [*extra_components])
91
  else:
92
+ return None
93
 
94
 
95
+ def get_hidden_states(raw_original_prompt, force_hidden_states=False):
96
  model, tokenizer = global_state.model, global_state.tokenizer
97
  original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
98
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
 
120
 
121
 
122
  @spaces.GPU
123
+ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_tokens, do_sample,
124
  temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
125
  num_beams=1):
126
  model = global_state.model
 
197
  tokens_container.append(btn)
198
 
199
  with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
200
+ global_state = GlobalState()
201
+ reset_model(model_name, load_on_gpu=True, with_extra_components=False, reset_sentence_transformer=True)
202
  with gr.Row():
203
  with gr.Column(scale=5):
204
  gr.Markdown('# 😎 Self-Interpreting Models')
 
278
 
279
  # event listeners
280
  for i, btn in enumerate(tokens_container):
281
+ btn.click(partial(run_interpretation, i=i), [raw_original_prompt, raw_interpretation_prompt,
282
  num_tokens, do_sample, temperature,
283
  top_k, top_p, repetition_penalty, length_penalty,
284
  use_gpu
285
  ], [progress_dummy, *interpretation_bubbles])
286
 
287
  original_prompt_btn.click(get_hidden_states,
288
+ [raw_original_prompt],
289
  [progress_dummy, *tokens_container, *interpretation_bubbles])
290
  raw_original_prompt.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
291
 
292
  extra_components = [raw_interpretation_prompt, raw_original_prompt, original_prompt_btn]
293
+ model_chooser.change(reset_model, [model_chooser, load_on_gpu, *extra_components],
294
+ [welcome_model, *interpretation_bubbles, *tokens_container, *extra_components])
295
 
296
  demo.launch()