Ais commited on
Commit
3afe501
·
verified ·
1 Parent(s): 45afec6

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +107 -56
app/main.py CHANGED
@@ -41,6 +41,57 @@ model.eval()
41
 
42
  print("✅ Model and adapter loaded successfully.")
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # === Root Route ===
45
  @app.get("/")
46
  def root():
@@ -64,79 +115,74 @@ async def chat(request: Request):
64
  messages = body.get("messages", [])
65
  if not messages or not isinstance(messages, list):
66
  raise ValueError("Invalid or missing 'messages' field.")
67
-
68
- # Extract system and user messages
69
- system_message = ""
70
- user_messages = []
71
-
72
- for msg in messages:
73
- if msg.get("role") == "system":
74
- system_message = msg.get("content", "")
75
- elif msg.get("role") in ["user", "assistant"]:
76
- user_messages.append(msg)
77
-
78
- # Get the last user message
79
- if not user_messages:
80
- raise ValueError("No user messages found.")
81
-
82
- user_prompt = user_messages[-1]["content"]
83
-
84
  except Exception as e:
85
  return JSONResponse(status_code=400, content={"error": f"Bad request: {str(e)}"})
86
 
87
- # ✅ FIXED: Simplified prompt formatting - no system message in prompt
88
- # The system message is handled by the frontend logic, not in the model prompt
89
- formatted_prompt = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
92
 
93
- # ✅ Generate Response with better settings for small model
94
  with torch.no_grad():
95
  outputs = model.generate(
96
  **inputs,
97
- max_new_tokens=400, # Reduced for more focused responses
98
  temperature=0.7,
99
  top_p=0.9,
100
  do_sample=True,
101
  pad_token_id=tokenizer.eos_token_id,
102
- repetition_penalty=1.1, # Prevent repetition
103
- length_penalty=1.0
 
104
  )
105
 
106
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
107
 
108
- # FIXED: Better extraction - remove the prompt part completely
109
- final_answer = decoded.split("<|im_start|>assistant\n")[-1].strip()
110
 
111
- # Additional cleaning to prevent system message leakage
112
- if final_answer.lower().startswith(("you are a helpful", "i am a helpful", "as a helpful")):
113
- # If the response starts with system-like text, try to extract actual content
114
- lines = final_answer.split('\n')
115
- cleaned_lines = []
116
- found_content = False
117
-
118
- for line in lines:
119
- line = line.strip()
120
- if not line:
121
- continue
122
-
123
- # Skip system-like phrases
124
- if any(phrase in line.lower() for phrase in [
125
- "you are a helpful", "i am a helpful", "as a helpful assistant",
126
- "how can i help", "what can i help", "i'm here to help"
127
- ]):
128
- continue
129
-
130
- # This looks like actual content
131
- found_content = True
132
- cleaned_lines.append(line)
133
-
134
- if found_content:
135
- final_answer = '\n'.join(cleaned_lines)
136
 
137
- # Fallback if response is too short or looks like system message
138
- if len(final_answer.strip()) < 10 or final_answer.lower().startswith(("system", "user", "assistant")):
139
- final_answer = "I understand your question. Let me help you with that."
140
 
141
  # ✅ OpenAI-style Response
142
  return {
@@ -152,5 +198,10 @@ async def chat(request: Request):
152
  },
153
  "finish_reason": "stop"
154
  }
155
- ]
 
 
 
 
 
156
  }
 
41
 
42
  print("✅ Model and adapter loaded successfully.")
43
 
