yuzhe commited on
Commit
ebb4c7f
·
verified ·
1 Parent(s): 74ee6ee

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -0
handler.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ class EndpointHandler:
4
+ def __init__(self, model_dir: str, **kw):
5
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
6
+ with init_empty_weights():
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ model_dir, torch_dtype="auto", trust_remote_code=True
9
+ )
10
+ self.model = load_checkpoint_and_dispatch(
11
+ model, checkpoint=model_dir, device_map="auto"
12
+ ) # 自动跨 GPU 切层
13
+ def __call__(self, data):
14
+ prompt = data["inputs"]
15
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
16
+ out_ids = self.model.generate(**inputs, max_new_tokens=256)
17
+ return {"generated_text": self.tokenizer.decode(out_ids[0], skip_special_tokens=True)}