|
from loguru import logger |
|
import rich |
|
import os |
|
import time |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import gc |
|
|
|
from huggingface_hub import hf_hub_download |
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits" |
|
|
|
loc = snapshot_download(repo_id=model_name, local_dir="model") |
|
|
|
|
|
os.environ["TZ"] = "Asia/Shanghai" |
|
try: |
|
time.tzset() |
|
except Exception: |
|
|
|
logger.warning("Windows, cant run time.tzset()") |
|
|
|
model = None |
|
gc.collect() |
|
|
|
logger.info("start") |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
"model", |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
load_in_8bit=True, |
|
trust_remote_code=True, |
|
|
|
|
|
) |
|
|
|
rich.print(f"{model=}") |
|
|
|
logger.info("done") |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation.utils import GenerationConfig |
|
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", use_fast=False, trust_remote_code=True) |
|
|
|
|
|
|
|
model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits") |
|
messages = [] |
|
messages.append({"role": "user", "content": "解释一下“温故而知新”"}) |
|
response = model.chat(tokenizer, messages) |
|
|
|
rich.print(response) |
|
|
|
logger.info(f"{response=}") |