Update handler.py
Browse files- handler.py +16 -17
handler.py
CHANGED
@@ -43,22 +43,21 @@ class EndpointHandler:
|
|
43 |
)
|
44 |
|
45 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
46 |
-
"""
|
47 |
-
data 格式:
|
48 |
-
{
|
49 |
-
"inputs": "your prompt here"
|
50 |
-
}
|
51 |
-
"""
|
52 |
prompt = data["inputs"]
|
53 |
-
|
54 |
-
#
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
with torch.inference_mode():
|
59 |
-
output_ids = self.model.generate(
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
return {
|
|
|
|
|
|
|
|
|
|
43 |
)
|
44 |
|
45 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
prompt = data["inputs"]
|
47 |
+
|
48 |
+
# ① 自动抓 embedding 所在 GPU
|
49 |
+
first_device = next(self.model.parameters()).device
|
50 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(first_device)
|
51 |
+
|
52 |
+
# ② 生成(其余逻辑不变)
|
53 |
with torch.inference_mode():
|
54 |
+
output_ids = self.model.generate(
|
55 |
+
**inputs,
|
56 |
+
**self.generation_kwargs,
|
57 |
+
)
|
58 |
+
|
59 |
+
return {
|
60 |
+
"generated_text": self.tokenizer.decode(
|
61 |
+
output_ids[0], skip_special_tokens=True
|
62 |
+
)
|
63 |
+
}
|