Ais commited on
Commit
0ee4730
·
verified ·
1 Parent(s): 3afe501

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +236 -133
app/main.py CHANGED
@@ -5,11 +5,12 @@ from fastapi.responses import JSONResponse
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from peft import PeftModel
7
  from starlette.middleware.cors import CORSMiddleware
 
8
 
9
  # === Setup FastAPI ===
10
- app = FastAPI()
11
 
12
- # === CORS (optional for frontend access) ===
13
  app.add_middleware(
14
  CORSMiddleware,
15
  allow_origins=["*"],
@@ -18,190 +19,292 @@ app.add_middleware(
18
  allow_headers=["*"],
19
  )
20
 
21
- # === Load API Key from Hugging Face Secrets ===
22
- API_KEY = os.getenv("API_KEY", "undefined") # Add API_KEY in your HF Space Secrets
23
-
24
- # === Model Settings ===
25
  BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
26
  ADAPTER_PATH = "adapter"
27
 
 
28
  print("🔧 Loading tokenizer...")
29
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
30
 
31
- print("🧠 Loading base model on CPU...")
32
  base_model = AutoModelForCausalLM.from_pretrained(
33
  BASE_MODEL,
34
  trust_remote_code=True,
35
- torch_dtype=torch.float32
36
- ).cpu()
 
37
 
38
  print("🔗 Applying LoRA adapter...")
39
- model = PeftModel.from_pretrained(base_model, ADAPTER_PATH).cpu()
40
  model.eval()
41
 
42
- print("✅ Model and adapter loaded successfully.")
43
 
44
- def clean_response(raw_response):
45
  """
46
- Clean the model response by removing unwanted artifacts while preserving the actual answer.
47
  """
48
- if not raw_response or len(raw_response.strip()) < 2:
49
- return "I apologize, but I couldn't generate a proper response. Please try again."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Remove common chat template artifacts
52
- cleaned = raw_response.strip()
53
 
54
- # Remove system/user/assistant prefixes that might leak through
55
- prefixes_to_remove = [
56
- "system\n", "user\n", "assistant\n",
57
- "System:", "User:", "Assistant:",
58
- "<|im_start|>", "<|im_end|>",
59
- "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
60
- "You are a helpful assistant.",
61
- "I am a helpful assistant.",
62
- "As a helpful assistant,",
 
63
  ]
64
 
65
- for prefix in prefixes_to_remove:
66
- if cleaned.lower().startswith(prefix.lower()):
67
- cleaned = cleaned[len(prefix):].strip()
68
 
69
- # Remove any remaining template artifacts
70
- lines = cleaned.split('\n')
71
- filtered_lines = []
72
 
73
  for line in lines:
74
- line_stripped = line.strip()
75
 
76
- # Skip empty lines at the beginning
77
- if not line_stripped and not filtered_lines:
78
  continue
79
 
80
- # Skip obvious template artifacts
81
- if line_stripped in ["system", "user", "assistant"]:
 
 
 
 
 
82
  continue
83
 
84
- filtered_lines.append(line)
 
 
 
85
 
86
- cleaned = '\n'.join(filtered_lines).strip()
 
 
87
 
88
- # If we still have content, return it
89
- if cleaned and len(cleaned) > 5:
90
- return cleaned
 
 
 
 
 
 
91
 
92
- # Fallback only if truly empty
93
- return "I understand your question. Let me help you with that."
94
 
95
- # === Root Route ===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  @app.get("/")
97
  def root():
98
- return {"message": "🧠 Qwen2.5-0.5B-Instruct API is running on CPU!"}
 
 
 
 
 
 
 
 
99
 
100
- # === Chat Completion API ===
101
  @app.post("/v1/chat/completions")
102
- async def chat(request: Request):
103
- # API Key Authorization
104
  auth_header = request.headers.get("Authorization", "")
105
  if not auth_header.startswith("Bearer "):
106
- return JSONResponse(status_code=401, content={"error": "Missing Bearer token in Authorization header."})
 
 
 
107
 
108
  token = auth_header.replace("Bearer ", "").strip()
109
  if token != API_KEY:
110
- return JSONResponse(status_code=401, content={"error": "Invalid API key."})
 
 
 
111
 
112
- # Parse Request
113
  try:
114
  body = await request.json()
115
  messages = body.get("messages", [])
 
 
 
116
  if not messages or not isinstance(messages, list):
117
- raise ValueError("Invalid or missing 'messages' field.")
 
118
  except Exception as e:
119
- return JSONResponse(status_code=400, content={"error": f"Bad request: {str(e)}"})
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- # ✅ FIXED: Use proper Qwen2.5 chat template formatting
122
  try:
123
- # Use the tokenizer's built-in chat template - this is the correct way!
124
- formatted_prompt = tokenizer.apply_chat_template(
125
- messages,
126
- tokenize=False,
127
- add_generation_prompt=True
 
128
  )
129
 
130
- print(f"🔍 Formatted prompt: {formatted_prompt}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  except Exception as e:
133
- print(f"❌ Chat template error: {e}")
134
- # Fallback to manual formatting if needed
135
- formatted_prompt = ""
136
- for msg in messages:
137
- role = msg.get("role", "user")
138
- content = msg.get("content", "")
139
- if role == "system":
140
- formatted_prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
141
- elif role == "user":
142
- formatted_prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
143
- elif role == "assistant":
144
- formatted_prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
145
- formatted_prompt += "<|im_start|>assistant\n"
146
-
147
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
148
-
149
- # ✅ Generate Response with optimized settings
150
- with torch.no_grad():
151
- outputs = model.generate(
152
- **inputs,
153
- max_new_tokens=512, # Increased for better responses
154
- temperature=0.7,
155
- top_p=0.9,
156
- do_sample=True,
157
- pad_token_id=tokenizer.eos_token_id,
158
- repetition_penalty=1.05, # Slightly reduced
159
- length_penalty=1.0,
160
- early_stopping=True
161
  )
162
 
163
- # FIXED: Better response extraction
164
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
165
-
166
- print(f"🔍 Full generated response: {full_response}")
167
-
168
- # Extract only the new generated part (after the prompt)
169
- if formatted_prompt in full_response:
170
- generated_part = full_response.split(formatted_prompt)[-1].strip()
171
- else:
172
- # If we can't find the exact prompt, try to extract the assistant's response
173
- assistant_marker = "<|im_start|>assistant\n"
174
- if assistant_marker in full_response:
175
- parts = full_response.split(assistant_marker)
176
- generated_part = parts[-1].split("<|im_end|>")[0].strip() if len(parts) > 1 else full_response
177
- else:
178
- generated_part = full_response
179
-
180
- print(f"🔍 Extracted generated part: {generated_part}")
181
-
182
- # ✅ Clean the response but keep it intact
183
- final_answer = clean_response(generated_part)
184
-
185
- print(f"🔍 Final cleaned answer: {final_answer}")
186
-
187
- # ✅ OpenAI-style Response
188
- return {
189
- "id": "chatcmpl-local-001",
190
- "object": "chat.completion",
191
- "model": "Qwen2.5-0.5B-Instruct-LoRA",
192
- "choices": [
193
- {
194
- "index": 0,
195
- "message": {
196
- "role": "assistant",
197
- "content": final_answer
198
- },
199
- "finish_reason": "stop"
200
- }
201
- ],
202
- "usage": {
203
- "prompt_tokens": len(inputs.input_ids[0]),
204
- "completion_tokens": len(outputs[0]) - len(inputs.input_ids[0]),
205
- "total_tokens": len(outputs[0])
206
  }
207
- }
 
 
 
 
 
 
 
 
 
 
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from peft import PeftModel
7
  from starlette.middleware.cors import CORSMiddleware
8
+ import re
9
 
10
  # === Setup FastAPI ===
11
+ app = FastAPI(title="Apollo AI Backend", version="1.0.0")
12
 
13
+ # === CORS ===
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
 
19
  allow_headers=["*"],
20
  )
21
 
22
+ # === Configuration ===
23
+ API_KEY = os.getenv("API_KEY", "aigenapikey1234567890")
 
 
24
  BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
25
  ADAPTER_PATH = "adapter"
26
 
27
+ # === Load Model ===
28
  print("🔧 Loading tokenizer...")
29
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
30
 
31
+ print("🧠 Loading base model...")
32
  base_model = AutoModelForCausalLM.from_pretrained(
33
  BASE_MODEL,
34
  trust_remote_code=True,
35
+ torch_dtype=torch.float32,
36
+ device_map="cpu"
37
+ )
38
 
39
  print("🔗 Applying LoRA adapter...")
40
+ model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
41
  model.eval()
42
 
43
+ print("✅ Model ready!")
44
 
45
+ def extract_clean_answer(full_response: str, formatted_prompt: str, user_message: str) -> str:
46
  """
47
+ Extract only the AI's response, removing all template artifacts and system prompt leaks.
48
  """
49
+ if not full_response or len(full_response.strip()) < 5:
50
+ return "I apologize, but I couldn't generate a response. Please try again."
51
+
52
+ print(f"🔍 Input full_response length: {len(full_response)}")
53
+ print(f"🔍 Input full_response preview: {full_response[:200]}...")
54
+
55
+ # Step 1: Remove the input prompt to isolate the generated part
56
+ generated_text = full_response
57
+ if formatted_prompt in full_response:
58
+ generated_text = full_response.split(formatted_prompt)[-1]
59
+
60
+ # Step 2: Extract content between assistant tags
61
+ assistant_pattern = r'<\|im_start\|>assistant\n(.*?)(?:<\|im_end\|>|$)'
62
+ assistant_matches = re.findall(assistant_pattern, generated_text, re.DOTALL)
63
+
64
+ if assistant_matches:
65
+ generated_text = assistant_matches[-1] # Get the last (newest) assistant response
66
+
67
+ # Step 3: Remove common template artifacts
68
+ artifacts_to_remove = [
69
+ r'<\|im_start\|>.*?<\|im_end\|>',
70
+ r'<\|im_start\|>.*',
71
+ r'<\|im_end\|>.*',
72
+ r'^(system|user|assistant):\s*',
73
+ r'^\s*(system|user|assistant)\s*\n',
74
+ ]
75
 
76
+ for pattern in artifacts_to_remove:
77
+ generated_text = re.sub(pattern, '', generated_text, flags=re.MULTILINE | re.IGNORECASE)
78
 
79
+ # Step 4: Aggressive system prompt leak removal
80
+ system_leaks = [
81
+ r'You are.*?(?=\n\n|\n[A-Z]|\.|$)',
82
+ r'Guidelines:.*?(?=\n\n|\n[A-Z]|$)',
83
+ r'Response format:.*?(?=\n\n|\n[A-Z]|$)',
84
+ r'- Provide.*?(?=\n\n|\n[A-Z]|$)',
85
+ r'- Use.*?(?=\n\n|\n[A-Z]|$)',
86
+ r'NEVER include.*?(?=\n\n|\n[A-Z]|$)',
87
+ r'VS Code Context:.*?(?=\n\n|\n[A-Z]|$)',
88
+ r'\[VS Code Context:.*?\]',
89
  ]
90
 
91
+ for leak_pattern in system_leaks:
92
+ generated_text = re.sub(leak_pattern, '', generated_text, flags=re.DOTALL | re.IGNORECASE)
 
93
 
94
+ # Step 5: Clean up whitespace and format
95
+ lines = generated_text.split('\n')
96
+ clean_lines = []
97
 
98
  for line in lines:
99
+ line = line.strip()
100
 
101
+ # Skip empty lines at the start
102
+ if not line and not clean_lines:
103
  continue
104
 
105
+ # Skip lines that are obviously system prompts
106
+ skip_patterns = [
107
+ 'you are a helpful', 'guidelines', 'response format', 'provide clear',
108
+ 'use markdown', 'never include', 'vs code context', 'current request'
109
+ ]
110
+
111
+ if any(pattern in line.lower() for pattern in skip_patterns):
112
  continue
113
 
114
+ clean_lines.append(line)
115
+
116
+ # Step 6: Reconstruct the response
117
+ final_answer = '\n'.join(clean_lines).strip()
118
 
119
+ # Step 7: Handle edge cases
120
+ if not final_answer or len(final_answer) < 10:
121
+ return "I understand your question. Could you please rephrase it for a clearer answer?"
122
 
123
+ # Step 8: Remove any remaining question echoes
124
+ if user_message and len(user_message) > 10:
125
+ user_words = set(user_message.lower().split())
126
+ first_sentence = final_answer.split('.')[0]
127
+ if len(set(first_sentence.lower().split()) & user_words) > len(user_words) * 0.7:
128
+ # First sentence likely echoes the question, remove it
129
+ remaining = '.'.join(final_answer.split('.')[1:]).strip()
130
+ if remaining and len(remaining) > 20:
131
+ final_answer = remaining
132
 
133
+ print(f"🧹 Final cleaned answer: {final_answer}")
134
+ return final_answer
135
 
136
+ def generate_response(messages: list, max_tokens: int = 300, temperature: float = 0.7) -> str:
137
+ """
138
+ Generate response using the model with proper chat formatting.
139
+ """
140
+ try:
141
+ # Build the conversation using tokenizer's chat template
142
+ formatted_prompt = tokenizer.apply_chat_template(
143
+ messages,
144
+ tokenize=False,
145
+ add_generation_prompt=True
146
+ )
147
+
148
+ print(f"🔍 Formatted prompt: {formatted_prompt}")
149
+
150
+ # Tokenize
151
+ inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=2048)
152
+
153
+ # Generate
154
+ with torch.no_grad():
155
+ outputs = model.generate(
156
+ inputs.input_ids,
157
+ attention_mask=inputs.attention_mask,
158
+ max_new_tokens=max_tokens,
159
+ temperature=temperature,
160
+ top_p=0.9,
161
+ do_sample=True,
162
+ pad_token_id=tokenizer.eos_token_id,
163
+ eos_token_id=tokenizer.eos_token_id,
164
+ repetition_penalty=1.05,
165
+ length_penalty=1.0,
166
+ early_stopping=True
167
+ )
168
+
169
+ # Decode the full response
170
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
171
+
172
+ # Extract user message for cleaning
173
+ user_message = ""
174
+ for msg in messages:
175
+ if msg.get("role") == "user":
176
+ user_message = msg.get("content", "")
177
+
178
+ # Clean and extract the answer
179
+ clean_answer = extract_clean_answer(full_response, formatted_prompt, user_message)
180
+
181
+ return clean_answer
182
+
183
+ except Exception as e:
184
+ print(f"❌ Generation error: {e}")
185
+ return f"I encountered an error while processing your request. Please try again."
186
+
187
+ # === Routes ===
188
  @app.get("/")
