File size: 2,088 Bytes
a55dc79
 
 
ebb4c7f
a55dc79
 
 
ebb4c7f
a55dc79
 
 
 
 
 
 
 
 
 
 
 
 
ebb4c7f
a55dc79
 
 
 
ebb4c7f
a55dc79
 
ebb4c7f
a55dc79
 
 
 
 
 
 
 
 
 
 
 
77c95fb
 
a55dc79
 
668e04f
 
 
 
 
 
a55dc79
668e04f
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# handler.py  ——  放在模型仓库根目录
from typing import Dict, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch


class EndpointHandler:
    """
    Hugging Face Inference Endpoints 约定的自定义入口:
      • __init__(model_dir, **kwargs)   —— 加载模型
      • __call__(inputs: Dict) -> Dict  —— 处理一次请求
    """

    def __init__(self, model_dir: str, **kwargs):
        # 1️⃣ Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_dir, trust_remote_code=True
        )

        # 2️⃣ 构建“空壳”模型(不占显存)
        with init_empty_weights():
            base_model = AutoModelForCausalLM.from_pretrained(
                model_dir,
                torch_dtype=torch.float16,
                trust_remote_code=True,
            )

        # 3️⃣ 把权重切片加载到两张 GPU
        self.model = load_checkpoint_and_dispatch(
            base_model,
            checkpoint=model_dir,
            device_map="auto",                # 自动分层到 cuda:0 / cuda:1
            dtype=torch.float16,
        )

        # 4️⃣ 生成时常用的生成参数
        self.generation_kwargs = dict(
            max_new_tokens=2048,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
        )

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        prompt = data["inputs"]
    
        # ① 自动抓 embedding 所在 GPU
        first_device = next(self.model.parameters()).device
        inputs = self.tokenizer(prompt, return_tensors="pt").to(first_device)
    
        # ② 生成(其余逻辑不变)
        with torch.inference_mode():
            output_ids = self.model.generate(
                **inputs,
                **self.generation_kwargs,
            )
    
        return {
            "generated_text": self.tokenizer.decode(
                output_ids[0], skip_special_tokens=True
            )
        }