Update handler.py
Browse files- handler.py +16 -4
handler.py
CHANGED
@@ -11,7 +11,19 @@ class EndpointHandler:
|
|
11 |
model, checkpoint=model_dir, device_map="auto"
|
12 |
) # 自动跨 GPU 切层
|
13 |
def __call__(self, data):
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
model, checkpoint=model_dir, device_map="auto"
|
12 |
) # 自动跨 GPU 切层
|
13 |
def __call__(self, data):
|
14 |
+
prompt = data["inputs"]
|
15 |
+
|
16 |
+
inputs = self.tokenizer(
|
17 |
+
prompt, return_tensors="pt"
|
18 |
+
).to("cuda:0") # 👈 把 input_ids/attention_mask 都放到 0 号卡
|
19 |
+
|
20 |
+
out_ids = self.model.generate(
|
21 |
+
**inputs,
|
22 |
+
max_new_tokens=256,
|
23 |
+
)
|
24 |
+
return {
|
25 |
+
"generated_text": self.tokenizer.decode(
|
26 |
+
out_ids[0], skip_special_tokens=True
|
27 |
+
)
|
28 |
+
}
|
29 |
+
|