189
  def root():
190
+ return {
191
+ "message": "🤖 Apollo AI Backend is running!",
192
+ "model": "Qwen2-0.5B-Instruct with LoRA",
193
+ "status": "ready"
194
+ }
195
+
196
+ @app.get("/health")
197
+ def health():
198
+ return {"status": "healthy", "model_loaded": True}
199
 
 
200
  @app.post("/v1/chat/completions")
201
+ async def chat_completions(request: Request):
202
+ # Validate API key
203
  auth_header = request.headers.get("Authorization", "")
204
  if not auth_header.startswith("Bearer "):
205
+ return JSONResponse(
206
+ status_code=401,
207
+ content={"error": "Missing or invalid Authorization header"}
208
+ )
209
 
210
  token = auth_header.replace("Bearer ", "").strip()
211
  if token != API_KEY:
212
+ return JSONResponse(
213
+ status_code=401,
214
+ content={"error": "Invalid API key"}
215
+ )
216
 
217
+ # Parse request body
218
  try:
219
  body = await request.json()
220
  messages = body.get("messages", [])
221
+ max_tokens = body.get("max_tokens", 300)
222
+ temperature = body.get("temperature", 0.7)
223
+
224
  if not messages or not isinstance(messages, list):
225
+ raise ValueError("Messages field is required and must be a list")
226
+
227
  except Exception as e:
