File size: 1,687 Bytes
a90ba7a f45f68b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from loguru import logger
import rich
import os
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc
from huggingface_hub import hf_hub_download
from huggingface_hub import snapshot_download
# snapshot_download(repo_id="lysandre/arxiv-nlp")
model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits"
# snapshot_download?
loc = snapshot_download(repo_id=model_name, local_dir="model")
# fix timezone in Linux
os.environ["TZ"] = "Asia/Shanghai"
try:
time.tzset() # type: ignore # pylint: disable=no-member
except Exception:
# Windows
logger.warning("Windows, cant run time.tzset()")
model = None
gc.collect()
logger.info("start")
model = AutoModelForCausalLM.from_pretrained(
"model", # loc
device_map="auto",
torch_dtype=torch.bfloat16,
load_in_8bit=True,
trust_remote_code=True,
# use_ram_optimized_load=False,
# offload_folder="offload_folder",
)
rich.print(f"{model=}")
logger.info("done")
# ========
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", use_fast=False, trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits")
messages = []
messages.append({"role": "user", "content": "解释一下“温故而知新”"})
response = model.chat(tokenizer, messages)
rich.print(response)
logger.info(f"{response=}") |