Spaces:
Runtime error
Runtime error
frankaging
commited on
Commit
·
2698ee0
1
Parent(s):
77bd93c
initial commit
Browse files- .gitignore +1 -0
- app.py +3 -5
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.ipynb_checkpoints/
|
app.py
CHANGED
|
@@ -37,12 +37,10 @@ if not torch.cuda.is_available():
|
|
| 37 |
if torch.cuda.is_available():
|
| 38 |
model_id = "meta-llama/Llama-2-7b-chat-hf" # not gated version.
|
| 39 |
model = AutoModelForCausalLM.from_pretrained(
|
| 40 |
-
model_id, device_map="
|
| 41 |
)
|
| 42 |
reft_model = ReftModel.load("pyvene/reft_goody2", model, from_huggingface_hub=True)
|
| 43 |
-
|
| 44 |
-
for k, v in reft_model.interventions.items():
|
| 45 |
-
v[0].to(model.device)
|
| 46 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 47 |
tokenizer.use_default_system_prompt = True
|
| 48 |
|
|
@@ -77,7 +75,7 @@ def generate(
|
|
| 77 |
|
| 78 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
| 79 |
generate_kwargs = {
|
| 80 |
-
"base": {"input_ids":
|
| 81 |
"unit_locations": {"sources->base": (None, [[[base_unit_location]]])},
|
| 82 |
"max_new_tokens": max_new_tokens,
|
| 83 |
"intervene_on_prompt": True,
|
|
|
|
| 37 |
if torch.cuda.is_available():
|
| 38 |
model_id = "meta-llama/Llama-2-7b-chat-hf" # not gated version.
|
| 39 |
model = AutoModelForCausalLM.from_pretrained(
|
| 40 |
+
model_id, device_map="cuda", torch_dtype=torch.bfloat16
|
| 41 |
)
|
| 42 |
reft_model = ReftModel.load("pyvene/reft_goody2", model, from_huggingface_hub=True)
|
| 43 |
+
reft_model.set_device("cuda")
|
|
|
|
|
|
|
| 44 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 45 |
tokenizer.use_default_system_prompt = True
|
| 46 |
|
|
|
|
| 75 |
|
| 76 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
| 77 |
generate_kwargs = {
|
| 78 |
+
"base": {"input_ids": input_ids, "attention_mask": attention_mask},
|
| 79 |
"unit_locations": {"sources->base": (None, [[[base_unit_location]]])},
|
| 80 |
"max_new_tokens": max_new_tokens,
|
| 81 |
"intervene_on_prompt": True,
|