Ais commited on
Commit
6df15e3
·
verified ·
1 Parent(s): 158ce9c

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +88 -46
app/main.py CHANGED
@@ -1,62 +1,104 @@
1
- from fastapi import FastAPI, Request, HTTPException
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
  import torch
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
6
  from peft import PeftModel
7
- import os
8
-
9
- # === CONFIG ===
10
- HF_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
11
- ADAPTER_PATH = "adapter" # folder where your LoRA is saved
12
- API_KEY = os.getenv("API_KEY", "your-secret-key") # Set in HF Space secrets
13
 
14
- # === FastAPI Setup ===
15
  app = FastAPI()
16
 
 
17
  app.add_middleware(
18
  CORSMiddleware,
19
- allow_origins=["*"], # adjust if needed
20
  allow_credentials=True,
21
  allow_methods=["*"],
22
  allow_headers=["*"],
23
  )
24
 
25
- # === Load Model & Tokenizer (CPU only) ===
26
- print("🔧 Loading model on CPU...")
27
- tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, trust_remote_code=True)
28
- model = AutoModelForCausalLM.from_pretrained(HF_MODEL, torch_dtype=torch.float32, trust_remote_code=True)
29
- model = PeftModel.from_pretrained(model, ADAPTER_PATH)
30
- model = model.to("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  model.eval()
32
- print("✅ Model ready on CPU.")
33
 
34
- # === Request Schema ===
35
- class ChatRequest(BaseModel):
36
- prompt: str
37
- api_key: str
38
 
 
 
39
  @app.get("/")
40
- def root():
41
- return {"message": " Qwen2.5 Chat API running."}
42
-
43
- @app.post("/chat")
44
- def chat(req: ChatRequest):
45
- if req.api_key != API_KEY:
46
- raise HTTPException(status_code=401, detail="Invalid API Key")
47
-
48
- input_text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{req.prompt}<|im_end|>\n<|im_start|>assistant\n"
49
-
50
- inputs = tokenizer(input_text, return_tensors="pt").to("cpu")
51
- outputs = model.generate(
52
- **inputs,
53
- max_new_tokens=512,
54
- temperature=0.7,
55
- do_sample=True,
56
- pad_token_id=tokenizer.eos_token_id
57
- )
58
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
59
-
60
- # Extract response after assistant tag
61
- final_resp = response.split("<|im_start|>assistant\n")[-1].strip()
62
- return {"response": final_resp}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
 
2
  import torch
3
+ from fastapi import FastAPI, Request
4
+ from fastapi.responses import JSONResponse
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from peft import PeftModel
7
+ from starlette.middleware.cors import CORSMiddleware
 
 
 
 
 
8
 
9
+ # === Setup FastAPI ===
10
  app = FastAPI()
11
 
12
+ # === CORS for frontend testing (optional) ===
13
  app.add_middleware(
14
  CORSMiddleware,
15
+ allow_origins=["*"],
16
  allow_credentials=True,
17
  allow_methods=["*"],
18
  allow_headers=["*"],
19
  )
20
 
21
+ # === Load Secret API Key from Hugging Face Secrets ===
22
+ API_KEY = os.getenv("API_KEY", "undefined")
23
+
24
+ # === Load Model and Adapter (CPU only) ===
25
+ BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
26
+ ADAPTER_PATH = "adapter"
27
+
28
+ print("🔧 Loading tokenizer...")
29
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
30
+
31
+ print("🧠 Loading base model on CPU...")
32
+ base_model = AutoModelForCausalLM.from_pretrained(
33
+ BASE_MODEL,
34
+ trust_remote_code=True,
35
+ torch_dtype=torch.float32 # CPU only
36
+ ).cpu()
37
+
38
+ print("🔗 Applying LoRA adapter...")
39
+ model = PeftModel.from_pretrained(base_model, ADAPTER_PATH).cpu()
40
  model.eval()
 
41
 
42
+ print("✅ Model and adapter loaded.")
 
 
 
43
 
44
+
45
+ # === Root route for test ===
46
  @app.get("/")
47
+ def read_root():
48
+ return {"message": "🧠 Qwen2.5-0.5B-Instruct API is running on CPU!"}
49
+
50
+
51
+ # === POST /v1/chat/completions (OpenAI-style) ===
52
+ @app.post("/v1/chat/completions")
53
+ async def chat(request: Request):
54
+ # ✅ Check API key from headers
55
+ auth = request.headers.get("Authorization", "")
56
+ if not auth.startswith("Bearer "):
57
+ return JSONResponse(status_code=401, content={"error": "Missing Bearer token in Authorization header."})
58
+
59
+ token = auth.replace("Bearer ", "").strip()
60
+ if token != API_KEY:
61
+ return JSONResponse(status_code=401, content={"error": "Invalid API key."})
62
+
63
+ # ✅ Parse user prompt
64
+ body = await request.json()
65
+ messages = body.get("messages", [])
66
+ if not messages or not isinstance(messages, list):
67
+ return JSONResponse(status_code=400, content={"error": "No messages provided."})
68
+
69
+ user_prompt = messages[-1]["content"]
70
+
71
+ # ✅ Format prompt for Qwen chat model
72
+ prompt = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
73
+ inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
74
+
75
+ # ✅ Generate
76
+ with torch.no_grad():
77
+ outputs = model.generate(
78
+ **inputs,
79
+ max_new_tokens=512,
80
+ temperature=0.7,
81
+ top_p=0.9,
82
+ do_sample=True,
83
+ pad_token_id=tokenizer.eos_token_id
84
+ )
85
+
86
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
87
+ answer = full_output.split("<|im_start|>assistant\n")[-1].strip()
88
+
89
+ # ✅ Return in OpenAI-style format
90
+ return {
91
+ "id": "chatcmpl-custom-001",
92
+ "object": "chat.completion",
93
+ "model": "Qwen2.5-0.5B-Instruct-LoRA",
94
+ "choices": [
95
+ {
96
+ "index": 0,
97
+ "message": {
98
+ "role": "assistant",
99
+ "content": answer
100
+ },
101
+ "finish_reason": "stop"
102
+ }
103
+ ]
104
+ }