File size: 5,101 Bytes
d0ffe9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import logging
from os import PathLike
from pathlib import Path
from typing import Optional
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download, snapshot_download
from tqdm.rich import tqdm
from animatediff import HF_HUB_CACHE, HF_LIB_NAME, HF_LIB_VER, get_dir
from animatediff.utils.util import path_from_cwd
logger = logging.getLogger(__name__)
data_dir = get_dir("data")
checkpoint_dir = data_dir.joinpath("models/sd")
pipeline_dir = data_dir.joinpath("models/huggingface")
IGNORE_TF = ["*.git*", "*.h5", "tf_*"]
IGNORE_FLAX = ["*.git*", "flax_*", "*.msgpack"]
IGNORE_TF_FLAX = IGNORE_TF + IGNORE_FLAX
class DownloadTqdm(tqdm):
def __init__(self, *args, **kwargs):
kwargs.update(
{
"ncols": 100,
"dynamic_ncols": False,
"disable": None,
}
)
super().__init__(*args, **kwargs)
def get_hf_file(
repo_id: Path,
filename: str,
target_dir: Path,
subfolder: Optional[PathLike] = None,
revision: Optional[str] = None,
force: bool = False,
) -> Path:
target_path = target_dir.joinpath(filename)
if target_path.exists() and force is not True:
raise FileExistsError(
f"File {path_from_cwd(target_path)} already exists! Pass force=True to overwrite"
)
target_dir.mkdir(exist_ok=True, parents=True)
save_path = hf_hub_download(
repo_id=str(repo_id),
filename=filename,
revision=revision or "main",
subfolder=subfolder,
local_dir=target_dir,
local_dir_use_symlinks=False,
cache_dir=HF_HUB_CACHE,
resume_download=True,
)
return Path(save_path)
def get_hf_repo(
repo_id: Path,
target_dir: Path,
subfolder: Optional[PathLike] = None,
revision: Optional[str] = None,
force: bool = False,
) -> Path:
if target_dir.exists() and force is not True:
raise FileExistsError(
f"Target dir {path_from_cwd(target_dir)} already exists! Pass force=True to overwrite"
)
target_dir.mkdir(exist_ok=True, parents=True)
save_path = snapshot_download(
repo_id=str(repo_id),
revision=revision or "main",
subfolder=subfolder,
library_name=HF_LIB_NAME,
library_version=HF_LIB_VER,
local_dir=target_dir,
local_dir_use_symlinks=False,
ignore_patterns=IGNORE_TF_FLAX,
cache_dir=HF_HUB_CACHE,
tqdm_class=DownloadTqdm,
max_workers=2,
resume_download=True,
)
return Path(save_path)
def get_hf_pipeline(
repo_id: Path,
target_dir: Path,
save: bool = True,
force_download: bool = False,
) -> StableDiffusionPipeline:
pipeline_exists = target_dir.joinpath("model_index.json").exists()
if pipeline_exists and force_download is not True:
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=target_dir,
local_files_only=True,
)
else:
target_dir.mkdir(exist_ok=True, parents=True)
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"),
cache_dir=HF_HUB_CACHE,
resume_download=True,
)
if save and force_download:
logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!")
pipeline.save_pretrained(target_dir, safe_serialization=True)
elif save and not pipeline_exists:
logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}")
pipeline.save_pretrained(target_dir, safe_serialization=True)
return pipeline
def get_hf_pipeline_sdxl(
repo_id: Path,
target_dir: Path,
save: bool = True,
force_download: bool = False,
) -> StableDiffusionXLPipeline:
import torch
pipeline_exists = target_dir.joinpath("model_index.json").exists()
if pipeline_exists and force_download is not True:
pipeline = StableDiffusionXLPipeline.from_pretrained(
pretrained_model_name_or_path=target_dir,
local_files_only=True,
torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
)
else:
target_dir.mkdir(exist_ok=True, parents=True)
pipeline = StableDiffusionXLPipeline.from_pretrained(
pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"),
cache_dir=HF_HUB_CACHE,
resume_download=True,
torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
)
if save and force_download:
logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!")
pipeline.save_pretrained(target_dir, safe_serialization=True)
elif save and not pipeline_exists:
logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}")
pipeline.save_pretrained(target_dir, safe_serialization=True)
return pipeline
|