SivaResearch commited on
Commit
fe967b9
·
verified ·
1 Parent(s): 12e3bb4

updated inference function

Browse files
Files changed (1) hide show
  1. app.py +8 -13
app.py CHANGED
@@ -35,26 +35,21 @@ def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True
35
  formatted_text = bos + formatted_text if add_bos else formatted_text
36
  return formatted_text
37
 
 
 
38
 
39
- def inference(input_prompts, model, tokenizer):
40
- input_prompts = [
41
- create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
42
- for input_prompt in input_prompts
43
- ]
44
-
45
- encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
46
  encodings = encodings.to(device)
47
 
48
  with torch.inference_mode(): # Add missing import statement for torch.inference_mode()
49
  outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
50
 
51
- output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True)
 
 
 
52
 
53
- input_prompts = [
54
- tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
55
- ]
56
- output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
57
- return output_texts
58
 
59
 
60
  def chat_interface(message,history):
 
35
  formatted_text = bos + formatted_text if add_bos else formatted_text
36
  return formatted_text
37
 
38
+ def inference(input_prompt, model, tokenizer):
39
+ input_prompt = create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
40
 
41
+ encodings = tokenizer(input_prompt, padding=True, return_tensors="pt")
 
 
 
 
 
 
42
  encodings = encodings.to(device)
43
 
44
  with torch.inference_mode(): # Add missing import statement for torch.inference_mode()
45
  outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
46
 
47
+ output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
+
49
+ input_prompt = tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True)
50
+ output_text = output_text[len(input_prompt):]
51
 
52
+ return output_text
 
 
 
 
53
 
54
 
55
  def chat_interface(message,history):