Update app.py
Browse files
app.py
CHANGED
@@ -75,7 +75,7 @@ def initialize_gpu():
|
|
75 |
pass
|
76 |
|
77 |
|
78 |
-
def reset_model(model_name,
|
79 |
# extract model info
|
80 |
model_args = deepcopy(model_info[model_name])
|
81 |
model_path = model_args.pop('model_path')
|
@@ -91,8 +91,7 @@ def reset_model(model_name, return_demo_blocks=True):
|
|
91 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
92 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
93 |
gc.collect()
|
94 |
-
|
95 |
-
return demo_blocks
|
96 |
|
97 |
|
98 |
def get_hidden_states(raw_original_prompt):
|
@@ -149,7 +148,7 @@ torch.set_grad_enabled(False)
|
|
149 |
global_state = GlobalState()
|
150 |
|
151 |
model_name = 'LLAMA2-7B'
|
152 |
-
reset_model(model_name
|
153 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
154 |
tokens_container = []
|
155 |
|
@@ -234,7 +233,9 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
234 |
|
235 |
|
236 |
# event listeners
|
237 |
-
|
|
|
|
|
238 |
|
239 |
for i, btn in enumerate(tokens_container):
|
240 |
btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
|
|
|
75 |
pass
|
76 |
|
77 |
|
78 |
+
def reset_model(model_name, *extra_components):
|
79 |
# extract model info
|
80 |
model_args = deepcopy(model_info[model_name])
|
81 |
model_path = model_args.pop('model_path')
|
|
|
91 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
92 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
93 |
gc.collect()
|
94 |
+
return extra_components
|
|
|
95 |
|
96 |
|
97 |
def get_hidden_states(raw_original_prompt):
|
|
|
148 |
global_state = GlobalState()
|
149 |
|
150 |
model_name = 'LLAMA2-7B'
|
151 |
+
reset_model(model_name)
|
152 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
153 |
tokens_container = []
|
154 |
|
|
|
233 |
|
234 |
|
235 |
# event listeners
|
236 |
+
all_components = [*interpretation_bubbles, *tokens_container, original_prompt_btn,
|
237 |
+
original_prompt_raw]
|
238 |
+
model_chooser.change(reset_model, [model_chooser, *all_components], all_components)
|
239 |
|
240 |
for i, btn in enumerate(tokens_container):
|
241 |
btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
|