yuzhe commited on
Commit
bc2a0b5
·
verified ·
1 Parent(s): c27e6f6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -36
handler.py CHANGED
@@ -1,57 +1,38 @@
 
1
  from typing import Dict, Any
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
5
 
6
-
7
  class EndpointHandler:
8
- def __init__(self, model_dir: str, **kwargs):
9
- self.tokenizer = AutoTokenizer.from_pretrained(
10
- model_dir, trust_remote_code=True
11
- )
12
 
13
- # ① 构空壳模型
14
  with init_empty_weights():
15
  base = AutoModelForCausalLM.from_pretrained(
16
- model_dir,
17
- torch_dtype=torch.float16,
18
- trust_remote_code=True,
19
  )
20
 
21
- # ② 分片加载到多 GPU
22
  self.model = load_checkpoint_and_dispatch(
23
- base,
24
- checkpoint=model_dir,
25
- device_map="auto",
26
- dtype=torch.float16,
27
  ).eval()
28
 
29
- # ③ 记录 embedding 所在 GPU,并把 **默认 GPU** 也切过去
30
- self.first_device = next(self.model.parameters()).device
31
- torch.cuda.set_device(self.first_device) # ← 关键一行
32
-
33
- # ④ 生成参数
34
- self.generation_kwargs = dict(
35
- max_new_tokens=512, # 🛈 2 k token 占显存极高,先压到 512 再逐步调
36
- do_sample=True,
37
- temperature=0.7,
38
- top_p=0.9,
39
- )
40
 
41
- # (可选)在日志中打印设备映射,方便后续排查
42
- print(">>> device_map =", self.model.hf_device_map)
43
 
44
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
45
  prompt = data["inputs"]
46
 
47
- # 把 *所有* 输入张量放到 first_device
48
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.first_device)
49
-
50
  with torch.inference_mode():
51
- output_ids = self.model.generate(**inputs, **self.generation_kwargs)
52
 
53
- return {
54
- "generated_text": self.tokenizer.decode(
55
- output_ids[0], skip_special_tokens=True
56
- )
57
- }
 
1
+ # handler.py
2
  from typing import Dict, Any
3
  import torch
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
6
 
 
7
  class EndpointHandler:
8
+ def __init__(self, model_dir: str, **kw):
9
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
 
 
10
 
11
+ # ① 空壳模型
12
  with init_empty_weights():
13
  base = AutoModelForCausalLM.from_pretrained(
14
+ model_dir, torch_dtype=torch.float16, trust_remote_code=True
 
 
15
  )
16
 
17
+ # ② 分片加载
18
  self.model = load_checkpoint_and_dispatch(
19
+ base, checkpoint=model_dir, device_map="auto", dtype=torch.float16
 
 
 
20
  ).eval()
21
 
22
+ # ③ 锁定“默认 GPU”= 词嵌入所在 GPU
23
+ self.embed_device = self.model.get_input_embeddings().weight.device
24
+ torch.cuda.set_device(self.embed_device) # ← 关键 1
25
+ print(">>> embedding on", self.embed_device)
 
 
 
 
 
 
 
26
 
27
+ # 生成参数
28
+ self.gen_kwargs = dict(max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True)
29
 
30
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
31
  prompt = data["inputs"]
32
 
33
+ # 把 *所有* 输入张量放到 embed_device
34
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.embed_device) # ← 关键 2
 
35
  with torch.inference_mode():
36
+ out_ids = self.model.generate(**inputs, **self.gen_kwargs)
37
 
38
+ return {"generated_text": self.tokenizer.decode(out_ids[0], skip_special_tokens=True)}