yuzhe commited on
Commit
668e04f
·
verified ·
1 Parent(s): a55dc79

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -17
handler.py CHANGED
@@ -43,22 +43,21 @@ class EndpointHandler:
43
  )
44
 
45
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
46
- """
47
- data 格式:
48
- {
49
- "inputs": "your prompt here"
50
- }
51
- """
52
  prompt = data["inputs"]
53
-
54
- # ➡️ 只把输入张量放到 cuda:0(与模型第一层同卡)
55
- inputs = self.tokenizer(prompt, return_tensors="pt").to("cuda:0")
56
-
57
- # 生成
 
58
  with torch.inference_mode():
59
- output_ids = self.model.generate(**inputs, **self.generation_kwargs)
60
-
61
- generated_text = self.tokenizer.decode(
62
- output_ids[0], skip_special_tokens=True
63
- )
64
- return {"generated_text": generated_text}
 
 
 
 
 
43
  )
44
 
45
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
 
 
 
 
 
 
46
  prompt = data["inputs"]
47
+
48
+ # 自动抓 embedding 所在 GPU
49
+ first_device = next(self.model.parameters()).device
50
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(first_device)
51
+
52
+ # ② 生成(其余逻辑不变)
53
  with torch.inference_mode():
54
+ output_ids = self.model.generate(
55
+ **inputs,
56
+ **self.generation_kwargs,
57
+ )
58
+
59
+ return {
60
+ "generated_text": self.tokenizer.decode(
61
+ output_ids[0], skip_special_tokens=True
62
+ )
63
+ }