import os import re import shutil import tempfile from contextlib import contextmanager from pathlib import Path from urllib.parse import unquote, urlparse import fal from fal.toolkit.utils.download_utils import ( FAL_MODEL_WEIGHTS_DIR, DownloadError, _hash_url, ) FAL_VERSION = getattr(fal, "__version__", "<1.0.0") _REQUEST_HEADERS = {"User-Agent": f"fal-client ({FAL_VERSION}/python)"} def get_civitai_headers() -> dict[str, str]: headers: dict[str, str] = {} civitai_token = os.getenv("CIVITAI_TOKEN", None) if not civitai_token: print("CIVITAI_TOKEN is not set in the environment variables.") return headers headers["Authorization"] = f"Bearer {civitai_token}" return headers def get_huggingface_headers() -> dict[str, str]: headers: dict[str, str] = {} hf_token = os.getenv("HF_TOKEN", None) if not hf_token: print("HF_TOKEN is not set in the environment variables.") return headers headers["Authorization"] = f"Bearer {hf_token}" return headers def get_local_file_content_length(file_path: Path) -> int: return file_path.stat().st_size def download_url_to_file( url: str, dst: str | Path, progress: bool = True, headers: dict[str, str] = None, chunk_size_in_mb=16, file_integrity_check_callback=None, ) -> Path: """Download object at the given URL to a local path. Args: url (str): URL of the object to download dst (str): Full path where object will be saved, e.g. ``/tmp/temporary_file`` progress (bool, optional): whether or not to display a progress bar to stderr Default: True headers (dict, optional): HTTP headers to include with the request Default: None chunk_size_in_mb (int, optional): size of each chunk in MB Default: 16 file_integrity_check_callback (callable, optional): callback function to check file integrity Default: None """ from tqdm import tqdm file_size = None request_headers = { **_REQUEST_HEADERS, **(headers or {}), } url = url.strip() if url.startswith("data:"): return _download_data_url_to_file(url, dst) import requests req = requests.get(url, headers=request_headers, stream=True, allow_redirects=True) req.raise_for_status() headers = req.headers # type: ignore content_length = headers.get("Content-Length", None) # type: ignore if content_length is not None and len(content_length) > 0: file_size = int(content_length[0]) with tempfile.NamedTemporaryFile(delete=False) as temp_file: file_path = temp_file.name try: with tqdm( total=file_size, disable=not progress, unit="B", unit_scale=True, unit_divisor=1024, ) as pbar, open(file_path, "wb") as f: for chunk in req.iter_content( chunk_size=chunk_size_in_mb * 1024 * 1024 ): if chunk: f.write(chunk) pbar.update(len(chunk)) # Move the file when the file is downloaded completely. Since the # file used is temporary, in a case of an interruption, the downloaded # content will be lost. So, it is safe to redownload the file in such cases. shutil.move(file_path, dst) except Exception as error: raise error finally: Path(temp_file.name).unlink(missing_ok=True) if file_integrity_check_callback: file_integrity_check_callback(dst) return Path(dst) def _download_data_url_to_file(url: str, dst: str | Path): import base64 data = url.split(",")[1] data = base64.b64decode(data) with open(dst, "wb") as fp: fp.write(data) return Path(dst) def download_model_weights(url: str, force: bool = False) -> Path: parsed_url = urlparse(url) headers = {} if parsed_url.netloc == "civitai.com": headers.update(get_civitai_headers()) elif parsed_url.netloc == "huggingface.co": headers.update(get_huggingface_headers()) return download_model_weights_fal(url, request_headers=headers, force=force) def url_without_query(url)->str: # 找到 '?' 的位置 query_index = url.find('?') # 如果找到 '?', 则截取到 '?' 之前的部分 if query_index != -1: url_without_query = url[:query_index] else: url_without_query = url # 如果没有查询参数,保持原 URL return url_without_query def download_model_weights_fal( url: str, force: bool = False, request_headers: dict[str, str] | None = None ) -> Path: without_query = os.environ.get("CKPT_DOWNLOAD_WITHOUT_QUERY", "false") == "true" if without_query: url = url_without_query(url) weights_dir = Path(FAL_MODEL_WEIGHTS_DIR / _hash_url(url)) if weights_dir.exists() and not force: try: weights_path = next(weights_dir.glob("*")) is_safetensors_file(weights_path) return weights_path # The model weights directory is empty, so we need to download the weights except StopIteration: pass try: file_name, file_content_length = _get_remote_file_properties( url, request_headers=request_headers ) except Exception as e: print(e) raise DownloadError(f"Failed to get remote file properties for {url}") target_path = weights_dir / file_name if ( target_path.exists() and get_local_file_content_length(target_path) == file_content_length and not force ): is_safetensors_file(target_path) return target_path # Make sure the parent directory exists target_path.parent.mkdir(parents=True, exist_ok=True) # download from network-volume ckpt_download_dir = os.environ.get("CKPT_DOWNLOAD_DIR", None) if ckpt_download_dir: src_path = os.path.join(ckpt_download_dir, file_name) if not os.path.exists(src_path): src_path = os.path.join(ckpt_download_dir, _hash_url(url), file_name) if not os.path.exists(src_path): src_path = os.path.join(ckpt_download_dir, _hash_url(url_without_query(url)), file_name) # 如果文件存在则复制到路径,并返回 if os.path.exists(src_path): src_file_content_length = get_local_file_content_length(Path(src_path)) if src_file_content_length == file_content_length: print(f"copy start from:{src_path} to:{target_path}") shutil.copy(src_path, target_path) print(f"copy done from:{src_path} to:{target_path}") return target_path else: print(f"cannot copy file length not same src_path:{src_path} len:{src_file_content_length} target_path:{target_path} len:{file_content_length}") try: download_url_to_file( url, target_path, progress=True, headers=request_headers, file_integrity_check_callback=is_safetensors_file, ) except Exception as e: print(e) raise DownloadError(f"Failed to download {url}") return target_path def _get_filename_from_content_disposition(cd: str | None) -> str | None: if not cd: return None filenames = re.findall('filename="(.+)"', cd) if len(filenames) == 0: filenames = re.findall("filename=(.+)", cd) if len(filenames) == 0: return None return unquote(filenames[0]) def _parse_filename(url: str, cd: str | None) -> str: url = url.strip() file_name = _get_filename_from_content_disposition(cd) if not file_name: parsed_url = urlparse(url) if parsed_url.scheme == "data": file_name = _hash_url(url) else: url_path = parsed_url.path file_name = Path(url_path).name or _hash_url(url) if url.startswith("data:"): import mimetypes mime_type = url.split(",")[0].split(":")[1].split(";")[0] extension = mimetypes.guess_extension(mime_type) if extension: file_name += extension return file_name # type: ignore def _get_remote_file_properties( url: str, request_headers: dict[str, str] = None ) -> tuple[str, int]: import requests headers = { **_REQUEST_HEADERS, **(request_headers or {}), } req = requests.get( url, headers=headers, stream=True, allow_redirects=True, verify=False ) req.raise_for_status() headers = req.headers # type: ignore content_disposition = headers.get("Content-Disposition", None) file_name = _parse_filename(url, content_disposition) content_length = int(headers.get("Content-Length", -1)) return file_name, content_length def is_safetensors_file(path: str | Path): from safetensors import safe_open path = str(path) if not path.endswith(".safetensors"): raise ValueError(f"File {path} is not a .safetensors file") try: with safe_open(path, framework="pt"): pass except Exception as e: print(e) error_mesage = e.args[0] if error_mesage == "Error while deserializing header: HeaderTooLarge": raise ValueError(f"File {path} is not a .safetensors file") else: raise e @contextmanager def download_file_temp( url: str, progress: bool = True, headers: dict[str, str] = None, chunk_size_in_mb=16, file_integrity_check_callback=None, ): file_name = _parse_filename(url, None) with tempfile.TemporaryDirectory() as temp_dir: file_path = download_url_to_file( url, f"{temp_dir}/{file_name}", progress=progress, headers=headers, chunk_size_in_mb=chunk_size_in_mb, file_integrity_check_callback=file_integrity_check_callback, ) yield file_path