44
+ def clean_response(raw_response):
45
+ """
46
+ Clean the model response by removing unwanted artifacts while preserving the actual answer.
47
+ """
48
+ if not raw_response or len(raw_response.strip()) < 2:
49
+ return "I apologize, but I couldn't generate a proper response. Please try again."
50
+
51
+ # Remove common chat template artifacts
52
+ cleaned = raw_response.strip()
53
+
54
+ # Remove system/user/assistant prefixes that might leak through
55
+ prefixes_to_remove = [
56
+ "system\n", "user\n", "assistant\n",
57
+ "System:", "User:", "Assistant:",
58
+ "<|im_start|>", "<|im_end|>",
59
+ "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
60
+ "You are a helpful assistant.",
61
+ "I am a helpful assistant.",
62
+ "As a helpful assistant,",
63
+ ]
64
+
65
+ for prefix in prefixes_to_remove:
66
+ if cleaned.lower().startswith(prefix.lower()):
67
+ cleaned = cleaned[len(prefix):].strip()
68
+
69
+ # Remove any remaining template artifacts
70
+ lines = cleaned.split('\n')
71
+ filtered_lines = []
72
+
73
+ for line in lines:
74
+ line_stripped = line.strip()
75
+
76
+ # Skip empty lines at the beginning
77
+ if not line_stripped and not filtered_lines:
78
+ continue
79
+
80
+ # Skip obvious template artifacts
81
+ if line_stripped in ["system", "user", "assistant"]:
82
+ continue
83
+
84
+ filtered_lines.append(line)
85
+
86
+ cleaned = '\n'.join(filtered_lines).strip()
87
+
88
+ # If we still have content, return it
89
+ if cleaned and len(cleaned) > 5:
90
+ return cleaned
91
+
92
+ # Fallback only if truly empty
93
+ return "I understand your question. Let me help you with that."
94
+
95
  # === Root Route ===
96
  @app.get("/")
97
  def root():
 
115
  messages = body.get("messages", [])
116
  if not messages or not isinstance(messages, list):
117
  raise ValueError("Invalid or missing 'messages' field.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  except Exception as e:
119
  return JSONResponse(status_code=400, content={"error": f"Bad request: {str(e)}"})
120
 
121
+ # ✅ FIXED: Use proper Qwen2.5 chat template formatting
122
+ try:
123
+ # Use the tokenizer's built-in chat template - this is the correct way!
124
+ formatted_prompt = tokenizer.apply_chat_template(
125
+ messages,
126
+ tokenize=False,
127
+ add_generation_prompt=True
128
+ )
129
+
130
+ print(f"🔍 Formatted prompt: {formatted_prompt}")
131
+
132
+ except Exception as e:
133
+ print(f"❌ Chat template error: {e}")
134
+ # Fallback to manual formatting if needed
135
+ formatted_prompt = ""
136
+ for msg in messages:
137
+ role = msg.get("role", "user")
138
+ content = msg.get("content", "")
139
+ if role == "system":
140
+ formatted_prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
141
+ elif role == "user":
142
+ formatted_prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
143
+ elif role == "assistant":
144
+ formatted_prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
145
+ formatted_prompt += "<|im_start|>assistant\n"
146
 
147
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
148
 
149
+ # ✅ Generate Response with optimized settings
150
  with torch.no_grad():
151
  outputs = model.generate(
152
  **inputs,
153
+ max_new_tokens=512, # Increased for better responses
154
  temperature=0.7,
155
  top_p=0.9,
156
  do_sample=True,
157
  pad_token_id=tokenizer.eos_token_id,
158
+ repetition_penalty=1.05, # Slightly reduced
159
+ length_penalty=1.0,
160
+ early_stopping=True
161
  )
162
 
163
+ # FIXED: Better response extraction
164
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
165
 
166
+ print(f"🔍 Full generated response: {full_response}")
 
167
 
168
+ # Extract only the new generated part (after the prompt)
169
+ if formatted_prompt in full_response:
170
+ generated_part = full_response.split(formatted_prompt)[-1].strip()
171
+ else:
172
+ # If we can't find the exact prompt, try to extract the assistant's response
173
+ assistant_marker = "<|im_start|>assistant\n"
174
+ if assistant_marker in full_response:
175
+ parts = full_response.split(assistant_marker)
176
+ generated_part = parts[-1].split("<|im_end|>")[0].strip() if len(parts) > 1 else full_response
177
+ else:
178
+ generated_part = full_response
179
+
180
+ print(f"🔍 Extracted generated part: {generated_part}")
181
+
182
+ # Clean the response but keep it intact
183
+ final_answer = clean_response(generated_part)
 
 
 
 
 
 
 
 
 
184
 
185
+ print(f"🔍 Final cleaned answer: {final_answer}")
 
 
186
 
187
  # ✅ OpenAI-style Response
188
  return {
 
198
  },
199
  "finish_reason": "stop"
200
  }
201
+ ],
202
+ "usage": {
203
+ "prompt_tokens": len(inputs.input_ids[0]),
204
+ "completion_tokens": len(outputs[0]) - len(inputs.input_ids[0]),
205
+ "total_tokens": len(outputs[0])
206
+ }
207
  }