|
import os |
|
import torch |
|
from fastapi import FastAPI, Request |
|
from fastapi.responses import JSONResponse |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftModel |
|
from starlette.middleware.cors import CORSMiddleware |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
API_KEY = os.getenv("API_KEY", "undefined") |
|
|
|
|
|
BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct" |
|
ADAPTER_PATH = "adapter" |
|
|
|
print("🔧 Loading tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) |
|
|
|
print("🧠 Loading base model on CPU...") |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
BASE_MODEL, |
|
trust_remote_code=True, |
|
torch_dtype=torch.float32 |
|
).cpu() |
|
|
|
print("🔗 Applying LoRA adapter...") |
|
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH).cpu() |
|
model.eval() |
|
|
|
print("✅ Model and adapter loaded successfully.") |
|
|
|
|
|
@app.get("/") |
|
def root(): |
|
return {"message": "🧠 Qwen2.5-0.5B-Instruct API is running on CPU!"} |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
async def chat(request: Request): |
|
|
|
auth_header = request.headers.get("Authorization", "") |
|
if not auth_header.startswith("Bearer "): |
|
return JSONResponse(status_code=401, content={"error": "Missing Bearer token in Authorization header."}) |
|
|
|
token = auth_header.replace("Bearer ", "").strip() |
|
if token != API_KEY: |
|
return JSONResponse(status_code=401, content={"error": "Invalid API key."}) |
|
|
|
|
|
try: |
|
body = await request.json() |
|
messages = body.get("messages", []) |
|
if not messages or not isinstance(messages, list): |
|
raise ValueError("Invalid or missing 'messages' field.") |
|
|
|
temperature = body.get("temperature", 0.7) |
|
max_tokens = body.get("max_tokens", 512) |
|
|
|
except Exception as e: |
|
return JSONResponse(status_code=400, content={"error": f"Bad request: {str(e)}"}) |
|
|
|
|
|
recent_messages = messages[-4:] if len(messages) > 4 else messages |
|
|
|
|
|
formatted_prompt = "" |
|
|
|
for message in recent_messages: |
|
role = message.get("role", "") |
|
content = message.get("content", "") |
|
|
|
if role == "system": |
|
formatted_prompt += f"<|im_start|>system\n{content}<|im_end|>\n" |
|
elif role == "user": |
|
formatted_prompt += f"<|im_start|>user\n{content}<|im_end|>\n" |
|
elif role == "assistant": |
|
formatted_prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n" |
|
|
|
|
|
formatted_prompt += "<|im_start|>assistant\n" |
|
|
|
print(f"🤖 Processing {len(recent_messages)} recent messages") |
|
|
|
inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=0.9, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
final_answer = decoded.split("<|im_start|>assistant\n")[-1].strip() |
|
|
|
|
|
if "<|im_end|>" in final_answer: |
|
final_answer = final_answer.split("<|im_end|>")[0].strip() |
|
|
|
|
|
if "Guidelines:" in final_answer: |
|
final_answer = final_answer.split("Guidelines:")[0].strip() |
|
|
|
if "Response format:" in final_answer: |
|
final_answer = final_answer.split("Response format:")[0].strip() |
|
|
|
|
|
if "[VS Code Context:" in final_answer: |
|
lines = final_answer.split('\n') |
|
cleaned_lines = [line for line in lines if not line.strip().startswith('[VS Code Context:')] |
|
final_answer = '\n'.join(cleaned_lines).strip() |
|
|
|
print(f"✅ Clean response: {final_answer[:100]}...") |
|
|
|
|
|
return { |
|
"id": "chatcmpl-local-001", |
|
"object": "chat.completion", |
|
"model": "Qwen2.5-0.5B-Instruct-LoRA", |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": final_answer |
|
}, |
|
"finish_reason": "stop" |
|
} |
|
] |
|
} |