mikeee's picture
Create app.py
a90ba7a
raw
history blame
981 Bytes
from loguru import logger
import rich
import os
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc
from huggingface_hub import hf_hub_download
from huggingface_hub import snapshot_download
# snapshot_download(repo_id="lysandre/arxiv-nlp")
model_name = "baichuan-inc/Baichuan2-13B-Chat-4bits"
# snapshot_download?
loc = snapshot_download(repo_id=model_name, local_dir="model")
# fix timezone in Linux
os.environ["TZ"] = "Asia/Shanghai"
try:
time.tzset() # type: ignore # pylint: disable=no-member
except Exception:
# Windows
logger.warning("Windows, cant run time.tzset()")
model = None
gc.collect()
logger.info("start")
model = AutoModelForCausalLM.from_pretrained(
"model", # loc
device_map="auto",
torch_dtype=torch.bfloat16,
load_in_8bit=True,
trust_remote_code=True,
# use_ram_optimized_load=False,
# offload_folder="offload_folder",
)
rich.print(f"{model=}")
logger.info("done")