File size: 3,199 Bytes
6df15e3 158ce9c 6df15e3 158ce9c 6df15e3 18aea39 6df15e3 90ddcea 18aea39 6df15e3 158ce9c 6df15e3 158ce9c 6df15e3 158ce9c 6df15e3 48b2ebf 6df15e3 158ce9c 6df15e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
import torch
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from starlette.middleware.cors import CORSMiddleware
# === Setup FastAPI ===
app = FastAPI()
# === CORS for frontend testing (optional) ===
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# === Load Secret API Key from Hugging Face Secrets ===
API_KEY = os.getenv("API_KEY", "undefined")
# === Load Model and Adapter (CPU only) ===
BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
ADAPTER_PATH = "adapter"
print("🔧 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
print("🧠 Loading base model on CPU...")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
torch_dtype=torch.float32 # CPU only
).cpu()
print("🔗 Applying LoRA adapter...")
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH).cpu()
model.eval()
print("✅ Model and adapter loaded.")
# === Root route for test ===
@app.get("/")
def read_root():
return {"message": "🧠 Qwen2.5-0.5B-Instruct API is running on CPU!"}
# === POST /v1/chat/completions (OpenAI-style) ===
@app.post("/v1/chat/completions")
async def chat(request: Request):
# ✅ Check API key from headers
auth = request.headers.get("Authorization", "")
if not auth.startswith("Bearer "):
return JSONResponse(status_code=401, content={"error": "Missing Bearer token in Authorization header."})
token = auth.replace("Bearer ", "").strip()
if token != API_KEY:
return JSONResponse(status_code=401, content={"error": "Invalid API key."})
# ✅ Parse user prompt
body = await request.json()
messages = body.get("messages", [])
if not messages or not isinstance(messages, list):
return JSONResponse(status_code=400, content={"error": "No messages provided."})
user_prompt = messages[-1]["content"]
# ✅ Format prompt for Qwen chat model
prompt = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
# ✅ Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = full_output.split("<|im_start|>assistant\n")[-1].strip()
# ✅ Return in OpenAI-style format
return {
"id": "chatcmpl-custom-001",
"object": "chat.completion",
"model": "Qwen2.5-0.5B-Instruct-LoRA",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": answer
},
"finish_reason": "stop"
}
]
}
|