Ais commited on
Commit
4ca2587
·
verified ·
1 Parent(s): 6df15e3

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +38 -34
app/main.py CHANGED
@@ -9,7 +9,7 @@ from starlette.middleware.cors import CORSMiddleware
9
  # === Setup FastAPI ===
10
  app = FastAPI()
11
 
12
- # === CORS for frontend testing (optional) ===
13
  app.add_middleware(
14
  CORSMiddleware,
15
  allow_origins=["*"],
@@ -18,10 +18,10 @@ app.add_middleware(
18
  allow_headers=["*"],
19
  )
20
 
21
- # === Load Secret API Key from Hugging Face Secrets ===
22
- API_KEY = os.getenv("API_KEY", "undefined")
23
 
24
- # === Load Model and Adapter (CPU only) ===
25
  BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
26
  ADAPTER_PATH = "adapter"
27
 
@@ -32,47 +32,51 @@ print("🧠 Loading base model on CPU...")
32
  base_model = AutoModelForCausalLM.from_pretrained(
33
  BASE_MODEL,
34
  trust_remote_code=True,
35
- torch_dtype=torch.float32 # CPU only
36
  ).cpu()
37
 
38
  print("🔗 Applying LoRA adapter...")
39
  model = PeftModel.from_pretrained(base_model, ADAPTER_PATH).cpu()
40
  model.eval()
41
 
42
- print("✅ Model and adapter loaded.")
43
 
44
-
45
- # === Root route for test ===
46
  @app.get("/")
47
- def read_root():
48
  return {"message": "🧠 Qwen2.5-0.5B-Instruct API is running on CPU!"}
49
 
50
-
51
- # === POST /v1/chat/completions (OpenAI-style) ===
52
  @app.post("/v1/chat/completions")
53
  async def chat(request: Request):
54
- # ✅ Check API key from headers
55
- auth = request.headers.get("Authorization", "")
56
- if not auth.startswith("Bearer "):
57
  return JSONResponse(status_code=401, content={"error": "Missing Bearer token in Authorization header."})
58
 
59
- token = auth.replace("Bearer ", "").strip()
60
  if token != API_KEY:
61
  return JSONResponse(status_code=401, content={"error": "Invalid API key."})
62
 
63
- # ✅ Parse user prompt
64
- body = await request.json()
65
- messages = body.get("messages", [])
66
- if not messages or not isinstance(messages, list):
67
- return JSONResponse(status_code=400, content={"error": "No messages provided."})
68
-
69
- user_prompt = messages[-1]["content"]
70
-
71
- # Format prompt for Qwen chat model
72
- prompt = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
73
- inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
74
-
75
- # Generate
 
 
 
 
 
 
76
  with torch.no_grad():
77
  outputs = model.generate(
78
  **inputs,
@@ -83,12 +87,12 @@ async def chat(request: Request):
83
  pad_token_id=tokenizer.eos_token_id
84
  )
85
 
86
- full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
87
- answer = full_output.split("<|im_start|>assistant\n")[-1].strip()
88
 
89
- # ✅ Return in OpenAI-style format
90
  return {
91
- "id": "chatcmpl-custom-001",
92
  "object": "chat.completion",
93
  "model": "Qwen2.5-0.5B-Instruct-LoRA",
94
  "choices": [
@@ -96,9 +100,9 @@ async def chat(request: Request):
96
  "index": 0,
97
  "message": {
98
  "role": "assistant",
99
- "content": answer
100
  },
101
  "finish_reason": "stop"
102
  }
103
  ]
104
- }
 
9
  # === Setup FastAPI ===
10
  app = FastAPI()
11
 
12
+ # === CORS (optional for frontend access) ===
13
  app.add_middleware(
14
  CORSMiddleware,
15
  allow_origins=["*"],
 
18
  allow_headers=["*"],
19
  )
20
 
21
+ # === Load API Key from Hugging Face Secrets ===
22
+ API_KEY = os.getenv("API_KEY", "undefined") # Add API_KEY in your HF Space Secrets
23
 
24
+ # === Model Settings ===
25
  BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
26
  ADAPTER_PATH = "adapter"
27
 
 
32
  base_model = AutoModelForCausalLM.from_pretrained(
33
  BASE_MODEL,
34
  trust_remote_code=True,
35
+ torch_dtype=torch.float32
36
  ).cpu()
37
 
38
  print("🔗 Applying LoRA adapter...")
39
  model = PeftModel.from_pretrained(base_model, ADAPTER_PATH).cpu()
40
  model.eval()
41
 
42
+ print("✅ Model and adapter loaded successfully.")
43
 
44
+ # === Root Route ===
 
45
  @app.get("/")
46
+ def root():
47
  return {"message": "🧠 Qwen2.5-0.5B-Instruct API is running on CPU!"}
48
 
49
+ # === Chat Completion API ===
 
50
  @app.post("/v1/chat/completions")
51
  async def chat(request: Request):
52
+ # ✅ API Key Authorization
53
+ auth_header = request.headers.get("Authorization", "")
54
+ if not auth_header.startswith("Bearer "):
55
  return JSONResponse(status_code=401, content={"error": "Missing Bearer token in Authorization header."})
56
 
57
+ token = auth_header.replace("Bearer ", "").strip()
58
  if token != API_KEY:
59
  return JSONResponse(status_code=401, content={"error": "Invalid API key."})
60
 
61
+ # ✅ Parse Request
62
+ try:
63
+ body = await request.json()
64
+ messages = body.get("messages", [])
65
+ if not messages or not isinstance(messages, list):
66
+ raise ValueError("Invalid or missing 'messages' field.")
67
+ user_prompt = messages[-1]["content"]
68
+ except Exception as e:
69
+ return JSONResponse(status_code=400, content={"error": f"Bad request: {str(e)}"})
70
+
71
+ # Format Prompt for Qwen
72
+ formatted_prompt = (
73
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
74
+ f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
75
+ "<|im_start|>assistant\n"
76
+ )
77
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
78
+
79
+ # ✅ Generate Response
80
  with torch.no_grad():
81
  outputs = model.generate(
82
  **inputs,
 
87
  pad_token_id=tokenizer.eos_token_id
88
  )
89
 
90
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
91
+ final_answer = decoded.split("<|im_start|>assistant\n")[-1].strip()
92
 
93
+ # ✅ OpenAI-style Response
94
  return {
95
+ "id": "chatcmpl-local-001",
96
  "object": "chat.completion",
97
  "model": "Qwen2.5-0.5B-Instruct-LoRA",
98
  "choices": [
 
100
  "index": 0,
101
  "message": {
102
  "role": "assistant",
103
+ "content": final_answer
104
  },
105
  "finish_reason": "stop"
106
  }
107
  ]
108
+ }