Ais commited on
Commit
1e414fd
·
verified ·
1 Parent(s): c455b22

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +71 -53
app/main.py CHANGED
@@ -1,57 +1,75 @@
 
 
 
1
  import torch
 
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
3
- from peft import PeftModel
4
- import json
5
- import os
6
 
7
- # Load tokenizer and base model
8
- base_model = "Qwen/Qwen2-0.5B-Instruct"
9
- tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
10
- model = AutoModelForCausalLM.from_pretrained(
11
- base_model,
12
- device_map="cuda" if torch.cuda.is_available() else "cpu",
13
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
14
- trust_remote_code=True
15
- )
16
-
17
- # Clean up adapter_config.json before loading adapter
18
- adapter_config_path = "./adapter/adapter_config.json"
19
- if os.path.exists(adapter_config_path):
20
- with open(adapter_config_path, "r") as f:
21
- adapter_config = json.load(f)
22
- for key in ["corda_config", "eva_config", "megatron_config"]:
23
- adapter_config.pop(key, None)
24
- with open(adapter_config_path, "w") as f:
25
- json.dump(adapter_config, f)
26
-
27
- # Load adapter
28
- model = PeftModel.from_pretrained(model, "./adapter", is_trainable=False)
29
- model.eval()
30
-
31
- # Simple chat function
32
- def chat(prompt):
33
- messages = [
34
- {"role": "system", "content": "You are a helpful assistant."},
35
- {"role": "user", "content": prompt}
36
- ]
37
- text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
38
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
39
-
40
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
41
- generated_ids = model.generate(
42
- **model_inputs,
43
- max_new_tokens=512,
44
- do_sample=True,
45
- temperature=0.7,
46
- streamer=streamer
47
  )
48
- output = tokenizer.decode(generated_ids[0][model_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
49
- return output
50
-
51
- # Example
52
- if __name__ == "__main__":
53
- while True:
54
- prompt = input("User: ")
55
- if prompt.lower() in ["exit", "quit"]:
56
- break
57
- print("AI:", chat(prompt))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gdown
3
+ import re
4
  import torch
5
+ from fastapi import FastAPI, Request
6
+ from pydantic import BaseModel
7
+ from peft import PeftModel, PeftConfig
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
 
 
 
9
 
10
+ app = FastAPI()
11
+
12
+ DRIVE_FOLDER_URL = "https://drive.google.com/drive/folders/1S9xT92Zm9rZ4RSCxAe_DLld8vu78mqW4"
13
+ LOCAL_ADAPTER_DIR = "adapter"
14
+ BASE_MODEL = "Qwen/Qwen2-0.5B-Instruct"
15
+
16
+ class PromptRequest(BaseModel):
17
+ prompt: str
18
+
19
+ def download_latest_adapter():
20
+ print("🔽 Downloading adapter folder from Google Drive...")
21
+ gdown.download_folder(url=DRIVE_FOLDER_URL, output="gdrive_tmp", quiet=False, use_cookies=False)
22
+
23
+ all_versions = sorted(
24
+ [d for d in os.listdir("gdrive_tmp") if re.match(r"version \d+", d)],
25
+ key=lambda x: int(x.split()[-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  )
27
+ if not all_versions:
28
+ raise ValueError("❌ No version folders found in Google Drive folder.")
29
+
30
+ latest = all_versions[-1]
31
+ src = os.path.join("gdrive_tmp", latest)
32
+ print(f"✅ Latest adapter found: {latest}")
33
+
34
+ os.makedirs(LOCAL_ADAPTER_DIR, exist_ok=True)
35
+ for file in os.listdir(src):
36
+ src_file = os.path.join(src, file)
37
+ dest_file = os.path.join(LOCAL_ADAPTER_DIR, file)
38
+ os.system(f"cp '{src_file}' '{dest_file}'")
39
+
40
+ print(f"✅ Adapter copied to: {LOCAL_ADAPTER_DIR}")
41
+
42
+ def load_model():
43
+ print("🚀 Loading base model...")
44
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map="auto", torch_dtype=torch.float16)
45
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
46
+
47
+ print("🔗 Loading adapter...")
48
+ model = PeftModel.from_pretrained(model, LOCAL_ADAPTER_DIR)
49
+ model.eval()
50
+
51
+ return model, tokenizer
52
+
53
+ # Step 1: Download latest adapter
54
+ download_latest_adapter()
55
+
56
+ # Step 2: Load model and tokenizer
57
+ model, tokenizer = load_model()
58
+
59
+ @app.post("/generate")
60
+ async def generate_text(request: PromptRequest):
61
+ prompt = request.prompt.strip()
62
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
63
+
64
+ with torch.no_grad():
65
+ outputs = model.generate(
66
+ input_ids,
67
+ max_new_tokens=300,
68
+ do_sample=True,
69
+ temperature=0.7,
70
+ top_p=0.95,
71
+ eos_token_id=tokenizer.eos_token_id,
72
+ )
73
+
74
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
75
+ return {"response": result[len(prompt):].strip()}