yuzhe commited on
Commit
c27e6f6
·
verified ·
1 Parent(s): 668e04f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +22 -28
handler.py CHANGED
@@ -1,4 +1,3 @@
1
- # handler.py —— 放在模型仓库根目录
2
  from typing import Dict, Any
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -6,58 +5,53 @@ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
6
 
7
 
8
  class EndpointHandler:
9
- """
10
- Hugging Face Inference Endpoints 约定的自定义入口:
11
- • __init__(model_dir, **kwargs) —— 加载模型
12
- • __call__(inputs: Dict) -> Dict —— 处理一次请求
13
- """
14
-
15
  def __init__(self, model_dir: str, **kwargs):
16
- # 1️⃣ Tokenizer
17
  self.tokenizer = AutoTokenizer.from_pretrained(
18
  model_dir, trust_remote_code=True
19
  )
20
 
21
- # 2️⃣ 构建“空壳”模型(不占显存)
22
  with init_empty_weights():
23
- base_model = AutoModelForCausalLM.from_pretrained(
24
  model_dir,
25
  torch_dtype=torch.float16,
26
  trust_remote_code=True,
27
  )
28
 
29
- # 3️⃣ 把权重切片加载到两张 GPU
30
  self.model = load_checkpoint_and_dispatch(
31
- base_model,
32
  checkpoint=model_dir,
33
- device_map="auto", # 自动分层到 cuda:0 / cuda:1
34
  dtype=torch.float16,
35
- )
36
 
37
- # 4️⃣ 生成时常用的生成参数
 
 
 
 
38
  self.generation_kwargs = dict(
39
- max_new_tokens=2048,
40
  do_sample=True,
41
  temperature=0.7,
42
  top_p=0.9,
43
  )
44
 
 
 
 
45
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
46
  prompt = data["inputs"]
47
-
48
- # 自动抓 embedding 所在 GPU
49
- first_device = next(self.model.parameters()).device
50
- inputs = self.tokenizer(prompt, return_tensors="pt").to(first_device)
51
-
52
- # ② 生成(其余逻辑不变)
53
  with torch.inference_mode():
54
- output_ids = self.model.generate(
55
- **inputs,
56
- **self.generation_kwargs,
57
- )
58
-
59
  return {
60
  "generated_text": self.tokenizer.decode(
61
  output_ids[0], skip_special_tokens=True
62
  )
63
- }
 
 
1
  from typing import Dict, Any
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
5
 
6
 
7
  class EndpointHandler:
 
 
 
 
 
 
8
  def __init__(self, model_dir: str, **kwargs):
 
9
  self.tokenizer = AutoTokenizer.from_pretrained(
10
  model_dir, trust_remote_code=True
11
  )
12
 
13
+ # 构空壳模型
14
  with init_empty_weights():
15
+ base = AutoModelForCausalLM.from_pretrained(
16
  model_dir,
17
  torch_dtype=torch.float16,
18
  trust_remote_code=True,
19
  )
20
 
21
+ # 分片加载到多 GPU
22
  self.model = load_checkpoint_and_dispatch(
23
+ base,
24
  checkpoint=model_dir,
25
+ device_map="auto",
26
  dtype=torch.float16,
27
+ ).eval()
28
 
29
+ # 记录 embedding 所在 GPU,并把 **默认 GPU** 也切过去
30
+ self.first_device = next(self.model.parameters()).device
31
+ torch.cuda.set_device(self.first_device) # ← 关键一行
32
+
33
+ # ④ 生成参数
34
  self.generation_kwargs = dict(
35
+ max_new_tokens=512, # 🛈 2 k token 占显存极高,先压到 512 再逐步调
36
  do_sample=True,
37
  temperature=0.7,
38
  top_p=0.9,
39
  )
40
 
41
+ # (可选)在日志中打印设备映射,方便后续排查
42
+ print(">>> device_map =", self.model.hf_device_map)
43
+
44
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
45
  prompt = data["inputs"]
46
+
47
+ # *所有* 输入张量放到 first_device
48
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.first_device)
49
+
 
 
50
  with torch.inference_mode():
51
+ output_ids = self.model.generate(**inputs, **self.generation_kwargs)
52
+
 
 
 
53
  return {
54
  "generated_text": self.tokenizer.decode(
55
  output_ids[0], skip_special_tokens=True
56
  )
57
+ }