Ais
commited on
Update app/main.py
Browse files- app/main.py +52 -79
app/main.py
CHANGED
@@ -1,89 +1,62 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
-
import torch
|
4 |
-
import gdown
|
5 |
-
import re
|
6 |
-
import shutil
|
7 |
-
from fastapi import FastAPI, Request
|
8 |
from pydantic import BaseModel
|
9 |
-
|
10 |
-
from transformers import
|
|
|
|
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
BASE_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
|
17 |
|
18 |
-
#
|
19 |
app = FastAPI()
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
prompt: str
|
|
|
23 |
|
24 |
-
|
25 |
-
def
|
26 |
-
|
27 |
-
gdown.download_folder(url=DRIVE_FOLDER_URL, output=TEMP_DIR, quiet=False, use_cookies=False)
|
28 |
-
|
29 |
-
all_versions = sorted(
|
30 |
-
[d for d in os.listdir(TEMP_DIR) if re.match(r"version \d+", d)],
|
31 |
-
key=lambda x: int(x.split()[-1])
|
32 |
-
)
|
33 |
-
if not all_versions:
|
34 |
-
raise ValueError("❌ No adapter versions found.")
|
35 |
-
|
36 |
-
latest = all_versions[-1]
|
37 |
-
src = os.path.join(TEMP_DIR, latest)
|
38 |
-
|
39 |
-
os.makedirs(ADAPTER_DIR, exist_ok=True)
|
40 |
-
for f in os.listdir(ADAPTER_DIR):
|
41 |
-
os.remove(os.path.join(ADAPTER_DIR, f))
|
42 |
-
|
43 |
-
for f in os.listdir(src):
|
44 |
-
shutil.copy(os.path.join(src, f), os.path.join(ADAPTER_DIR, f))
|
45 |
-
|
46 |
-
print(f"✅ Adapter '{latest}' copied to '{ADAPTER_DIR}'")
|
47 |
-
|
48 |
-
# ====== LOAD MODEL ======
|
49 |
-
def load_model():
|
50 |
-
print("🔧 Loading base model...")
|
51 |
-
base_model = AutoModelForCausalLM.from_pretrained(
|
52 |
-
BASE_MODEL,
|
53 |
-
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
54 |
-
device_map="auto"
|
55 |
-
)
|
56 |
-
|
57 |
-
print("🔗 Loading LoRA adapter...")
|
58 |
-
model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
|
59 |
-
|
60 |
-
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
61 |
-
return model, tokenizer
|
62 |
-
|
63 |
-
# ====== RUN ======
|
64 |
-
download_latest_adapter()
|
65 |
-
model, tokenizer = load_model()
|
66 |
|
67 |
@app.post("/chat")
|
68 |
-
def chat(
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
top_p=0.9
|
84 |
-
)
|
85 |
-
|
86 |
-
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
87 |
-
response = response.replace(text, "").strip()
|
88 |
|
89 |
-
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Request, HTTPException
|
2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
|
|
|
|
3 |
from pydantic import BaseModel
|
4 |
+
import torch
|
5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
+
from peft import PeftModel
|
7 |
+
import os
|
8 |
|
9 |
+
# === CONFIG ===
|
10 |
+
HF_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
|
11 |
+
ADAPTER_PATH = "adapter" # folder where your LoRA is saved
|
12 |
+
API_KEY = os.getenv("API_KEY", "your-secret-key") # Set in HF Space secrets
|
|
|
13 |
|
14 |
+
# === FastAPI Setup ===
|
15 |
app = FastAPI()
|
16 |
|
17 |
+
app.add_middleware(
|
18 |
+
CORSMiddleware,
|
19 |
+
allow_origins=["*"], # adjust if needed
|
20 |
+
allow_credentials=True,
|
21 |
+
allow_methods=["*"],
|
22 |
+
allow_headers=["*"],
|
23 |
+
)
|
24 |
+
|
25 |
+
# === Load Model & Tokenizer (CPU only) ===
|
26 |
+
print("🔧 Loading model on CPU...")
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, trust_remote_code=True)
|
28 |
+
model = AutoModelForCausalLM.from_pretrained(HF_MODEL, torch_dtype=torch.float32, trust_remote_code=True)
|
29 |
+
model = PeftModel.from_pretrained(model, ADAPTER_PATH)
|
30 |
+
model = model.to("cpu")
|
31 |
+
model.eval()
|
32 |
+
print("✅ Model ready on CPU.")
|
33 |
+
|
34 |
+
# === Request Schema ===
|
35 |
+
class ChatRequest(BaseModel):
|
36 |
prompt: str
|
37 |
+
api_key: str
|
38 |
|
39 |
+
@app.get("/")
|
40 |
+
def root():
|
41 |
+
return {"message": "✅ Qwen2.5 Chat API running."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
@app.post("/chat")
|
44 |
+
def chat(req: ChatRequest):
|
45 |
+
if req.api_key != API_KEY:
|
46 |
+
raise HTTPException(status_code=401, detail="Invalid API Key")
|
47 |
+
|
48 |
+
input_text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{req.prompt}<|im_end|>\n<|im_start|>assistant\n"
|
49 |
+
|
50 |
+
inputs = tokenizer(input_text, return_tensors="pt").to("cpu")
|
51 |
+
outputs = model.generate(
|
52 |
+
**inputs,
|
53 |
+
max_new_tokens=512,
|
54 |
+
temperature=0.7,
|
55 |
+
do_sample=True,
|
56 |
+
pad_token_id=tokenizer.eos_token_id
|
57 |
+
)
|
58 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
# Extract response after assistant tag
|
61 |
+
final_resp = response.split("<|im_start|>assistant\n")[-1].strip()
|
62 |
+
return {"response": final_resp}
|