File size: 1,944 Bytes
a55dc79
 
ebb4c7f
a55dc79
 
 
ebb4c7f
a55dc79
 
 
 
 
c27e6f6
ebb4c7f
c27e6f6
a55dc79
 
 
ebb4c7f
a55dc79
c27e6f6
ebb4c7f
c27e6f6
a55dc79
c27e6f6
a55dc79
c27e6f6
a55dc79
c27e6f6
 
 
 
 
a55dc79
c27e6f6
a55dc79
 
 
77c95fb
 
c27e6f6
 
 
a55dc79
 
c27e6f6
 
 
 
a55dc79
c27e6f6
 
668e04f
 
 
 
c27e6f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import Dict, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch


class EndpointHandler:
    def __init__(self, model_dir: str, **kwargs):
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_dir, trust_remote_code=True
        )

        # ① 构空壳模型
        with init_empty_weights():
            base = AutoModelForCausalLM.from_pretrained(
                model_dir,
                torch_dtype=torch.float16,
                trust_remote_code=True,
            )

        # ② 分片加载到多 GPU
        self.model = load_checkpoint_and_dispatch(
            base,
            checkpoint=model_dir,
            device_map="auto",
            dtype=torch.float16,
        ).eval()

        # ③ 记录 embedding 所在 GPU,并把 **默认 GPU** 也切过去
        self.first_device = next(self.model.parameters()).device
        torch.cuda.set_device(self.first_device)     # ← 关键一行

        # ④ 生成参数
        self.generation_kwargs = dict(
            max_new_tokens=512,     # 🛈 2 k token 占显存极高,先压到 512 再逐步调
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
        )

        # (可选)在日志中打印设备映射,方便后续排查
        print(">>> device_map =", self.model.hf_device_map)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        prompt = data["inputs"]

        # 把 *所有* 输入张量放到 first_device
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.first_device)

        with torch.inference_mode():
            output_ids = self.model.generate(**inputs, **self.generation_kwargs)

        return {
            "generated_text": self.tokenizer.decode(
                output_ids[0], skip_special_tokens=True
            )
        }