Ais commited on
Commit
0f91c7c
·
verified ·
1 Parent(s): 7e53eac

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +45 -33
app/main.py CHANGED
@@ -1,45 +1,57 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- from peft import PeftModel
5
  import torch
 
 
 
 
6
 
7
- app = FastAPI()
8
-
9
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", trust_remote_code=True)
10
- tokenizer.pad_token = tokenizer.eos_token
11
-
12
  model = AutoModelForCausalLM.from_pretrained(
13
- "Qwen/Qwen2.5-0.5B-Instruct",
14
- torch_dtype=torch.float32,
 
15
  trust_remote_code=True
16
  )
 
 
 
 
 
 
 
 
 
 
 
 
17
  model = PeftModel.from_pretrained(model, "./adapter", is_trainable=False)
18
  model.eval()
19
 
20
- def build_prompt(messages):
21
- prompt = ""
22
- for msg in messages:
23
- role = "User" if msg["role"] == "user" else "Assistant"
24
- prompt += f"### {role}:\n{msg['content']}\n"
25
- prompt += "### Assistant:\n"
26
- return prompt
 
27
 
28
- class ChatRequest(BaseModel):
29
- messages: list # [{"role": "user", "content": "..."}]
30
-
31
- @app.post("/chat")
32
- async def chat(req: ChatRequest):
33
- prompt = build_prompt(req.messages)
34
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
35
- outputs = model.generate(
36
- **inputs,
37
- max_new_tokens=256,
38
  do_sample=True,
39
  temperature=0.7,
40
- top_p=0.95,
41
- eos_token_id=tokenizer.eos_token_id
42
  )
43
- output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
44
- reply = output_text.split("### Assistant:")[-1].strip()
45
- return {"response": reply}
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
3
+ from peft import PeftModel
4
+ import json
5
+ import os
6
 
7
+ # Load tokenizer and base model
8
+ base_model = "Qwen/Qwen2-0.5B-Instruct"
9
+ tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
 
 
10
  model = AutoModelForCausalLM.from_pretrained(
11
+ base_model,
12
+ device_map="cuda" if torch.cuda.is_available() else "cpu",
13
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
14
  trust_remote_code=True
15
  )
16
+
17
+ # Clean up adapter_config.json before loading adapter
18
+ adapter_config_path = "./adapter/adapter_config.json"
19
+ if os.path.exists(adapter_config_path):
20
+ with open(adapter_config_path, "r") as f:
21
+ adapter_config = json.load(f)
22
+ for key in ["corda_config", "eva_config", "megatron_config"]:
23
+ adapter_config.pop(key, None)
24
+ with open(adapter_config_path, "w") as f:
25
+ json.dump(adapter_config, f)
26
+
27
+ # Load adapter
28
  model = PeftModel.from_pretrained(model, "./adapter", is_trainable=False)
29
  model.eval()
30
 
31
+ # Simple chat function
32
+ def chat(prompt):
33
+ messages = [
34
+ {"role": "system", "content": "You are a helpful assistant."},
35
+ {"role": "user", "content": prompt}
36
+ ]
37
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
38
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
39
 
40
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
41
+ generated_ids = model.generate(
42
+ **model_inputs,
43
+ max_new_tokens=512,
 
 
 
 
 
 
44
  do_sample=True,
45
  temperature=0.7,
46
+ streamer=streamer
 
47
  )
48
+ output = tokenizer.decode(generated_ids[0][model_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
49
+ return output
50
+
51
+ # Example
52
+ if __name__ == "__main__":
53
+ while True:
54
+ prompt = input("User: ")
55
+ if prompt.lower() in ["exit", "quit"]:
56
+ break
57
+ print("AI:", chat(prompt))