aigen / app /main.py
Ais
Update app/main.py
730d86c verified
raw
history blame
4.4 kB
import os
import torch
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from starlette.middleware.cors import CORSMiddleware
# === Setup FastAPI ===
app = FastAPI()
# === CORS (optional for frontend access) ===
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# === Load API Key from Hugging Face Secrets ===
API_KEY = os.getenv("API_KEY", "undefined") # Add API_KEY in your HF Space Secrets
# === Model Settings ===
BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
ADAPTER_PATH = "adapter"
print("🔧 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
print("🧠 Loading base model on CPU...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
torch_dtype=torch.float32
).cpu()
print("🔗 Applying LoRA adapter...")
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH).cpu()
model.eval()
print("✅ Model and adapter loaded successfully.")
# === Root Route ===
@app.get("/")
def root():
return {"message": "🧠 Qwen2.5-0.5B-Instruct API is running on CPU!"}
# === Chat Completion API ===
@app.post("/v1/chat/completions")
async def chat(request: Request):
# ✅ API Key Authorization
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return JSONResponse(status_code=401, content={"error": "Missing Bearer token in Authorization header."})
token = auth_header.replace("Bearer ", "").strip()
if token != API_KEY:
return JSONResponse(status_code=401, content={"error": "Invalid API key."})
# ✅ Parse Request
try:
body = await request.json()
messages = body.get("messages", [])
if not messages or not isinstance(messages, list):
raise ValueError("Invalid or missing 'messages' field.")
# ✅ FIXED: Process full conversation history, not just last message
temperature = body.get("temperature", 0.7)
max_tokens = body.get("max_tokens", 512)
except Exception as e:
return JSONResponse(status_code=400, content={"error": f"Bad request: {str(e)}"})
# ✅ FIXED: Build full conversation prompt with history
formatted_prompt = ""
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
if role == "system":
formatted_prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
elif role == "user":
formatted_prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
elif role == "assistant":
formatted_prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
# Add the assistant start token for generation
formatted_prompt += "<|im_start|>assistant\n"
print(f"🤖 Processing conversation with {len(messages)} messages")
print(f"📝 Full prompt preview: {formatted_prompt[:200]}...")
inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
# ✅ Generate Response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
# ✅ FIXED: Clean extraction of only the new assistant response
final_answer = decoded.split("<|im_start|>assistant\n")[-1].strip()
# Remove any potential end tokens or artifacts
if "<|im_end|>" in final_answer:
final_answer = final_answer.split("<|im_end|>")[0].strip()
print(f"✅ Generated response: {final_answer[:100]}...")
# ✅ OpenAI-style Response
return {
"id": "chatcmpl-local-001",
"object": "chat.completion",
"model": "Qwen2.5-0.5B-Instruct-LoRA",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": final_answer
},
"finish_reason": "stop"
}
]
}