Update app.py
Browse files
app.py
CHANGED
|
@@ -41,7 +41,7 @@ suggested_interpretation_prompts = [
|
|
| 41 |
def initialize_gpu():
|
| 42 |
pass
|
| 43 |
|
| 44 |
-
def reset_model(model_name, *extra_components):
|
| 45 |
# extract model info
|
| 46 |
model_args = deepcopy(model_info[model_name])
|
| 47 |
model_path = model_args.pop('model_path')
|
|
@@ -58,7 +58,10 @@ def reset_model(model_name, *extra_components):
|
|
| 58 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
| 59 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
| 60 |
gc.collect()
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
def get_hidden_states(raw_original_prompt):
|
|
@@ -101,8 +104,8 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
|
| 101 |
interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt)
|
| 102 |
|
| 103 |
# generate the interpretations
|
| 104 |
-
|
| 105 |
-
|
| 106 |
**generation_kwargs)
|
| 107 |
generation_texts = global_state.tokenizer.batch_decode(generated)
|
| 108 |
progress_dummy_output = ''
|
|
@@ -116,7 +119,7 @@ torch.set_grad_enabled(False)
|
|
| 116 |
global_state = GlobalState()
|
| 117 |
|
| 118 |
model_name = 'LLAMA2-7B'
|
| 119 |
-
reset_model(model_name)
|
| 120 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
| 121 |
tokens_container = []
|
| 122 |
|
|
|
|
| 41 |
def initialize_gpu():
|
| 42 |
pass
|
| 43 |
|
| 44 |
+
def reset_model(model_name, *extra_components, with_extra_components=True):
|
| 45 |
# extract model info
|
| 46 |
model_args = deepcopy(model_info[model_name])
|
| 47 |
model_path = model_args.pop('model_path')
|
|
|
|
| 58 |
global_state.model = AutoModelClass.from_pretrained(model_path, **model_args).cuda()
|
| 59 |
global_state.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, token=os.environ['hf_token'])
|
| 60 |
gc.collect()
|
| 61 |
+
if with_extra_components:
|
| 62 |
+
for x in interpretation_bubbles:
|
| 63 |
+
x.visible = False
|
| 64 |
+
return extra_components
|
| 65 |
|
| 66 |
|
| 67 |
def get_hidden_states(raw_original_prompt):
|
|
|
|
| 104 |
interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt)
|
| 105 |
|
| 106 |
# generate the interpretations
|
| 107 |
+
generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors},
|
| 108 |
+
layers_format=global_state.layers_format, k=3,
|
| 109 |
**generation_kwargs)
|
| 110 |
generation_texts = global_state.tokenizer.batch_decode(generated)
|
| 111 |
progress_dummy_output = ''
|
|
|
|
| 119 |
global_state = GlobalState()
|
| 120 |
|
| 121 |
model_name = 'LLAMA2-7B'
|
| 122 |
+
reset_model(model_name, with_extra_components=False)
|
| 123 |
original_prompt_raw = gr.Textbox(value='How to make a Molotov cocktail?', container=True, label='Original Prompt')
|
| 124 |
tokens_container = []
|
| 125 |
|