Ais
commited on
Update app/main.py
Browse files- app/main.py +208 -187
app/main.py
CHANGED
@@ -1,190 +1,211 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
54 |
|
55 |
-
|
56 |
-
if
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
static getHistory() {
|
66 |
-
return [...this.conversationHistory];
|
67 |
-
}
|
68 |
-
}
|
69 |
-
|
70 |
-
export async function askAI(prompt: string, options: {
|
71 |
-
temperature?: number;
|
72 |
-
maxTokens?: number;
|
73 |
-
includeContext?: boolean;
|
74 |
-
retries?: number;
|
75 |
-
forceMode?: boolean;
|
76 |
-
} = {}): Promise<string> {
|
77 |
-
const {
|
78 |
-
temperature = 0.7,
|
79 |
-
maxTokens = 1500,
|
80 |
-
includeContext = true,
|
81 |
-
retries = MAX_RETRIES,
|
82 |
-
forceMode = false
|
83 |
-
} = options;
|
84 |
-
|
85 |
-
console.log('🤖 Apollo AI: Starting request for prompt:', prompt.substring(0, 100) + '...');
|
86 |
-
console.log('🔧 Force mode:', forceMode);
|
87 |
-
|
88 |
-
// Build messages array for proper conversation
|
89 |
-
const messages = [];
|
90 |
-
|
91 |
-
// ✅ FIXED: Much simpler system messages
|
92 |
-
if (forceMode) {
|
93 |
-
messages.push({
|
94 |
-
role: 'system',
|
95 |
-
content: 'Give direct, brief answers only. No explanations.'
|
96 |
-
});
|
97 |
-
} else {
|
98 |
-
messages.push({
|
99 |
-
role: 'system',
|
100 |
-
content: 'You are a helpful assistant.'
|
101 |
-
});
|
102 |
-
}
|
103 |
-
|
104 |
-
// Add conversation history (only if includeContext is true and we have history)
|
105 |
-
if (includeContext && ApolloAI.getHistory().length > 0) {
|
106 |
-
const history = ApolloAI.getHistory().slice(-2); // Last 2 messages only
|
107 |
-
for (const msg of history) {
|
108 |
-
messages.push({
|
109 |
-
role: msg.role,
|
110 |
-
content: msg.content
|
111 |
-
});
|
112 |
-
}
|
113 |
-
}
|
114 |
-
|
115 |
-
// Add current user message
|
116 |
-
messages.push({
|
117 |
-
role: 'user',
|
118 |
-
content: prompt
|
119 |
-
});
|
120 |
-
|
121 |
-
// Add VS Code context if available (but not in conversation history)
|
122 |
-
const editor = vscode.window.activeTextEditor;
|
123 |
-
if (includeContext && editor && !forceMode) {
|
124 |
-
const fileName = editor.document.fileName.split(/[/\\]/).pop();
|
125 |
-
const language = editor.document.languageId;
|
126 |
-
messages[messages.length - 1].content += `\n\n[VS Code Context: Editing ${fileName} (${language})]`;
|
127 |
-
}
|
128 |
-
|
129 |
-
const headers = {
|
130 |
-
'Authorization': `Bearer ${API_KEY}`,
|
131 |
-
'Content-Type': 'application/json',
|
132 |
-
'User-Agent': 'Apollo-AI-VSCode-Extension/1.2.0'
|
133 |
-
};
|
134 |
-
|
135 |
-
const body = {
|
136 |
-
messages: messages,
|
137 |
-
temperature: forceMode ? 0.3 : temperature, // Lower temperature for force mode
|
138 |
-
max_tokens: forceMode ? 200 : maxTokens, // Much shorter responses for force mode
|
139 |
-
stream: false
|
140 |
-
};
|
141 |
-
|
142 |
-
for (let attempt = 1; attempt <= retries; attempt++) {
|
143 |
-
try {
|
144 |
-
const fetchImpl = await getFetch();
|
145 |
-
|
146 |
-
console.log(`🚀 Apollo AI: Attempt ${attempt}/${retries}, sending request to API...`);
|
147 |
-
console.log('📤 Request body:', JSON.stringify(body, null, 2));
|
148 |
-
|
149 |
-
const controller = new AbortController();
|
150 |
-
const timeoutId = setTimeout(() => controller.abort(), TIMEOUT_MS);
|
151 |
-
|
152 |
-
const res = await fetchImpl(API_URL, {
|
153 |
-
method: 'POST',
|
154 |
-
headers,
|
155 |
-
body: JSON.stringify(body),
|
156 |
-
signal: controller.signal
|
157 |
-
});
|
158 |
-
|
159 |
-
clearTimeout(timeoutId);
|
160 |
-
|
161 |
-
console.log('📨 Apollo AI: Received response, status:', res.status);
|
162 |
-
|
163 |
-
if (!res.ok) {
|
164 |
-
const errorText = await res.text().catch(() => 'Unable to read error response');
|
165 |
-
console.error(`❌ Apollo AI: API Error ${res.status}: ${errorText}`);
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
} else if (res.status === 401) {
|
170 |
-
throw new Error('🔑 Authentication failed. Please check your API key.');
|
171 |
-
} else if (res.status >= 500) {
|
172 |
-
throw new Error('🔧 Server error. The AI service is temporarily unavailable.');
|
173 |
-
}
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from fastapi import FastAPI, Request
|
4 |
+
from fastapi.responses import JSONResponse
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
+
from peft import PeftModel
|
7 |
+
from starlette.middleware.cors import CORSMiddleware
|
8 |
+
|
9 |
+
# === Setup FastAPI ===
|
10 |
+
app = FastAPI()
|
11 |
+
|
12 |
+
# === CORS (optional for frontend access) ===
|
13 |
+
app.add_middleware(
|
14 |
+
CORSMiddleware,
|
15 |
+
allow_origins=["*"],
|
16 |
+
allow_credentials=True,
|
17 |
+
allow_methods=["*"],
|
18 |
+
allow_headers=["*"],
|
19 |
+
)
|
20 |
+
|
21 |
+
# === Load API Key from Hugging Face Secrets ===
|
22 |
+
API_KEY = os.getenv("API_KEY", "undefined")
|
23 |
+
|
24 |
+
# === Model Settings ===
|
25 |
+
BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
|
26 |
+
ADAPTER_PATH = "adapter"
|
27 |
+
|
28 |
+
print("🔧 Loading tokenizer...")
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
|
30 |
+
|
31 |
+
print("🧠 Loading base model on CPU...")
|
32 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
33 |
+
BASE_MODEL,
|
34 |
+
trust_remote_code=True,
|
35 |
+
torch_dtype=torch.float32
|
36 |
+
).cpu()
|
37 |
+
|
38 |
+
print("🔗 Applying LoRA adapter...")
|
39 |
+
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH).cpu()
|
40 |
+
model.eval()
|
41 |
+
|
42 |
+
print("✅ Model and adapter loaded successfully.")
|
43 |
+
|
44 |
+
# === Root Route ===
|
45 |
+
@app.get("/")
|
46 |
+
def root():
|
47 |
+
return {"message": "🧠 Qwen2.5-0.5B-Instruct API is running on CPU!"}
|
48 |
+
|
49 |
+
# === Chat Completion API ===
|
50 |
+
@app.post("/v1/chat/completions")
|
51 |
+
async def chat(request: Request):
|
52 |
+
# ✅ API Key Authorization
|
53 |
+
auth_header = request.headers.get("Authorization", "")
|
54 |
+
if not auth_header.startswith("Bearer "):
|
55 |
+
return JSONResponse(status_code=401, content={"error": "Missing Bearer token in Authorization header."})
|
56 |
|
57 |
+
token = auth_header.replace("Bearer ", "").strip()
|
58 |
+
if token != API_KEY:
|
59 |
+
return JSONResponse(status_code=401, content={"error": "Invalid API key."})
|
60 |
+
|
61 |
+
# ✅ Parse Request
|
62 |
+
try:
|
63 |
+
body = await request.json()
|
64 |
+
messages = body.get("messages", [])
|
65 |
+
if not messages or not isinstance(messages, list):
|
66 |
+
raise ValueError("Invalid or missing 'messages' field.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
+
temperature = body.get("temperature", 0.7)
|
69 |
+
max_tokens = body.get("max_tokens", 512)
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
except Exception as e:
|
72 |
+
return JSONResponse(status_code=400, content={"error": f"Bad request: {str(e)}"})
|
73 |
+
|
74 |
+
# ✅ FIXED: Only use last 4 messages to prevent stacking
|
75 |
+
recent_messages = messages[-4:] if len(messages) > 4 else messages
|
76 |
+
|
77 |
+
# ✅ Build clean conversation prompt
|
78 |
+
formatted_prompt = ""
|
79 |
+
|
80 |
+
for message in recent_messages:
|
81 |
+
role = message.get("role", "")
|
82 |
+
content = message.get("content", "")
|
83 |
+
|
84 |
+
if role == "system":
|
85 |
+
formatted_prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
|
86 |
+
elif role == "user":
|
87 |
+
formatted_prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
|
88 |
+
elif role == "assistant":
|
89 |
+
formatted_prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
|
90 |
+
|
91 |
+
# Add the assistant start token for generation
|
92 |
+
formatted_prompt += "<|im_start|>assistant\n"
|
93 |
+
|
94 |
+
print(f"🤖 Processing {len(recent_messages)} recent messages")
|
95 |
+
|
96 |
+
inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
|
97 |
+
|
98 |
+
# ✅ Generate Response
|
99 |
+
with torch.no_grad():
|
100 |
+
outputs = model.generate(
|
101 |
+
**inputs,
|
102 |
+
max_new_tokens=max_tokens,
|
103 |
+
temperature=temperature,
|
104 |
+
top_p=0.9,
|
105 |
+
do_sample=True,
|
106 |
+
pad_token_id=tokenizer.eos_token_id,
|
107 |
+
eos_token_id=tokenizer.eos_token_id
|
108 |
+
)
|
109 |
+
|
110 |
+
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
111 |
+
|
112 |
+
# ✅ MUCH BETTER: Extract only the final assistant response
|
113 |
+
if "<|im_start|>assistant\n" in decoded:
|
114 |
+
# Get everything after the LAST assistant token
|
115 |
+
parts = decoded.split("<|im_start|>assistant\n")
|
116 |
+
final_answer = parts[-1].strip()
|
117 |
+
else:
|
118 |
+
# Fallback if no assistant token found
|
119 |
+
final_answer = decoded.strip()
|
120 |
+
|
121 |
+
# Remove end token
|
122 |
+
if "<|im_end|>" in final_answer:
|
123 |
+
final_answer = final_answer.split("<|im_end|>")[0].strip()
|
124 |
+
|
125 |
+
# ✅ CRITICAL: Remove conversation artifacts that leak through
|
126 |
+
# Remove user/assistant role labels that appear in content
|
127 |
+
final_answer = final_answer.replace("user\n", "").replace("assistant\n", "")
|
128 |
+
|
129 |
+
# Remove repeated questions and conversation artifacts
|
130 |
+
lines = final_answer.split('\n')
|
131 |
+
cleaned_lines = []
|
132 |
+
seen_content = set()
|
133 |
+
found_answer = False
|
134 |
+
|
135 |
+
for line in lines:
|
136 |
+
line = line.strip()
|
137 |
+
|
138 |
+
# Skip empty lines at the start
|
139 |
+
if not line and not found_answer:
|
140 |
+
continue
|
141 |
+
|
142 |
+
# Skip if this exact line was seen before (removes repeats)
|
143 |
+
if line in seen_content:
|
144 |
+
continue
|
145 |
+
|
146 |
+
# Skip lines that look like user prompts being repeated
|
147 |
+
if line.endswith('?') and len(line) < 100 and not found_answer:
|
148 |
+
print(f"🚫 Skipping repeated question: {line}")
|
149 |
+
continue
|
150 |
+
|
151 |
+
# Skip role indicators
|
152 |
+
if line in ['user', 'assistant', 'system']:
|
153 |
+
continue
|
154 |
+
|
155 |
+
# Skip conversation tokens
|
156 |
+
if '<|im_start|>' in line or '<|im_end|>' in line:
|
157 |
+
continue
|
158 |
+
|
159 |
+
# If we get here, this looks like actual content
|
160 |
+
found_answer = True
|
161 |
+
cleaned_lines.append(line)
|
162 |
+
seen_content.add(line)
|
163 |
+
|
164 |
+
final_answer = '\n'.join(cleaned_lines).strip()
|
165 |
+
|
166 |
+
# Remove VS Code context if it leaked through
|
167 |
+
if "[VS Code Context:" in final_answer:
|
168 |
+
context_lines = final_answer.split('\n')
|
169 |
+
cleaned_context_lines = [line for line in context_lines if not line.strip().startswith('[VS Code Context:')]
|
170 |
+
final_answer = '\n'.join(cleaned_context_lines).strip()
|
171 |
+
|
172 |
+
# Remove system prompts that leaked through
|
173 |
+
system_indicators = [
|
174 |
+
"Guidelines:",
|
175 |
+
"Response format:",
|
176 |
+
"You are a helpful",
|
177 |
+
"I'm here to help",
|
178 |
+
"system\n",
|
179 |
+
"assistant\n",
|
180 |
+
"user\n"
|
181 |
+
]
|
182 |
+
|
183 |
+
for indicator in system_indicators:
|
184 |
+
if indicator in final_answer:
|
185 |
+
final_answer = final_answer.split(indicator)[0].strip()
|
186 |
+
|
187 |
+
# Clean up extra whitespace
|
188 |
+
final_answer = final_answer.replace('\n\n\n', '\n\n').strip()
|
189 |
+
|
190 |
+
# Ensure we have some content
|
191 |
+
if not final_answer or len(final_answer.strip()) < 3:
|
192 |
+
final_answer = "I apologize, but I couldn't generate a proper response. Please try again."
|
193 |
+
|
194 |
+
print(f"✅ Clean response: {final_answer[:100]}...")
|
195 |
+
|
196 |
+
# ✅ OpenAI-style Response
|
197 |
+
return {
|
198 |
+
"id": "chatcmpl-local-001",
|
199 |
+
"object": "chat.completion",
|
200 |
+
"model": "Qwen2.5-0.5B-Instruct-LoRA",
|
201 |
+
"choices": [
|
202 |
+
{
|
203 |
+
"index": 0,
|
204 |
+
"message": {
|
205 |
+
"role": "assistant",
|
206 |
+
"content": final_answer
|
207 |
+
},
|
208 |
+
"finish_reason": "stop"
|
209 |
+
}
|
210 |
+
]
|
211 |
+
}
|