Ais commited on
Commit
730d86c
·
verified ·
1 Parent(s): 4ca2587

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +36 -10
app/main.py CHANGED
@@ -53,7 +53,7 @@ async def chat(request: Request):
53
  auth_header = request.headers.get("Authorization", "")
54
  if not auth_header.startswith("Bearer "):
55
  return JSONResponse(status_code=401, content={"error": "Missing Bearer token in Authorization header."})
56
-
57
  token = auth_header.replace("Bearer ", "").strip()
58
  if token != API_KEY:
59
  return JSONResponse(status_code=401, content={"error": "Invalid API key."})
@@ -64,31 +64,57 @@ async def chat(request: Request):
64
  messages = body.get("messages", [])
65
  if not messages or not isinstance(messages, list):
66
  raise ValueError("Invalid or missing 'messages' field.")
67
- user_prompt = messages[-1]["content"]
 
 
 
 
68
  except Exception as e:
69
  return JSONResponse(status_code=400, content={"error": f"Bad request: {str(e)}"})
70
 
71
- # ✅ Format Prompt for Qwen
72
- formatted_prompt = (
73
- "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
74
- f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
75
- "<|im_start|>assistant\n"
76
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
78
 
79
  # ✅ Generate Response
80
  with torch.no_grad():
81
  outputs = model.generate(
82
  **inputs,
83
- max_new_tokens=512,
84
- temperature=0.7,
85
  top_p=0.9,
86
  do_sample=True,
87
  pad_token_id=tokenizer.eos_token_id
88
  )
89
 
90
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
91
  final_answer = decoded.split("<|im_start|>assistant\n")[-1].strip()
 
 
 
 
 
 
92
 
93
  # ✅ OpenAI-style Response
94
  return {
 
53
  auth_header = request.headers.get("Authorization", "")
54
  if not auth_header.startswith("Bearer "):
55
  return JSONResponse(status_code=401, content={"error": "Missing Bearer token in Authorization header."})
56
+
57
  token = auth_header.replace("Bearer ", "").strip()
58
  if token != API_KEY:
59
  return JSONResponse(status_code=401, content={"error": "Invalid API key."})
 
64
  messages = body.get("messages", [])
65
  if not messages or not isinstance(messages, list):
66
  raise ValueError("Invalid or missing 'messages' field.")
67
+
68
+ # ✅ FIXED: Process full conversation history, not just last message
69
+ temperature = body.get("temperature", 0.7)
70
+ max_tokens = body.get("max_tokens", 512)
71
+
72
  except Exception as e:
73
  return JSONResponse(status_code=400, content={"error": f"Bad request: {str(e)}"})
74
 
75
+ # ✅ FIXED: Build full conversation prompt with history
76
+ formatted_prompt = ""
77
+
78
+ for message in messages:
79
+ role = message.get("role", "")
80
+ content = message.get("content", "")
81
+
82
+ if role == "system":
83
+ formatted_prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
84
+ elif role == "user":
85
+ formatted_prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
86
+ elif role == "assistant":
87
+ formatted_prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
88
+
89
+ # Add the assistant start token for generation
90
+ formatted_prompt += "<|im_start|>assistant\n"
91
+
92
+ print(f"🤖 Processing conversation with {len(messages)} messages")
93
+ print(f"📝 Full prompt preview: {formatted_prompt[:200]}...")
94
+
95
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
96
 
97
  # ✅ Generate Response
98
  with torch.no_grad():
99
  outputs = model.generate(
100
  **inputs,
101
+ max_new_tokens=max_tokens,
102
+ temperature=temperature,
103
  top_p=0.9,
104
  do_sample=True,
105
  pad_token_id=tokenizer.eos_token_id
106
  )
107
 
108
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
109
+
110
+ # ✅ FIXED: Clean extraction of only the new assistant response
111
  final_answer = decoded.split("<|im_start|>assistant\n")[-1].strip()
112
+
113
+ # Remove any potential end tokens or artifacts
114
+ if "<|im_end|>" in final_answer:
115
+ final_answer = final_answer.split("<|im_end|>")[0].strip()
116
+
117
+ print(f"✅ Generated response: {final_answer[:100]}...")
118
 
119
  # ✅ OpenAI-style Response
120
  return {