File size: 5,101 Bytes
d0ffe9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import logging
from os import PathLike
from pathlib import Path
from typing import Optional

from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download, snapshot_download
from tqdm.rich import tqdm

from animatediff import HF_HUB_CACHE, HF_LIB_NAME, HF_LIB_VER, get_dir
from animatediff.utils.util import path_from_cwd

logger = logging.getLogger(__name__)

data_dir = get_dir("data")
checkpoint_dir = data_dir.joinpath("models/sd")
pipeline_dir = data_dir.joinpath("models/huggingface")

IGNORE_TF = ["*.git*", "*.h5", "tf_*"]
IGNORE_FLAX = ["*.git*", "flax_*", "*.msgpack"]
IGNORE_TF_FLAX = IGNORE_TF + IGNORE_FLAX


class DownloadTqdm(tqdm):
    def __init__(self, *args, **kwargs):
        kwargs.update(
            {
                "ncols": 100,
                "dynamic_ncols": False,
                "disable": None,
            }
        )
        super().__init__(*args, **kwargs)


def get_hf_file(
    repo_id: Path,
    filename: str,
    target_dir: Path,
    subfolder: Optional[PathLike] = None,
    revision: Optional[str] = None,
    force: bool = False,
) -> Path:
    target_path = target_dir.joinpath(filename)
    if target_path.exists() and force is not True:
        raise FileExistsError(
            f"File {path_from_cwd(target_path)} already exists! Pass force=True to overwrite"
        )

    target_dir.mkdir(exist_ok=True, parents=True)
    save_path = hf_hub_download(
        repo_id=str(repo_id),
        filename=filename,
        revision=revision or "main",
        subfolder=subfolder,
        local_dir=target_dir,
        local_dir_use_symlinks=False,
        cache_dir=HF_HUB_CACHE,
        resume_download=True,
    )
    return Path(save_path)


def get_hf_repo(
    repo_id: Path,
    target_dir: Path,
    subfolder: Optional[PathLike] = None,
    revision: Optional[str] = None,
    force: bool = False,
) -> Path:
    if target_dir.exists() and force is not True:
        raise FileExistsError(
            f"Target dir {path_from_cwd(target_dir)} already exists! Pass force=True to overwrite"
        )

    target_dir.mkdir(exist_ok=True, parents=True)
    save_path = snapshot_download(
        repo_id=str(repo_id),
        revision=revision or "main",
        subfolder=subfolder,
        library_name=HF_LIB_NAME,
        library_version=HF_LIB_VER,
        local_dir=target_dir,
        local_dir_use_symlinks=False,
        ignore_patterns=IGNORE_TF_FLAX,
        cache_dir=HF_HUB_CACHE,
        tqdm_class=DownloadTqdm,
        max_workers=2,
        resume_download=True,
    )
    return Path(save_path)


def get_hf_pipeline(
    repo_id: Path,
    target_dir: Path,
    save: bool = True,
    force_download: bool = False,
) -> StableDiffusionPipeline:
    pipeline_exists = target_dir.joinpath("model_index.json").exists()
    if pipeline_exists and force_download is not True:
        pipeline = StableDiffusionPipeline.from_pretrained(
            pretrained_model_name_or_path=target_dir,
            local_files_only=True,
        )
    else:
        target_dir.mkdir(exist_ok=True, parents=True)
        pipeline = StableDiffusionPipeline.from_pretrained(
            pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"),
            cache_dir=HF_HUB_CACHE,
            resume_download=True,
        )
        if save and force_download:
            logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!")
            pipeline.save_pretrained(target_dir, safe_serialization=True)
        elif save and not pipeline_exists:
            logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}")
            pipeline.save_pretrained(target_dir, safe_serialization=True)
    return pipeline

def get_hf_pipeline_sdxl(
    repo_id: Path,
    target_dir: Path,
    save: bool = True,
    force_download: bool = False,
) -> StableDiffusionXLPipeline:
    import torch
    pipeline_exists = target_dir.joinpath("model_index.json").exists()
    if pipeline_exists and force_download is not True:
        pipeline = StableDiffusionXLPipeline.from_pretrained(
            pretrained_model_name_or_path=target_dir,
            local_files_only=True,
            torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
        )
    else:
        target_dir.mkdir(exist_ok=True, parents=True)
        pipeline = StableDiffusionXLPipeline.from_pretrained(
            pretrained_model_name_or_path=str(repo_id).lstrip("./").replace("\\", "/"),
            cache_dir=HF_HUB_CACHE,
            resume_download=True,
            torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
        )
        if save and force_download:
            logger.warning(f"Pipeline already exists at {path_from_cwd(target_dir)}. Overwriting!")
            pipeline.save_pretrained(target_dir, safe_serialization=True)
        elif save and not pipeline_exists:
            logger.info(f"Saving pipeline to {path_from_cwd(target_dir)}")
            pipeline.save_pretrained(target_dir, safe_serialization=True)
    return pipeline