Ais commited on
Commit
988fa7f
·
verified ·
1 Parent(s): 730d86c

Update app/inference.py

Browse files
Files changed (1) hide show
  1. app/inference.py +46 -3
app/inference.py CHANGED
@@ -17,9 +17,34 @@ model.eval()
17
 
18
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
19
 
20
- def generate_response(prompt: str) -> str:
21
- formatted = f"<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
 
23
  with torch.no_grad():
24
  output = model.generate(
25
  **inputs,
@@ -29,6 +54,24 @@ def generate_response(prompt: str) -> str:
29
  do_sample=True,
30
  pad_token_id=tokenizer.eos_token_id
31
  )
 
32
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
33
  answer = decoded.split("<|im_start|>assistant\n")[-1].strip()
34
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
19
 
20
+ def generate_response(prompt: str, conversation_history: list = None) -> str:
21
+ """
22
+ Generate response with optional conversation history
23
+
24
+ Args:
25
+ prompt: Current user message
26
+ conversation_history: List of {"role": "user/assistant", "content": "..."}
27
+ """
28
+
29
+ # Build conversation format
30
+ formatted = "<|im_start|>system\nYou are a helpful AI assistant.<|im_end|>\n"
31
+
32
+ # Add conversation history if provided
33
+ if conversation_history:
34
+ for msg in conversation_history:
35
+ role = msg.get("role", "")
36
+ content = msg.get("content", "")
37
+
38
+ if role == "user":
39
+ formatted += f"<|im_start|>user\n{content}<|im_end|>\n"
40
+ elif role == "assistant":
41
+ formatted += f"<|im_start|>assistant\n{content}<|im_end|>\n"
42
+
43
+ # Add current prompt
44
+ formatted += f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
45
+
46
  inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
47
+
48
  with torch.no_grad():
49
  output = model.generate(
50
  **inputs,
 
54
  do_sample=True,
55
  pad_token_id=tokenizer.eos_token_id
56
  )
57
+
58
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
59
  answer = decoded.split("<|im_start|>assistant\n")[-1].strip()
60
+
61
+ # Clean up any end tokens
62
+ if "<|im_end|>" in answer:
63
+ answer = answer.split("<|im_end|>")[0].strip()
64
+
65
+ return answer
66
+
67
+ # Example usage with conversation history
68
+ if __name__ == "__main__":
69
+ # Test with conversation history
70
+ history = [
71
+ {"role": "user", "content": "What is Python?"},
72
+ {"role": "assistant", "content": "Python is a high-level programming language..."},
73
+ ]
74
+
75
+ # This should now consider the conversation context
76
+ response = generate_response("Can you show me a simple example?", conversation_history=history)
77
+ print("Response:", response)