aigen / app /main.py
Ais
Update app/main.py
6df15e3 verified
raw
history blame
3.2 kB
import os
import torch
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from starlette.middleware.cors import CORSMiddleware
# === Setup FastAPI ===
app = FastAPI()
# === CORS for frontend testing (optional) ===
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# === Load Secret API Key from Hugging Face Secrets ===
API_KEY = os.getenv("API_KEY", "undefined")
# === Load Model and Adapter (CPU only) ===
BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
ADAPTER_PATH = "adapter"
print("🔧 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
print("🧠 Loading base model on CPU...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
torch_dtype=torch.float32 # CPU only
).cpu()
print("🔗 Applying LoRA adapter...")
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH).cpu()
model.eval()
print("✅ Model and adapter loaded.")
# === Root route for test ===
@app.get("/")
def read_root():
return {"message": "🧠 Qwen2.5-0.5B-Instruct API is running on CPU!"}
# === POST /v1/chat/completions (OpenAI-style) ===
@app.post("/v1/chat/completions")
async def chat(request: Request):
# ✅ Check API key from headers
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
return JSONResponse(status_code=401, content={"error": "Missing Bearer token in Authorization header."})
token = auth.replace("Bearer ", "").strip()
if token != API_KEY:
return JSONResponse(status_code=401, content={"error": "Invalid API key."})
# ✅ Parse user prompt
body = await request.json()
messages = body.get("messages", [])
if not messages or not isinstance(messages, list):
return JSONResponse(status_code=400, content={"error": "No messages provided."})
user_prompt = messages[-1]["content"]
# ✅ Format prompt for Qwen chat model
prompt = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
# ✅ Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = full_output.split("<|im_start|>assistant\n")[-1].strip()
# ✅ Return in OpenAI-style format
return {
"id": "chatcmpl-custom-001",
"object": "chat.completion",
"model": "Qwen2.5-0.5B-Instruct-LoRA",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": answer
},
"finish_reason": "stop"
}
]
}