Spaces:
Paused
Paused
updated inference function
Browse files
app.py
CHANGED
@@ -35,26 +35,21 @@ def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True
|
|
35 |
formatted_text = bos + formatted_text if add_bos else formatted_text
|
36 |
return formatted_text
|
37 |
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
input_prompts = [
|
41 |
-
create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
|
42 |
-
for input_prompt in input_prompts
|
43 |
-
]
|
44 |
-
|
45 |
-
encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
|
46 |
encodings = encodings.to(device)
|
47 |
|
48 |
with torch.inference_mode(): # Add missing import statement for torch.inference_mode()
|
49 |
outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
|
50 |
|
51 |
-
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
|
55 |
-
]
|
56 |
-
output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
|
57 |
-
return output_texts
|
58 |
|
59 |
|
60 |
def chat_interface(message,history):
|
|
|
35 |
formatted_text = bos + formatted_text if add_bos else formatted_text
|
36 |
return formatted_text
|
37 |
|
38 |
+
def inference(input_prompt, model, tokenizer):
|
39 |
+
input_prompt = create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
|
40 |
|
41 |
+
encodings = tokenizer(input_prompt, padding=True, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
encodings = encodings.to(device)
|
43 |
|
44 |
with torch.inference_mode(): # Add missing import statement for torch.inference_mode()
|
45 |
outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
|
46 |
|
47 |
+
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
48 |
+
|
49 |
+
input_prompt = tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True)
|
50 |
+
output_text = output_text[len(input_prompt):]
|
51 |
|
52 |
+
return output_text
|
|
|
|
|
|
|
|
|
53 |
|
54 |
|
55 |
def chat_interface(message,history):
|