mikeee's picture
Update app.py
f45f68b
raw
history blame
1.69 kB
from loguru import logger
import rich
import os
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc
from huggingface_hub import hf_hub_download
from huggingface_hub import snapshot_download
# snapshot_download(repo_id="lysandre/arxiv-nlp")
model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits"
# snapshot_download?
loc = snapshot_download(repo_id=model_name, local_dir="model")
# fix timezone in Linux
os.environ["TZ"] = "Asia/Shanghai"
try:
time.tzset() # type: ignore # pylint: disable=no-member
except Exception:
# Windows
logger.warning("Windows, cant run time.tzset()")
model = None
gc.collect()
logger.info("start")
model = AutoModelForCausalLM.from_pretrained(
"model", # loc
device_map="auto",
torch_dtype=torch.bfloat16,
load_in_8bit=True,
trust_remote_code=True,
# use_ram_optimized_load=False,
# offload_folder="offload_folder",
)
rich.print(f"{model=}")
logger.info("done")
# ========
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", use_fast=False, trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan2-13B-Chat-4bits")
messages = []
messages.append({"role": "user", "content": "解释一下“温故而知新”"})
response = model.chat(tokenizer, messages)
rich.print(response)
logger.info(f"{response=}")