Ais
commited on
Update app/main.py
Browse files- app/main.py +186 -145
app/main.py
CHANGED
@@ -1,149 +1,190 @@
|
|
1 |
-
|
2 |
-
import torch
|
3 |
-
from fastapi import FastAPI, Request
|
4 |
-
from fastapi.responses import JSONResponse
|
5 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
-
from peft import PeftModel
|
7 |
-
from starlette.middleware.cors import CORSMiddleware
|
8 |
-
|
9 |
-
# === Setup FastAPI ===
|
10 |
-
app = FastAPI()
|
11 |
-
|
12 |
-
# === CORS (optional for frontend access) ===
|
13 |
-
app.add_middleware(
|
14 |
-
CORSMiddleware,
|
15 |
-
allow_origins=["*"],
|
16 |
-
allow_credentials=True,
|
17 |
-
allow_methods=["*"],
|
18 |
-
allow_headers=["*"],
|
19 |
-
)
|
20 |
-
|
21 |
-
# === Load API Key from Hugging Face Secrets ===
|
22 |
-
API_KEY = os.getenv("API_KEY", "undefined")
|
23 |
-
|
24 |
-
# === Model Settings ===
|
25 |
-
BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
|
26 |
-
ADAPTER_PATH = "adapter"
|
27 |
-
|
28 |
-
print("🔧 Loading tokenizer...")
|
29 |
-
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
|
30 |
-
|
31 |
-
print("🧠 Loading base model on CPU...")
|
32 |
-
base_model = AutoModelForCausalLM.from_pretrained(
|
33 |
-
BASE_MODEL,
|
34 |
-
trust_remote_code=True,
|
35 |
-
torch_dtype=torch.float32
|
36 |
-
).cpu()
|
37 |
-
|
38 |
-
print("🔗 Applying LoRA adapter...")
|
39 |
-
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH).cpu()
|
40 |
-
model.eval()
|
41 |
-
|
42 |
-
print("✅ Model and adapter loaded successfully.")
|
43 |
-
|
44 |
-
# === Root Route ===
|
45 |
-
@app.get("/")
|
46 |
-
def root():
|
47 |
-
return {"message": "🧠 Qwen2.5-0.5B-Instruct API is running on CPU!"}
|
48 |
-
|
49 |
-
# === Chat Completion API ===
|
50 |
-
@app.post("/v1/chat/completions")
|
51 |
-
async def chat(request: Request):
|
52 |
-
# ✅ API Key Authorization
|
53 |
-
auth_header = request.headers.get("Authorization", "")
|
54 |
-
if not auth_header.startswith("Bearer "):
|
55 |
-
return JSONResponse(status_code=401, content={"error": "Missing Bearer token in Authorization header."})
|
56 |
-
|
57 |
-
token = auth_header.replace("Bearer ", "").strip()
|
58 |
-
if token != API_KEY:
|
59 |
-
return JSONResponse(status_code=401, content={"error": "Invalid API key."})
|
60 |
-
|
61 |
-
# ✅ Parse Request
|
62 |
-
try:
|
63 |
-
body = await request.json()
|
64 |
-
messages = body.get("messages", [])
|
65 |
-
if not messages or not isinstance(messages, list):
|
66 |
-
raise ValueError("Invalid or missing 'messages' field.")
|
67 |
-
|
68 |
-
temperature = body.get("temperature", 0.7)
|
69 |
-
max_tokens = body.get("max_tokens", 512)
|
70 |
-
|
71 |
-
except Exception as e:
|
72 |
-
return JSONResponse(status_code=400, content={"error": f"Bad request: {str(e)}"})
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
-
if
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
# Add the assistant start token for generation
|
92 |
-
formatted_prompt += "<|im_start|>assistant\n"
|
93 |
-
|
94 |
-
print(f"🤖 Processing {len(recent_messages)} recent messages")
|
95 |
-
|
96 |
-
inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
|
97 |
-
|
98 |
-
# ✅ Generate Response
|
99 |
-
with torch.no_grad():
|
100 |
-
outputs = model.generate(
|
101 |
-
**inputs,
|
102 |
-
max_new_tokens=max_tokens,
|
103 |
-
temperature=temperature,
|
104 |
-
top_p=0.9,
|
105 |
-
do_sample=True,
|
106 |
-
pad_token_id=tokenizer.eos_token_id,
|
107 |
-
eos_token_id=tokenizer.eos_token_id
|
108 |
-
)
|
109 |
-
|
110 |
-
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
111 |
-
|
112 |
-
# ✅ FIXED: Extract ONLY the new assistant response
|
113 |
-
final_answer = decoded.split("<|im_start|>assistant\n")[-1].strip()
|
114 |
-
|
115 |
-
# Remove any end tokens or artifacts
|
116 |
-
if "<|im_end|>" in final_answer:
|
117 |
-
final_answer = final_answer.split("<|im_end|>")[0].strip()
|
118 |
-
|
119 |
-
# Remove any repeated system prompts or guidelines that leaked through
|
120 |
-
if "Guidelines:" in final_answer:
|
121 |
-
final_answer = final_answer.split("Guidelines:")[0].strip()
|
122 |
-
|
123 |
-
if "Response format:" in final_answer:
|
124 |
-
final_answer = final_answer.split("Response format:")[0].strip()
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
"message": {
|
143 |
-
"role": "assistant",
|
144 |
-
"content": final_answer
|
145 |
-
},
|
146 |
-
"finish_reason": "stop"
|
147 |
-
}
|
148 |
-
]
|
149 |
-
}
|
|
|
1 |
+
// chat.ts - Fixed Apollo AI Chat Module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
import * as vscode from 'vscode';
|
4 |
+
|
5 |
+
// Use global fetch if available (VS Code >=1.74), otherwise fallback to dynamic import of node-fetch
|
6 |
+
async function getFetch(): Promise<any> {
|
7 |
+
if (typeof (globalThis as any).fetch !== 'undefined') {
|
8 |
+
return (globalThis as any).fetch;
|
9 |
+
} else {
|
10 |
+
try {
|
11 |
+
const fetch = require('node-fetch');
|
12 |
+
return fetch.default || fetch;
|
13 |
+
} catch (error) {
|
14 |
+
throw new Error('Unable to load fetch. Please ensure node-fetch is installed.');
|
15 |
+
}
|
16 |
+
}
|
17 |
+
}
|
18 |
+
|
19 |
+
// Configuration
|
20 |
+
const API_URL = 'https://ais0909-aigen.hf.space/v1/chat/completions';
|
21 |
+
const API_KEY = 'aigenapikey1234567890';
|
22 |
+
const MAX_RETRIES = 3;
|
23 |
+
const TIMEOUT_MS = 300000; // 5 minutes
|
24 |
+
|
25 |
+
interface APIResponse {
|
26 |
+
choices?: Array<{
|
27 |
+
message?: { content?: string };
|
28 |
+
text?: string;
|
29 |
+
}>;
|
30 |
+
generated_text?: string;
|
31 |
+
error?: string;
|
32 |
+
id?: string;
|
33 |
+
object?: string;
|
34 |
+
model?: string;
|
35 |
+
}
|
36 |
+
|
37 |
+
interface ChatContext {
|
38 |
+
currentFile?: string;
|
39 |
+
language?: string;
|
40 |
+
workspaceFolder?: string;
|
41 |
+
selectedText?: string;
|
42 |
+
}
|
43 |
+
|
44 |
+
export class ApolloAI {
|
45 |
+
private static conversationHistory: Array<{role: string, content: string}> = [];
|
46 |
+
private static context: ChatContext = {};
|
47 |
+
|
48 |
+
static setContext(context: ChatContext) {
|
49 |
+
this.context = context;
|
50 |
+
}
|
51 |
+
|
52 |
+
static addToHistory(role: 'user' | 'assistant', content: string) {
|
53 |
+
this.conversationHistory.push({ role, content });
|
54 |
|
55 |
+
// Keep only last 2 messages to prevent conversation stacking
|
56 |
+
if (this.conversationHistory.length > 2) {
|
57 |
+
this.conversationHistory = this.conversationHistory.slice(-2);
|
58 |
+
}
|
59 |
+
}
|
60 |
+
|
61 |
+
static clearHistory() {
|
62 |
+
this.conversationHistory = [];
|
63 |
+
}
|
64 |
+
|
65 |
+
static getHistory() {
|
66 |
+
return [...this.conversationHistory];
|
67 |
+
}
|
68 |
+
}
|
69 |
+
|
70 |
+
export async function askAI(prompt: string, options: {
|
71 |
+
temperature?: number;
|
72 |
+
maxTokens?: number;
|
73 |
+
includeContext?: boolean;
|
74 |
+
retries?: number;
|
75 |
+
forceMode?: boolean;
|
76 |
+
} = {}): Promise<string> {
|
77 |
+
const {
|
78 |
+
temperature = 0.7,
|
79 |
+
maxTokens = 1500,
|
80 |
+
includeContext = true,
|
81 |
+
retries = MAX_RETRIES,
|
82 |
+
forceMode = false
|
83 |
+
} = options;
|
84 |
+
|
85 |
+
console.log('🤖 Apollo AI: Starting request for prompt:', prompt.substring(0, 100) + '...');
|
86 |
+
console.log('🔧 Force mode:', forceMode);
|
87 |
+
|
88 |
+
// Build messages array for proper conversation
|
89 |
+
const messages = [];
|
90 |
+
|
91 |
+
// ✅ FIXED: Much simpler system messages
|
92 |
+
if (forceMode) {
|
93 |
+
messages.push({
|
94 |
+
role: 'system',
|
95 |
+
content: 'Give direct, brief answers only. No explanations.'
|
96 |
+
});
|
97 |
+
} else {
|
98 |
+
messages.push({
|
99 |
+
role: 'system',
|
100 |
+
content: 'You are a helpful assistant.'
|
101 |
+
});
|
102 |
+
}
|
103 |
+
|
104 |
+
// Add conversation history (only if includeContext is true and we have history)
|
105 |
+
if (includeContext && ApolloAI.getHistory().length > 0) {
|
106 |
+
const history = ApolloAI.getHistory().slice(-2); // Last 2 messages only
|
107 |
+
for (const msg of history) {
|
108 |
+
messages.push({
|
109 |
+
role: msg.role,
|
110 |
+
content: msg.content
|
111 |
+
});
|
112 |
+
}
|
113 |
+
}
|
114 |
+
|
115 |
+
// Add current user message
|
116 |
+
messages.push({
|
117 |
+
role: 'user',
|
118 |
+
content: prompt
|
119 |
+
});
|
120 |
+
|
121 |
+
// Add VS Code context if available (but not in conversation history)
|
122 |
+
const editor = vscode.window.activeTextEditor;
|
123 |
+
if (includeContext && editor && !forceMode) {
|
124 |
+
const fileName = editor.document.fileName.split(/[/\\]/).pop();
|
125 |
+
const language = editor.document.languageId;
|
126 |
+
messages[messages.length - 1].content += `\n\n[VS Code Context: Editing ${fileName} (${language})]`;
|
127 |
+
}
|
128 |
+
|
129 |
+
const headers = {
|
130 |
+
'Authorization': `Bearer ${API_KEY}`,
|
131 |
+
'Content-Type': 'application/json',
|
132 |
+
'User-Agent': 'Apollo-AI-VSCode-Extension/1.2.0'
|
133 |
+
};
|
134 |
+
|
135 |
+
const body = {
|
136 |
+
messages: messages,
|
137 |
+
temperature: forceMode ? 0.3 : temperature, // Lower temperature for force mode
|
138 |
+
max_tokens: forceMode ? 200 : maxTokens, // Much shorter responses for force mode
|
139 |
+
stream: false
|
140 |
+
};
|
141 |
+
|
142 |
+
for (let attempt = 1; attempt <= retries; attempt++) {
|
143 |
+
try {
|
144 |
+
const fetchImpl = await getFetch();
|
145 |
+
|
146 |
+
console.log(`🚀 Apollo AI: Attempt ${attempt}/${retries}, sending request to API...`);
|
147 |
+
console.log('📤 Request body:', JSON.stringify(body, null, 2));
|
148 |
+
|
149 |
+
const controller = new AbortController();
|
150 |
+
const timeoutId = setTimeout(() => controller.abort(), TIMEOUT_MS);
|
151 |
+
|
152 |
+
const res = await fetchImpl(API_URL, {
|
153 |
+
method: 'POST',
|
154 |
+
headers,
|
155 |
+
body: JSON.stringify(body),
|
156 |
+
signal: controller.signal
|
157 |
+
});
|
158 |
+
|
159 |
+
clearTimeout(timeoutId);
|
160 |
+
|
161 |
+
console.log('📨 Apollo AI: Received response, status:', res.status);
|
162 |
+
|
163 |
+
if (!res.ok) {
|
164 |
+
const errorText = await res.text().catch(() => 'Unable to read error response');
|
165 |
+
console.error(`❌ Apollo AI: API Error ${res.status}: ${errorText}`);
|
166 |
|
167 |
+
if (res.status === 429) {
|
168 |
+
throw new Error('⏱️ Rate limit exceeded. Please wait a moment and try again.');
|
169 |
+
} else if (res.status === 401) {
|
170 |
+
throw new Error('🔑 Authentication failed. Please check your API key.');
|
171 |
+
} else if (res.status >= 500) {
|
172 |
+
throw new Error('🔧 Server error. The AI service is temporarily unavailable.');
|
173 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
+
throw new Error(`API Error (${res.status}): ${res.statusText}`);
|
176 |
+
}
|
177 |
+
|
178 |
+
const json: APIResponse = await res.json();
|
179 |
+
console.log('📦 Apollo AI: Raw JSON response:', JSON.stringify(json, null, 2));
|
180 |
+
|
181 |
+
// ✅ FIXED: Extract response from proper JSON structure
|
182 |
+
let responseText = '';
|
183 |
+
|
184 |
+
// Handle the actual API response format
|
185 |
+
if (json.choices && json.choices[0] && json.choices[0].message) {
|
186 |
+
responseText = json.choices[0].message.content || '';
|
187 |
+
console.log('✅ Extracted content from JSON response:', responseText.substring(0, 100) + '...');
|
188 |
+
} else if (json.generated_text) {
|
189 |
+
responseText = json.generated_text;
|
190 |
+
console.log('✅ Extracted generated_text from response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|