Ais commited on
Commit
158ce9c
·
verified ·
1 Parent(s): 90ddcea

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +52 -79
app/main.py CHANGED
@@ -1,89 +1,62 @@
1
- # app/main.py
2
- import os
3
- import torch
4
- import gdown
5
- import re
6
- import shutil
7
- from fastapi import FastAPI, Request
8
  from pydantic import BaseModel
9
- from peft import PeftModel, PeftConfig
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
 
 
11
 
12
- # ====== CONFIG ======
13
- DRIVE_FOLDER_URL = "https://drive.google.com/drive/folders/1S9xT92Zm9rZ4RSCxAe_DLld8vu78mqW4"
14
- ADAPTER_DIR = "adapter"
15
- TEMP_DIR = "gdrive_tmp"
16
- BASE_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
17
 
18
- # ====== FASTAPI SETUP ======
19
  app = FastAPI()
20
 
21
- class Message(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  prompt: str
 
23
 
24
- # ====== DOWNLOAD LATEST ADAPTER ======
25
- def download_latest_adapter():
26
- print("🔽 Downloading adapter folder from Google Drive...")
27
- gdown.download_folder(url=DRIVE_FOLDER_URL, output=TEMP_DIR, quiet=False, use_cookies=False)
28
-
29
- all_versions = sorted(
30
- [d for d in os.listdir(TEMP_DIR) if re.match(r"version \d+", d)],
31
- key=lambda x: int(x.split()[-1])
32
- )
33
- if not all_versions:
34
- raise ValueError("❌ No adapter versions found.")
35
-
36
- latest = all_versions[-1]
37
- src = os.path.join(TEMP_DIR, latest)
38
-
39
- os.makedirs(ADAPTER_DIR, exist_ok=True)
40
- for f in os.listdir(ADAPTER_DIR):
41
- os.remove(os.path.join(ADAPTER_DIR, f))
42
-
43
- for f in os.listdir(src):
44
- shutil.copy(os.path.join(src, f), os.path.join(ADAPTER_DIR, f))
45
-
46
- print(f"✅ Adapter '{latest}' copied to '{ADAPTER_DIR}'")
47
-
48
- # ====== LOAD MODEL ======
49
- def load_model():
50
- print("🔧 Loading base model...")
51
- base_model = AutoModelForCausalLM.from_pretrained(
52
- BASE_MODEL,
53
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
54
- device_map="auto"
55
- )
56
-
57
- print("🔗 Loading LoRA adapter...")
58
- model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
59
-
60
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
61
- return model, tokenizer
62
-
63
- # ====== RUN ======
64
- download_latest_adapter()
65
- model, tokenizer = load_model()
66
 
67
  @app.post("/chat")
68
- def chat(msg: Message):
69
- prompt = msg.prompt.strip()
70
-
71
- messages = [
72
- {"role": "user", "content": prompt}
73
- ]
74
- text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
75
-
76
- inputs = tokenizer(text, return_tensors="pt").to(model.device)
77
- with torch.no_grad():
78
- output = model.generate(
79
- **inputs,
80
- max_new_tokens=512,
81
- do_sample=True,
82
- temperature=0.7,
83
- top_p=0.9
84
- )
85
-
86
- response = tokenizer.decode(output[0], skip_special_tokens=True)
87
- response = response.replace(text, "").strip()
88
 
89
- return {"response": response}
 
 
 
1
+ from fastapi import FastAPI, Request, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
 
 
 
 
 
3
  from pydantic import BaseModel
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from peft import PeftModel
7
+ import os
8
 
9
+ # === CONFIG ===
10
+ HF_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
11
+ ADAPTER_PATH = "adapter" # folder where your LoRA is saved
12
+ API_KEY = os.getenv("API_KEY", "your-secret-key") # Set in HF Space secrets
 
13
 
14
+ # === FastAPI Setup ===
15
  app = FastAPI()
16
 
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"], # adjust if needed
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ # === Load Model & Tokenizer (CPU only) ===
26
+ print("🔧 Loading model on CPU...")
27
+ tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, trust_remote_code=True)
28
+ model = AutoModelForCausalLM.from_pretrained(HF_MODEL, torch_dtype=torch.float32, trust_remote_code=True)
29
+ model = PeftModel.from_pretrained(model, ADAPTER_PATH)
30
+ model = model.to("cpu")
31
+ model.eval()
32
+ print("✅ Model ready on CPU.")
33
+
34
+ # === Request Schema ===
35
+ class ChatRequest(BaseModel):
36
  prompt: str
37
+ api_key: str
38
 
39
+ @app.get("/")
40
+ def root():
41
+ return {"message": "✅ Qwen2.5 Chat API running."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  @app.post("/chat")
44
+ def chat(req: ChatRequest):
45
+ if req.api_key != API_KEY:
46
+ raise HTTPException(status_code=401, detail="Invalid API Key")
47
+
48
+ input_text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{req.prompt}<|im_end|>\n<|im_start|>assistant\n"
49
+
50
+ inputs = tokenizer(input_text, return_tensors="pt").to("cpu")
51
+ outputs = model.generate(
52
+ **inputs,
53
+ max_new_tokens=512,
54
+ temperature=0.7,
55
+ do_sample=True,
56
+ pad_token_id=tokenizer.eos_token_id
57
+ )
58
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
59
 
60
+ # Extract response after assistant tag
61
+ final_resp = response.split("<|im_start|>assistant\n")[-1].strip()
62
+ return {"response": final_resp}