Ais commited on
Commit
90ddcea
·
verified ·
1 Parent(s): 363fc6a

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +78 -69
app/main.py CHANGED
@@ -1,80 +1,89 @@
1
  # app/main.py
2
- from fastapi import FastAPI, Form
3
- from fastapi.responses import HTMLResponse
4
- from fastapi.middleware.cors import CORSMiddleware
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
- from peft import PeftModel
7
- import torch
8
  import os
 
 
 
 
 
 
 
 
9
 
10
- from app.download_adapter import download_latest_adapter
 
 
 
 
11
 
12
- # === Step 1: Download Adapter ===
13
- download_latest_adapter()
14
 
15
- # === Step 2: Load Model and Tokenizer ===
16
- BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
17
- ADAPTER_FOLDER = "adapter"
18
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
19
 
20
- print("🚀 Loading base model...")
21
- base_model = AutoModelForCausalLM.from_pretrained(
22
- BASE_MODEL,
23
- torch_dtype=torch.float16,
24
- device_map="auto",
25
- token=HF_TOKEN,
26
- trust_remote_code=True
27
- )
28
 
29
- print("🔧 Applying LoRA adapter...")
30
- model = PeftModel.from_pretrained(base_model, ADAPTER_FOLDER)
 
 
 
 
31
 
32
- print("🧠 Loading tokenizer...")
33
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
34
 
35
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
 
36
 
37
- # === Step 3: FastAPI App ===
38
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- app.add_middleware(
41
- CORSMiddleware,
42
- allow_origins=["*"], # Allow all origins for testing
43
- allow_credentials=True,
44
- allow_methods=["*"],
45
- allow_headers=["*"],
46
- )
47
-
48
- @app.get("/", response_class=HTMLResponse)
49
- async def form():
50
- return """
51
- <html>
52
- <head><title>Qwen Chat</title></head>
53
- <body>
54
- <h2>Ask something:</h2>
55
- <form method="post">
56
- <textarea name="prompt" rows="4" cols="60"></textarea><br>
57
- <input type="submit" value="Generate">
58
- </form>
59
- </body>
60
- </html>
61
- """
62
-
63
- @app.post("/", response_class=HTMLResponse)
64
- async def generate(prompt: str = Form(...)):
65
- full_prompt = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
66
- output = pipe(full_prompt, max_new_tokens=256, do_sample=True, temperature=0.7)
67
- response = output[0]["generated_text"].split("<|im_start|>assistant\n")[-1].strip()
68
-
69
- return f"""
70
- <html>
71
- <head><title>Qwen Chat</title></head>
72
- <body>
73
- <h2>Your Prompt:</h2>
74
- <p>{prompt}</p>
75
- <h2>Response:</h2>
76
- <p>{response}</p>
77
- <a href="/">Ask again</a>
78
- </body>
79
- </html>
80
- """
 
1
  # app/main.py
 
 
 
 
 
 
2
  import os
3
+ import torch
4
+ import gdown
5
+ import re
6
+ import shutil
7
+ from fastapi import FastAPI, Request
8
+ from pydantic import BaseModel
9
+ from peft import PeftModel, PeftConfig
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
11
 
12
+ # ====== CONFIG ======
13
+ DRIVE_FOLDER_URL = "https://drive.google.com/drive/folders/1S9xT92Zm9rZ4RSCxAe_DLld8vu78mqW4"
14
+ ADAPTER_DIR = "adapter"
15
+ TEMP_DIR = "gdrive_tmp"
16
+ BASE_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
17
 
18
+ # ====== FASTAPI SETUP ======
19
+ app = FastAPI()
20
 
21
+ class Message(BaseModel):
22
+ prompt: str
 
 
23
 
24
+ # ====== DOWNLOAD LATEST ADAPTER ======
25
+ def download_latest_adapter():
26
+ print("🔽 Downloading adapter folder from Google Drive...")
27
+ gdown.download_folder(url=DRIVE_FOLDER_URL, output=TEMP_DIR, quiet=False, use_cookies=False)
 
 
 
 
28
 
29
+ all_versions = sorted(
30
+ [d for d in os.listdir(TEMP_DIR) if re.match(r"version \d+", d)],
31
+ key=lambda x: int(x.split()[-1])
32
+ )
33
+ if not all_versions:
34
+ raise ValueError("❌ No adapter versions found.")
35
 
36
+ latest = all_versions[-1]
37
+ src = os.path.join(TEMP_DIR, latest)
38
 
39
+ os.makedirs(ADAPTER_DIR, exist_ok=True)
40
+ for f in os.listdir(ADAPTER_DIR):
41
+ os.remove(os.path.join(ADAPTER_DIR, f))
42
 
43
+ for f in os.listdir(src):
44
+ shutil.copy(os.path.join(src, f), os.path.join(ADAPTER_DIR, f))
45
+
46
+ print(f"✅ Adapter '{latest}' copied to '{ADAPTER_DIR}'")
47
+
48
+ # ====== LOAD MODEL ======
49
+ def load_model():
50
+ print("🔧 Loading base model...")
51
+ base_model = AutoModelForCausalLM.from_pretrained(
52
+ BASE_MODEL,
53
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
54
+ device_map="auto"
55
+ )
56
+
57
+ print("🔗 Loading LoRA adapter...")
58
+ model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
59
+
60
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
61
+ return model, tokenizer
62
+
63
+ # ====== RUN ======
64
+ download_latest_adapter()
65
+ model, tokenizer = load_model()
66
+
67
+ @app.post("/chat")
68
+ def chat(msg: Message):
69
+ prompt = msg.prompt.strip()
70
+
71
+ messages = [
72
+ {"role": "user", "content": prompt}
73
+ ]
74
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
75
+
76
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
77
+ with torch.no_grad():
78
+ output = model.generate(
79
+ **inputs,
80
+ max_new_tokens=512,
81
+ do_sample=True,
82
+ temperature=0.7,
83
+ top_p=0.9
84
+ )
85
+
86
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
87
+ response = response.replace(text, "").strip()
88
 
89
+ return {"response": response}