File size: 1,687 Bytes
a90ba7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f45f68b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from loguru import logger
import rich
import os
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc

from huggingface_hub import hf_hub_download
from huggingface_hub import snapshot_download
# snapshot_download(repo_id="lysandre/arxiv-nlp")

model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits"
# snapshot_download?
loc = snapshot_download(repo_id=model_name, local_dir="model")

# fix timezone in Linux
os.environ["TZ"] = "Asia/Shanghai"
try:
    time.tzset()  # type: ignore # pylint: disable=no-member
except Exception:
    # Windows
    logger.warning("Windows, cant run time.tzset()")

model = None
gc.collect()

logger.info("start")

model = AutoModelForCausalLM.from_pretrained(
    "model",   # loc
    device_map="auto", 
    torch_dtype=torch.bfloat16, 
    load_in_8bit=True,
    trust_remote_code=True,
    # use_ram_optimized_load=False,
    # offload_folder="offload_folder",
)

rich.print(f"{model=}")

logger.info("done")

# ========
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", use_fast=False, trust_remote_code=True)

# model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)

model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits")
messages = []
messages.append({"role": "user", "content": "解释一下“温故而知新”"})
response = model.chat(tokenizer, messages)

rich.print(response)

logger.info(f"{response=}")