Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
2062515
1
Parent(s):
eb1a863
remove some unneeded lines, fix pipe issue
Browse files- utils/models.py +5 -16
utils/models.py
CHANGED
|
@@ -135,15 +135,10 @@ def run_inference(model_name, context, question):
|
|
| 135 |
# Common arguments for tokenizer loading
|
| 136 |
tokenizer_load_args = {"padding_side": "left", "token": True}
|
| 137 |
|
| 138 |
-
# Determine the Hugging Face model name for the tokenizer
|
| 139 |
actual_model_name_for_tokenizer = model_name
|
| 140 |
if "icecream" in model_name.lower():
|
| 141 |
actual_model_name_for_tokenizer = "meta-llama/llama-3.2-3b-instruct"
|
| 142 |
|
| 143 |
-
# Note: tokenizer_kwargs (defined earlier, with add_generation_prompt etc.)
|
| 144 |
-
# is intended for tokenizer.apply_chat_template, not for AutoTokenizer.from_pretrained generally.
|
| 145 |
-
# If a specific tokenizer (e.g., Qwen) needs special __init__ args that happen to be in tokenizer_kwargs,
|
| 146 |
-
# that would require more specific handling here. For now, we assume general constructor args.
|
| 147 |
tokenizer = AutoTokenizer.from_pretrained(actual_model_name_for_tokenizer, **tokenizer_load_args)
|
| 148 |
tokenizer_cache[model_name] = tokenizer
|
| 149 |
|
|
@@ -201,8 +196,6 @@ def run_inference(model_name, context, question):
|
|
| 201 |
elif "icecream" in model_name.lower():
|
| 202 |
|
| 203 |
print("ICECREAM")
|
| 204 |
-
# text_input is the list of messages from format_rag_prompt
|
| 205 |
-
# tokenizer_kwargs (e.g., {"add_generation_prompt": True}) are correctly passed to apply_chat_template
|
| 206 |
model_inputs = tokenizer.apply_chat_template(
|
| 207 |
text_input,
|
| 208 |
tokenize=True,
|
|
@@ -211,38 +204,34 @@ def run_inference(model_name, context, question):
|
|
| 211 |
**tokenizer_kwargs,
|
| 212 |
)
|
| 213 |
|
| 214 |
-
|
| 215 |
model_inputs = model_inputs.to(model.device)
|
| 216 |
|
| 217 |
input_ids = model_inputs.input_ids
|
| 218 |
-
attention_mask = model_inputs.attention_mask
|
| 219 |
|
| 220 |
-
prompt_tokens_length = input_ids.shape[1]
|
| 221 |
|
| 222 |
with torch.inference_mode():
|
| 223 |
# Check interrupt before generation
|
| 224 |
if generation_interrupt.is_set():
|
| 225 |
return ""
|
| 226 |
|
| 227 |
-
# Explicitly pass input_ids, attention_mask, and pad_token_id
|
| 228 |
-
# tokenizer.pad_token is set to tokenizer.eos_token if None, earlier in the code.
|
| 229 |
output_sequences = model.generate(
|
| 230 |
input_ids=input_ids,
|
| 231 |
attention_mask=attention_mask,
|
| 232 |
max_new_tokens=512,
|
| 233 |
-
eos_token_id=tokenizer.eos_token_id,
|
| 234 |
pad_token_id=tokenizer.pad_token_id # Addresses the warning
|
| 235 |
)
|
| 236 |
|
| 237 |
-
# output_sequences[0] contains the full sequence (prompt + generation)
|
| 238 |
-
# Decode only the newly generated tokens
|
| 239 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
| 240 |
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
|
| 241 |
|
| 242 |
else: # For other models
|
| 243 |
formatted = pipe.tokenizer.apply_chat_template(
|
| 244 |
text_input,
|
| 245 |
-
tokenize=
|
| 246 |
**tokenizer_kwargs,
|
| 247 |
)
|
| 248 |
|
|
|
|
| 135 |
# Common arguments for tokenizer loading
|
| 136 |
tokenizer_load_args = {"padding_side": "left", "token": True}
|
| 137 |
|
|
|
|
| 138 |
actual_model_name_for_tokenizer = model_name
|
| 139 |
if "icecream" in model_name.lower():
|
| 140 |
actual_model_name_for_tokenizer = "meta-llama/llama-3.2-3b-instruct"
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
tokenizer = AutoTokenizer.from_pretrained(actual_model_name_for_tokenizer, **tokenizer_load_args)
|
| 143 |
tokenizer_cache[model_name] = tokenizer
|
| 144 |
|
|
|
|
| 196 |
elif "icecream" in model_name.lower():
|
| 197 |
|
| 198 |
print("ICECREAM")
|
|
|
|
|
|
|
| 199 |
model_inputs = tokenizer.apply_chat_template(
|
| 200 |
text_input,
|
| 201 |
tokenize=True,
|
|
|
|
| 204 |
**tokenizer_kwargs,
|
| 205 |
)
|
| 206 |
|
| 207 |
+
|
| 208 |
model_inputs = model_inputs.to(model.device)
|
| 209 |
|
| 210 |
input_ids = model_inputs.input_ids
|
| 211 |
+
attention_mask = model_inputs.attention_mask
|
| 212 |
|
| 213 |
+
prompt_tokens_length = input_ids.shape[1]
|
| 214 |
|
| 215 |
with torch.inference_mode():
|
| 216 |
# Check interrupt before generation
|
| 217 |
if generation_interrupt.is_set():
|
| 218 |
return ""
|
| 219 |
|
|
|
|
|
|
|
| 220 |
output_sequences = model.generate(
|
| 221 |
input_ids=input_ids,
|
| 222 |
attention_mask=attention_mask,
|
| 223 |
max_new_tokens=512,
|
| 224 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 225 |
pad_token_id=tokenizer.pad_token_id # Addresses the warning
|
| 226 |
)
|
| 227 |
|
|
|
|
|
|
|
| 228 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
| 229 |
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
|
| 230 |
|
| 231 |
else: # For other models
|
| 232 |
formatted = pipe.tokenizer.apply_chat_template(
|
| 233 |
text_input,
|
| 234 |
+
tokenize=False,
|
| 235 |
**tokenizer_kwargs,
|
| 236 |
)
|
| 237 |
|