yuzhe commited on
Commit
77c95fb
·
verified ·
1 Parent(s): ebb4c7f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -4
handler.py CHANGED
@@ -11,7 +11,19 @@ class EndpointHandler:
11
  model, checkpoint=model_dir, device_map="auto"
12
  ) # 自动跨 GPU 切层
13
  def __call__(self, data):
14
- prompt = data["inputs"]
15
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
16
- out_ids = self.model.generate(**inputs, max_new_tokens=256)
17
- return {"generated_text": self.tokenizer.decode(out_ids[0], skip_special_tokens=True)}
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  model, checkpoint=model_dir, device_map="auto"
12
  ) # 自动跨 GPU 切层
13
  def __call__(self, data):
14
+ prompt = data["inputs"]
15
+
16
+ inputs = self.tokenizer(
17
+ prompt, return_tensors="pt"
18
+ ).to("cuda:0") # 👈 把 input_ids/attention_mask 都放到 0 号卡
19
+
20
+ out_ids = self.model.generate(
21
+ **inputs,
22
+ max_new_tokens=256,
23
+ )
24
+ return {
25
+ "generated_text": self.tokenizer.decode(
26
+ out_ids[0], skip_special_tokens=True
27
+ )
28
+ }
29
+