228
+ return JSONResponse(
229
+ status_code=400,
230
+ content={"error": f"Invalid request body: {str(e)}"}
231
+ )
232
+
233
+ # Validate messages format
234
+ for i, msg in enumerate(messages):
235
+ if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
236
+ return JSONResponse(
237
+ status_code=400,
238
+ content={"error": f"Invalid message format at index {i}"}
239
+ )
240
 
 
241
  try:
242
+ # Generate response
243
+ print(f"📥 Processing {len(messages)} messages")
244
+ response_content = generate_response(
245
+ messages=messages,
246
+ max_tokens=min(max_tokens, 500), # Cap max tokens
247
+ temperature=max(0.1, min(temperature, 1.0)) # Clamp temperature
248
  )
249
 
250
+ # Return OpenAI-compatible response
251
+ return {
252
+ "id": f"chatcmpl-apollo-{hash(str(messages)) % 10000}",
253
+ "object": "chat.completion",
254
+ "created": int(torch.tensor(0).item()), # Simple timestamp
255
+ "model": "qwen2-0.5b-instruct-lora",
256
+ "choices": [
257
+ {
258
+ "index": 0,
259
+ "message": {
260
+ "role": "assistant",
261
+ "content": response_content
262
+ },
263
+ "finish_reason": "stop"
264
+ }
265
+ ],
266
+ "usage": {
267
+ "prompt_tokens": len(str(messages)), # Approximate
268
+ "completion_tokens": len(response_content), # Approximate
269
+ "total_tokens": len(str(messages)) + len(response_content)
270
+ }
271
+ }
272
 
273
  except Exception as e:
274
+ print(f"❌ Chat completion error: {e}")
275
+ return JSONResponse(
276
+ status_code=500,
277
+ content={"error": f"Internal server error: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  )
279
 
280
+ # === Test endpoint for debugging ===
281
+ @app.post("/test")
282
+ async def test_generation(request: Request):
283
+ """Test endpoint for debugging the model directly"""
284
+ try:
285
+ body = await request.json()
286
+ prompt = body.get("prompt", "Hello, how are you?")
287
+
288
+ messages = [
289
+ {"role": "system", "content": "You are a helpful assistant."},
290
+ {"role": "user", "content": prompt}
291
+ ]
292
+
293
+ response = generate_response(messages, max_tokens=200, temperature=0.7)
294
+
295
+ return {
296
+ "prompt": prompt,
297
+ "response": response,
298
+ "status": "success"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  }
300
+
301
+ except Exception as e:
302
+ return JSONResponse(
303
+ status_code=500,
304
+ content={"error": str(e)}
305
+ )
306
+
307
+ if __name__ == "__main__":
308
+ import uvicorn
309
+ print("🚀 Starting Apollo AI Backend...")
310
+ uvicorn.run(app, host="0.0.0.0", port=7860)