Ais
commited on
Update app/main.py
Browse files- app/main.py +131 -92
app/main.py
CHANGED
@@ -7,7 +7,7 @@ from peft import PeftModel
|
|
7 |
from starlette.middleware.cors import CORSMiddleware
|
8 |
|
9 |
# === Setup FastAPI ===
|
10 |
-
app = FastAPI(title="Apollo AI Backend - Qwen2-0.5B", version="4.
|
11 |
|
12 |
# === CORS ===
|
13 |
app.add_middleware(
|
@@ -46,39 +46,31 @@ print("✅ Qwen2-0.5B model ready!")
|
|
46 |
def create_conversation_prompt(messages: list, is_force_mode: bool) -> str:
|
47 |
"""Create a conversation prompt with clear mode instructions"""
|
48 |
|
|
|
|
|
|
|
49 |
if is_force_mode:
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
else:
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
# Build conversation
|
70 |
-
conversation = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
|
71 |
-
|
72 |
-
# Add conversation history (last 4 messages for context)
|
73 |
-
recent_messages = messages[-4:] if len(messages) > 4 else messages
|
74 |
-
|
75 |
-
for msg in recent_messages:
|
76 |
-
role = msg.get("role", "")
|
77 |
-
content = msg.get("content", "")
|
78 |
-
conversation += f"<|im_start|>{role}\n{content}<|im_end|>\n"
|
79 |
|
80 |
-
|
81 |
-
return conversation
|
82 |
|
83 |
def generate_response(messages: list, is_force_mode: bool = False, max_tokens: int = 200, temperature: float = 0.7) -> str:
|
84 |
"""Generate response using the AI model"""
|
@@ -88,19 +80,30 @@ def generate_response(messages: list, is_force_mode: bool = False, max_tokens: i
|
|
88 |
|
89 |
print(f"🎯 Generating {'FORCE (Direct)' if is_force_mode else 'MENTOR (Questions)'} response")
|
90 |
print(f"🔍 Mode flag: {is_force_mode}")
|
|
|
91 |
|
92 |
# Adjust parameters based on mode
|
93 |
if is_force_mode:
|
94 |
-
generation_temp = 0.
|
95 |
-
generation_tokens = min(max_tokens,
|
|
|
96 |
else:
|
97 |
-
generation_temp = 0.
|
98 |
generation_tokens = min(max_tokens, 250)
|
|
|
99 |
|
100 |
-
# Tokenize input
|
101 |
-
inputs = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
-
|
|
|
|
|
104 |
with torch.no_grad():
|
105 |
outputs = model.generate(
|
106 |
inputs.input_ids,
|
@@ -109,65 +112,77 @@ def generate_response(messages: list, is_force_mode: bool = False, max_tokens: i
|
|
109 |
do_sample=True,
|
110 |
pad_token_id=tokenizer.eos_token_id,
|
111 |
eos_token_id=tokenizer.eos_token_id,
|
112 |
-
top_p=
|
113 |
-
repetition_penalty=1.
|
114 |
-
no_repeat_ngram_size=
|
|
|
115 |
)
|
116 |
|
117 |
-
# Decode response
|
118 |
-
|
|
|
119 |
|
120 |
-
|
121 |
-
response = full_response[len(prompt):].strip()
|
122 |
|
123 |
# Clean up response
|
124 |
response = response.replace("<|im_end|>", "").strip()
|
125 |
|
126 |
-
# Remove conversation
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
if not line.startswith(('<|im_start|>', '<|im_end|>', 'system:', 'user:', 'assistant:')):
|
132 |
-
clean_lines.append(line)
|
133 |
-
|
134 |
-
response = '\n'.join(clean_lines).strip()
|
135 |
-
|
136 |
-
# Take first paragraph if too long
|
137 |
-
if len(response) > max_tokens * 4:
|
138 |
-
paragraphs = response.split('\n\n')
|
139 |
-
response = paragraphs[0] if paragraphs else response[:max_tokens * 4]
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
# Simple validation - no template injection
|
144 |
-
if not response or len(response) < 10:
|
145 |
if is_force_mode:
|
146 |
-
return "I need more specific information to provide a
|
147 |
else:
|
148 |
-
return "That's
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
return response
|
151 |
|
152 |
except Exception as e:
|
153 |
print(f"❌ Generation error: {e}")
|
|
|
|
|
|
|
154 |
if is_force_mode:
|
155 |
-
return "I encountered an error generating a
|
156 |
else:
|
157 |
-
return "
|
158 |
|
159 |
# === Routes ===
|
160 |
@app.get("/")
|
161 |
def root():
|
162 |
return {
|
163 |
-
"message": "🤖 Apollo AI Backend v4.
|
164 |
"model": "Qwen/Qwen2-0.5B-Instruct with LoRA",
|
165 |
"status": "ready",
|
166 |
"modes": {
|
167 |
-
"mentor": "Guides learning with questions -
|
168 |
-
"force": "Provides direct answers -
|
169 |
},
|
170 |
-
"fixes":
|
|
|
|
|
|
|
|
|
|
|
171 |
}
|
172 |
|
173 |
@app.get("/health")
|
@@ -176,7 +191,7 @@ def health():
|
|
176 |
"status": "healthy",
|
177 |
"model_loaded": True,
|
178 |
"model_size": "0.5B",
|
179 |
-
"version": "4.
|
180 |
}
|
181 |
|
182 |
@app.post("/v1/chat/completions")
|
@@ -200,19 +215,28 @@ async def chat_completions(request: Request):
|
|
200 |
try:
|
201 |
body = await request.json()
|
202 |
messages = body.get("messages", [])
|
203 |
-
max_tokens = min(body.get("max_tokens",
|
204 |
temperature = max(0.1, min(body.get("temperature", 0.7), 1.0))
|
205 |
|
206 |
-
# Get force mode flag
|
207 |
-
is_force_mode =
|
|
|
|
|
|
|
|
|
208 |
|
209 |
-
print(f"🚨 REQUEST RECEIVED
|
210 |
-
print(f"
|
|
|
|
|
|
|
|
|
211 |
|
212 |
if not messages or not isinstance(messages, list):
|
213 |
raise ValueError("Messages field is required and must be a list")
|
214 |
|
215 |
except Exception as e:
|
|
|
216 |
return JSONResponse(
|
217 |
status_code=400,
|
218 |
content={"error": f"Invalid request body: {str(e)}"}
|
@@ -227,9 +251,10 @@ async def chat_completions(request: Request):
|
|
227 |
)
|
228 |
|
229 |
try:
|
230 |
-
print(f"
|
|
|
231 |
|
232 |
-
# Generate response
|
233 |
response_content = generate_response(
|
234 |
messages=messages,
|
235 |
is_force_mode=is_force_mode,
|
@@ -237,13 +262,19 @@ async def chat_completions(request: Request):
|
|
237 |
temperature=temperature
|
238 |
)
|
239 |
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
return {
|
243 |
-
"id": f"chatcmpl-apollo-{hash(str(messages)) % 10000}",
|
244 |
"object": "chat.completion",
|
245 |
-
"created":
|
246 |
-
"model": f"qwen2-0.5b-{'force' if is_force_mode else 'mentor'}-
|
247 |
"choices": [
|
248 |
{
|
249 |
"index": 0,
|
@@ -255,26 +286,34 @@ async def chat_completions(request: Request):
|
|
255 |
}
|
256 |
],
|
257 |
"usage": {
|
258 |
-
"prompt_tokens": len(
|
259 |
-
"completion_tokens": len(response_content),
|
260 |
-
"total_tokens": len(
|
261 |
},
|
262 |
-
"apollo_mode": "
|
263 |
-
"pure_ai_response": True
|
|
|
264 |
}
|
265 |
|
266 |
except Exception as e:
|
267 |
print(f"❌ Chat completion error: {e}")
|
|
|
|
|
|
|
268 |
return JSONResponse(
|
269 |
status_code=500,
|
270 |
-
content={
|
|
|
|
|
|
|
|
|
271 |
)
|
272 |
|
273 |
if __name__ == "__main__":
|
274 |
import uvicorn
|
275 |
-
print("🚀 Starting Apollo AI Backend v4.
|
276 |
print("🧠 Model: Qwen/Qwen2-0.5B-Instruct (500M parameters)")
|
277 |
-
print("
|
278 |
-
print("
|
279 |
-
print("
|
280 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
7 |
from starlette.middleware.cors import CORSMiddleware
|
8 |
|
9 |
# === Setup FastAPI ===
|
10 |
+
app = FastAPI(title="Apollo AI Backend - Qwen2-0.5B", version="4.1.0-TRULY-FIXED")
|
11 |
|
12 |
# === CORS ===
|
13 |
app.add_middleware(
|
|
|
46 |
def create_conversation_prompt(messages: list, is_force_mode: bool) -> str:
|
47 |
"""Create a conversation prompt with clear mode instructions"""
|
48 |
|
49 |
+
# Get the last user message
|
50 |
+
last_message = messages[-1].get("content", "") if messages else ""
|
51 |
+
|
52 |
if is_force_mode:
|
53 |
+
# FORCE MODE: Direct, complete answers
|
54 |
+
system_instruction = """You are a helpful programming assistant. Answer directly and completely. Provide clear explanations with code examples when relevant. Don't ask questions back to the user."""
|
55 |
+
|
56 |
+
prompt = f"""<|im_start|>system
|
57 |
+
{system_instruction}<|im_end|>
|
58 |
+
<|im_start|>user
|
59 |
+
{last_message}<|im_end|>
|
60 |
+
<|im_start|>assistant
|
61 |
+
"""
|
62 |
else:
|
63 |
+
# MENTOR MODE: Guide with questions
|
64 |
+
system_instruction = """You are a programming mentor. Guide students to discover answers through questions and hints. Ask questions to help them think, rather than giving direct answers."""
|
65 |
+
|
66 |
+
prompt = f"""<|im_start|>system
|
67 |
+
{system_instruction}<|im_end|>
|
68 |
+
<|im_start|>user
|
69 |
+
{last_message}<|im_end|>
|
70 |
+
<|im_start|>assistant
|
71 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
+
return prompt
|
|
|
74 |
|
75 |
def generate_response(messages: list, is_force_mode: bool = False, max_tokens: int = 200, temperature: float = 0.7) -> str:
|
76 |
"""Generate response using the AI model"""
|
|
|
80 |
|
81 |
print(f"🎯 Generating {'FORCE (Direct)' if is_force_mode else 'MENTOR (Questions)'} response")
|
82 |
print(f"🔍 Mode flag: {is_force_mode}")
|
83 |
+
print(f"📝 Prompt preview: {prompt[:200]}...")
|
84 |
|
85 |
# Adjust parameters based on mode
|
86 |
if is_force_mode:
|
87 |
+
generation_temp = 0.4 # More focused for direct answers
|
88 |
+
generation_tokens = min(max_tokens, 350)
|
89 |
+
top_p = 0.8
|
90 |
else:
|
91 |
+
generation_temp = 0.6 # More creative for questions
|
92 |
generation_tokens = min(max_tokens, 250)
|
93 |
+
top_p = 0.9
|
94 |
|
95 |
+
# Tokenize input with proper truncation
|
96 |
+
inputs = tokenizer(
|
97 |
+
prompt,
|
98 |
+
return_tensors="pt",
|
99 |
+
max_length=1024, # Shorter context for better responses
|
100 |
+
truncation=True,
|
101 |
+
padding=False
|
102 |
+
)
|
103 |
|
104 |
+
print(f"🔢 Input tokens: {inputs.input_ids.shape[1]}")
|
105 |
+
|
106 |
+
# Generate response with better parameters
|
107 |
with torch.no_grad():
|
108 |
outputs = model.generate(
|
109 |
inputs.input_ids,
|
|
|
112 |
do_sample=True,
|
113 |
pad_token_id=tokenizer.eos_token_id,
|
114 |
eos_token_id=tokenizer.eos_token_id,
|
115 |
+
top_p=top_p,
|
116 |
+
repetition_penalty=1.05, # Reduced repetition penalty
|
117 |
+
no_repeat_ngram_size=2, # Reduced n-gram size
|
118 |
+
early_stopping=True
|
119 |
)
|
120 |
|
121 |
+
# Decode response properly
|
122 |
+
generated_ids = outputs[0][inputs.input_ids.shape[1]:] # Only new tokens
|
123 |
+
response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
124 |
|
125 |
+
print(f"🔍 Raw response: {response[:150]}...")
|
|
|
126 |
|
127 |
# Clean up response
|
128 |
response = response.replace("<|im_end|>", "").strip()
|
129 |
|
130 |
+
# Remove any leftover conversation markers
|
131 |
+
unwanted_prefixes = ["<|im_start|>", "assistant:", "user:", "system:"]
|
132 |
+
for prefix in unwanted_prefixes:
|
133 |
+
if response.startswith(prefix):
|
134 |
+
response = response[len(prefix):].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
+
# Handle empty or very short responses
|
137 |
+
if not response or len(response) < 5:
|
|
|
|
|
138 |
if is_force_mode:
|
139 |
+
return "I need more specific information to provide a helpful answer. Could you please clarify your question?"
|
140 |
else:
|
141 |
+
return "That's an interesting question! What do you think the answer might be? Have you tried experimenting with it?"
|
142 |
+
|
143 |
+
# Truncate if too long but ensure complete sentences
|
144 |
+
if len(response) > max_tokens * 6: # Rough character to token ratio
|
145 |
+
sentences = response.split('. ')
|
146 |
+
truncated = ""
|
147 |
+
for sentence in sentences:
|
148 |
+
if len(truncated + sentence + '. ') <= max_tokens * 5:
|
149 |
+
truncated += sentence + '. '
|
150 |
+
else:
|
151 |
+
break
|
152 |
+
response = truncated.rstrip()
|
153 |
+
|
154 |
+
print(f"✅ Final response length: {len(response)}")
|
155 |
+
print(f"📝 Response preview: {response[:100]}...")
|
156 |
|
157 |
return response
|
158 |
|
159 |
except Exception as e:
|
160 |
print(f"❌ Generation error: {e}")
|
161 |
+
import traceback
|
162 |
+
traceback.print_exc()
|
163 |
+
|
164 |
if is_force_mode:
|
165 |
+
return "I encountered an error generating a response. Please try rephrasing your question."
|
166 |
else:
|
167 |
+
return "That's a challenging question! What approach do you think might work? Let's explore this step by step."
|
168 |
|
169 |
# === Routes ===
|
170 |
@app.get("/")
|
171 |
def root():
|
172 |
return {
|
173 |
+
"message": "🤖 Apollo AI Backend v4.1-TRULY-FIXED - Qwen2-0.5B",
|
174 |
"model": "Qwen/Qwen2-0.5B-Instruct with LoRA",
|
175 |
"status": "ready",
|
176 |
"modes": {
|
177 |
+
"mentor": "Guides learning with questions - FIXED GENERATION",
|
178 |
+
"force": "Provides direct answers - FIXED GENERATION"
|
179 |
},
|
180 |
+
"fixes": [
|
181 |
+
"Fixed prompt truncation",
|
182 |
+
"Improved token generation",
|
183 |
+
"Better response cleaning",
|
184 |
+
"Proper mode detection"
|
185 |
+
]
|
186 |
}
|
187 |
|
188 |
@app.get("/health")
|
|
|
191 |
"status": "healthy",
|
192 |
"model_loaded": True,
|
193 |
"model_size": "0.5B",
|
194 |
+
"version": "4.1-TRULY-FIXED"
|
195 |
}
|
196 |
|
197 |
@app.post("/v1/chat/completions")
|
|
|
215 |
try:
|
216 |
body = await request.json()
|
217 |
messages = body.get("messages", [])
|
218 |
+
max_tokens = min(body.get("max_tokens", 300), 500) # Increased default
|
219 |
temperature = max(0.1, min(body.get("temperature", 0.7), 1.0))
|
220 |
|
221 |
+
# CRITICAL: Get force mode flag - check multiple possible names
|
222 |
+
is_force_mode = (
|
223 |
+
body.get("force_mode", False) or
|
224 |
+
body.get("forceMode", False) or
|
225 |
+
body.get("force", False)
|
226 |
+
)
|
227 |
|
228 |
+
print(f"🚨 REQUEST RECEIVED")
|
229 |
+
print(f"🎯 Force mode detected: {is_force_mode}")
|
230 |
+
print(f"📊 Max tokens: {max_tokens}, Temperature: {temperature}")
|
231 |
+
print(f"📝 Messages count: {len(messages)}")
|
232 |
+
if messages:
|
233 |
+
print(f"📝 Last message: {messages[-1].get('content', '')[:100]}...")
|
234 |
|
235 |
if not messages or not isinstance(messages, list):
|
236 |
raise ValueError("Messages field is required and must be a list")
|
237 |
|
238 |
except Exception as e:
|
239 |
+
print(f"❌ Request parsing error: {e}")
|
240 |
return JSONResponse(
|
241 |
status_code=400,
|
242 |
content={"error": f"Invalid request body: {str(e)}"}
|
|
|
251 |
)
|
252 |
|
253 |
try:
|
254 |
+
print(f"🔄 Processing with {len(messages)} messages")
|
255 |
+
print(f"🎯 Mode: {'FORCE (Direct Answer)' if is_force_mode else 'MENTOR (Guiding Questions)'}")
|
256 |
|
257 |
+
# Generate response
|
258 |
response_content = generate_response(
|
259 |
messages=messages,
|
260 |
is_force_mode=is_force_mode,
|
|
|
262 |
temperature=temperature
|
263 |
)
|
264 |
|
265 |
+
# Validate response
|
266 |
+
if not response_content or len(response_content.strip()) < 10:
|
267 |
+
response_content = "I apologize, but I couldn't generate a proper response. Please try rephrasing your question."
|
268 |
+
|
269 |
+
print(f"✅ Response generated successfully")
|
270 |
+
print(f"📊 Response length: {len(response_content)}")
|
271 |
+
print(f"🔍 Mode used: {'force_direct' if is_force_mode else 'mentor_questions'}")
|
272 |
|
273 |
return {
|
274 |
+
"id": f"chatcmpl-apollo-{abs(hash(str(messages))) % 10000}",
|
275 |
"object": "chat.completion",
|
276 |
+
"created": 1704067200, # Fixed timestamp
|
277 |
+
"model": f"qwen2-0.5b-{'force' if is_force_mode else 'mentor'}-v4.1",
|
278 |
"choices": [
|
279 |
{
|
280 |
"index": 0,
|
|
|
286 |
}
|
287 |
],
|
288 |
"usage": {
|
289 |
+
"prompt_tokens": sum(len(msg.get("content", "")) for msg in messages) // 4, # Rough estimate
|
290 |
+
"completion_tokens": len(response_content) // 4, # Rough estimate
|
291 |
+
"total_tokens": (sum(len(msg.get("content", "")) for msg in messages) + len(response_content)) // 4
|
292 |
},
|
293 |
+
"apollo_mode": "force_direct_v4.1" if is_force_mode else "mentor_questions_v4.1",
|
294 |
+
"pure_ai_response": True,
|
295 |
+
"generation_success": True
|
296 |
}
|
297 |
|
298 |
except Exception as e:
|
299 |
print(f"❌ Chat completion error: {e}")
|
300 |
+
import traceback
|
301 |
+
traceback.print_exc()
|
302 |
+
|
303 |
return JSONResponse(
|
304 |
status_code=500,
|
305 |
+
content={
|
306 |
+
"error": f"Internal server error: {str(e)}",
|
307 |
+
"type": "generation_error",
|
308 |
+
"mode_requested": "force" if is_force_mode else "mentor"
|
309 |
+
}
|
310 |
)
|
311 |
|
312 |
if __name__ == "__main__":
|
313 |
import uvicorn
|
314 |
+
print("🚀 Starting Apollo AI Backend v4.1-TRULY-FIXED")
|
315 |
print("🧠 Model: Qwen/Qwen2-0.5B-Instruct (500M parameters)")
|
316 |
+
print("🔧 Fixed: Prompt generation, token handling, response cleaning")
|
317 |
+
print("🎯 Mentor Mode: Guides with questions")
|
318 |
+
print("⚡ Force Mode: Provides direct answers")
|
319 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|