Spaces:
Sleeping
Sleeping
| import os | |
| import spaces | |
| import argparse | |
| from pathlib import Path | |
| import os | |
| import torch | |
| from diffusers import (DiffusionPipeline, AutoencoderKL, FlowMatchEulerDiscreteScheduler, StableDiffusionXLPipeline, StableDiffusionPipeline, | |
| FluxPipeline, FluxTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline) | |
| from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection, CLIPFeatureExtractor, AutoTokenizer, T5EncoderModel, BitsAndBytesConfig as TFBitsAndBytesConfig | |
| from huggingface_hub import save_torch_state_dict, snapshot_download | |
| from diffusers.loaders.single_file_utils import (convert_flux_transformer_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, | |
| convert_sd3_t5_checkpoint_to_diffusers) | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
| import safetensors.torch | |
| import gradio as gr | |
| import shutil | |
| import gc | |
| import tempfile | |
| # also requires aria, gdown, peft, huggingface_hub, safetensors, transformers, accelerate, pytorch_lightning | |
| from utils import (get_token, set_token, is_repo_exists, is_repo_name, get_download_file, upload_repo, gate_repo) | |
| from sdutils import (SCHEDULER_CONFIG_MAP, get_scheduler_config, fuse_loras, DTYPE_DEFAULT, get_dtype, get_dtypes, get_model_type_from_key, get_process_dtype) | |
| def fake_gpu(): | |
| pass | |
| try: | |
| from diffusers import BitsAndBytesConfig | |
| is_nf4 = True | |
| except Exception: | |
| is_nf4 = False | |
| FLUX_BASE_REPOS = ["camenduru/FLUX.1-dev-diffusers", "black-forest-labs/FLUX.1-schnell", "John6666/flux1-dev-fp8-flux", "John6666/flux1-schnell-fp8-flux"] | |
| FLUX_T5_URL = "https://huggingface.co/camenduru/FLUX.1-dev/blob/main/t5xxl_fp8_e4m3fn.safetensors" | |
| SD35_BASE_REPOS = ["adamo1139/stable-diffusion-3.5-large-ungated", "adamo1139/stable-diffusion-3.5-large-turbo-ungated"] | |
| SD35_T5_URL = "https://huggingface.co/adamo1139/stable-diffusion-3.5-large-turbo-ungated/blob/main/text_encoders/t5xxl_fp8_e4m3fn.safetensors" | |
| TEMP_DIR = tempfile.mkdtemp() | |
| IS_ZERO = os.environ.get("SPACES_ZERO_GPU") is not None | |
| IS_CUDA = torch.cuda.is_available() | |
| def safe_clean(path: str): | |
| try: | |
| if Path(path).exists(): | |
| if Path(path).is_dir(): shutil.rmtree(str(Path(path))) | |
| else: Path(path).unlink() | |
| print(f"Deleted: {path}") | |
| else: print(f"File not found: {path}") | |
| except Exception as e: | |
| print(f"Failed to delete: {path} {e}") | |
| def save_readme_md(dir, url): | |
| orig_url = "" | |
| orig_name = "" | |
| if is_repo_name(url): | |
| orig_name = url | |
| orig_url = f"https://huggingface.co/{url}/" | |
| elif "http" in url: | |
| orig_name = url | |
| orig_url = url | |
| if orig_name and orig_url: | |
| md = f"""--- | |
| license: other | |
| language: | |
| - en | |
| library_name: diffusers | |
| pipeline_tag: text-to-image | |
| tags: | |
| - text-to-image | |
| --- | |
| Converted from [{orig_name}]({orig_url}). | |
| """ | |
| else: | |
| md = f"""--- | |
| license: other | |
| language: | |
| - en | |
| library_name: diffusers | |
| pipeline_tag: text-to-image | |
| tags: | |
| - text-to-image | |
| --- | |
| """ | |
| path = str(Path(dir, "README.md")) | |
| with open(path, mode='w', encoding="utf-8") as f: | |
| f.write(md) | |
| def save_module(model, name: str, dir: str, dtype: str="fp8", progress=gr.Progress(track_tqdm=True)): # doesn't work | |
| if name in ["vae", "transformer", "unet"]: pattern = "diffusion_pytorch_model{suffix}.safetensors" | |
| else: pattern = "model{suffix}.safetensors" | |
| if name in ["transformer", "unet"]: size = "10GB" | |
| else: size = "5GB" | |
| path = str(Path(f"{dir.removesuffix('/')}/{name}")) | |
| os.makedirs(path, exist_ok=True) | |
| progress(0, desc=f"Saving {name} to {dir}...") | |
| print(f"Saving {name} to {dir}...") | |
| model.to("cpu") | |
| sd = dict(model.state_dict()) | |
| new_sd = {} | |
| for key in list(sd.keys()): | |
| q = sd.pop(key) | |
| if dtype == "fp8": new_sd[key] = q if q.dtype == torch.float8_e4m3fn else q.to(torch.float8_e4m3fn) | |
| else: new_sd[key] = q | |
| del sd | |
| gc.collect() | |
| save_torch_state_dict(state_dict=new_sd, save_directory=path, filename_pattern=pattern, max_shard_size=size) | |
| del new_sd | |
| gc.collect() | |
| def save_module_sd(sd: dict, name: str, dir: str, dtype: str="fp8", progress=gr.Progress(track_tqdm=True)): | |
| if name in ["vae", "transformer", "unet"]: pattern = "diffusion_pytorch_model{suffix}.safetensors" | |
| else: pattern = "model{suffix}.safetensors" | |
| if name in ["transformer", "unet"]: size = "10GB" | |
| else: size = "5GB" | |
| path = str(Path(f"{dir.removesuffix('/')}/{name}")) | |
| os.makedirs(path, exist_ok=True) | |
| progress(0, desc=f"Saving state_dict of {name} to {dir}...") | |
| print(f"Saving state_dict of {name} to {dir}...") | |
| new_sd = {} | |
| for key in list(sd.keys()): | |
| q = sd.pop(key).to("cpu") | |
| if dtype == "fp8": new_sd[key] = q if q.dtype == torch.float8_e4m3fn else q.to(torch.float8_e4m3fn) | |
| else: new_sd[key] = q | |
| save_torch_state_dict(state_dict=new_sd, save_directory=path, filename_pattern=pattern, max_shard_size=size) | |
| del new_sd | |
| gc.collect() | |
| def convert_flux_fp8_cpu(new_file: str, new_dir: str, dtype: str, base_repo: str, civitai_key: str, kwargs: dict, progress=gr.Progress(track_tqdm=True)): | |
| temp_dir = TEMP_DIR | |
| down_dir = str(Path(f"{TEMP_DIR}/down")) | |
| os.makedirs(down_dir, exist_ok=True) | |
| hf_token = get_token() | |
| progress(0.25, desc=f"Loading {new_file}...") | |
| orig_sd = safetensors.torch.load_file(new_file) | |
| progress(0.3, desc=f"Converting {new_file}...") | |
| conv_sd = convert_flux_transformer_checkpoint_to_diffusers(orig_sd) | |
| del orig_sd | |
| gc.collect() | |
| progress(0.35, desc=f"Saving {new_file}...") | |
| save_module_sd(conv_sd, "transformer", new_dir, dtype) | |
| del conv_sd | |
| gc.collect() | |
| progress(0.5, desc=f"Loading text_encoder_2 from {FLUX_T5_URL}...") | |
| t5_file = get_download_file(temp_dir, FLUX_T5_URL, civitai_key) | |
| if not t5_file: raise Exception(f"Safetensors file not found: {FLUX_T5_URL}") | |
| t5_sd = safetensors.torch.load_file(t5_file) | |
| safe_clean(t5_file) | |
| save_module_sd(t5_sd, "text_encoder_2", new_dir, dtype) | |
| del t5_sd | |
| gc.collect() | |
| progress(0.6, desc=f"Loading other components from {base_repo}...") | |
| pipe = FluxPipeline.from_pretrained(base_repo, transformer=None, text_encoder_2=None, use_safetensors=True, **kwargs, | |
| torch_dtype=torch.bfloat16, token=hf_token) | |
| pipe.save_pretrained(new_dir) | |
| progress(0.75, desc=f"Loading nontensor files from {base_repo}...") | |
| snapshot_download(repo_id=base_repo, local_dir=down_dir, token=hf_token, force_download=True, | |
| ignore_patterns=["*.safetensors", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"]) | |
| shutil.copytree(down_dir, new_dir, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True) | |
| safe_clean(down_dir) | |
| def convert_sd35_fp8_cpu(new_file: str, new_dir: str, dtype: str, base_repo: str, civitai_key: str, kwargs: dict, progress=gr.Progress(track_tqdm=True)): | |
| temp_dir = TEMP_DIR | |
| down_dir = str(Path(f"{TEMP_DIR}/down")) | |
| os.makedirs(down_dir, exist_ok=True) | |
| hf_token = get_token() | |
| progress(0.25, desc=f"Loading {new_file}...") | |
| orig_sd = safetensors.torch.load_file(new_file) | |
| progress(0.3, desc=f"Converting {new_file}...") | |
| conv_sd = convert_sd3_transformer_checkpoint_to_diffusers(orig_sd) | |
| del orig_sd | |
| gc.collect() | |
| progress(0.35, desc=f"Saving {new_file}...") | |
| save_module_sd(conv_sd, "transformer", new_dir, dtype) | |
| del conv_sd | |
| gc.collect() | |
| progress(0.5, desc=f"Loading text_encoder_3 from {SD35_T5_URL}...") | |
| t5_file = get_download_file(temp_dir, SD35_T5_URL, civitai_key) | |
| if not t5_file: raise Exception(f"Safetensors file not found: {SD35_T5_URL}") | |
| t5_sd = safetensors.torch.load_file(t5_file) | |
| safe_clean(t5_file) | |
| conv_t5_sd = convert_sd3_t5_checkpoint_to_diffusers(t5_sd) | |
| del t5_sd | |
| gc.collect() | |
| save_module_sd(conv_t5_sd, "text_encoder_3", new_dir, dtype) | |
| del conv_t5_sd | |
| gc.collect() | |
| progress(0.6, desc=f"Loading other components from {base_repo}...") | |
| pipe = StableDiffusion3Pipeline.from_pretrained(base_repo, transformer=None, text_encoder_3=None, use_safetensors=True, **kwargs, | |
| torch_dtype=torch.bfloat16, token=hf_token) | |
| pipe.save_pretrained(new_dir) | |
| progress(0.75, desc=f"Loading nontensor files from {base_repo}...") | |
| snapshot_download(repo_id=base_repo, local_dir=down_dir, token=hf_token, force_download=True, | |
| ignore_patterns=["*.safetensors", "*.sft", ".*", "README*", "*.md", "*.index", "*.jpg", "*.jpeg", "*.png", "*.webp"]) | |
| shutil.copytree(down_dir, new_dir, ignore=shutil.ignore_patterns(".*", "README*", "*.md", "*.jpg", "*.jpeg", "*.png", "*.webp"), dirs_exist_ok=True) | |
| safe_clean(down_dir) | |
| #@spaces.GPU(duration=60) | |
| def load_and_save_pipeline(pipe, model_type: str, url: str, new_file: str, new_dir: str, dtype: str, | |
| scheduler: str, ema: bool, image_size: str, is_safety_checker: bool, base_repo: str, civitai_key: str, lora_dict: dict, | |
| my_vae, my_clip_tokenizer, my_clip_encoder, my_t5_tokenizer, my_t5_encoder, | |
| kwargs: dict, dkwargs: dict, progress=gr.Progress(track_tqdm=True)): | |
| try: | |
| hf_token = get_token() | |
| temp_dir = TEMP_DIR | |
| qkwargs = {} | |
| tfqkwargs = {} | |
| if is_nf4: | |
| nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16) | |
| nf4_config_tf = TFBitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16) | |
| else: | |
| nf4_config = None | |
| nf4_config_tf = None | |
| if dtype == "NF4" and nf4_config is not None and nf4_config_tf is not None: | |
| qkwargs["quantization_config"] = nf4_config | |
| tfqkwargs["quantization_config"] = nf4_config_tf | |
| #print(f"model_type:{model_type}, dtype:{dtype}, scheduler:{scheduler}, ema:{ema}, base_repo:{base_repo}") | |
| #print("lora_dict:", lora_dict, "kwargs:", kwargs, "dkwargs:", dkwargs) | |
| #t5 = None | |
| if model_type == "SDXL": | |
| if is_repo_name(url): pipe = StableDiffusionXLPipeline.from_pretrained(url, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
| else: pipe = StableDiffusionXLPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **dkwargs) | |
| pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs) | |
| sconf = get_scheduler_config(scheduler) | |
| pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1]) | |
| pipe.save_pretrained(new_dir) | |
| elif model_type == "SD 1.5": | |
| if is_safety_checker: | |
| safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") | |
| feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32") | |
| kwargs["requires_safety_checker"] = True | |
| kwargs["safety_checker"] = safety_checker | |
| kwargs["feature_extractor"] = feature_extractor | |
| else: kwargs["requires_safety_checker"] = False | |
| if is_repo_name(url): pipe = StableDiffusionPipeline.from_pretrained(url, extract_ema=ema, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
| else: pipe = StableDiffusionPipeline.from_single_file(new_file, extract_ema=ema, use_safetensors=True, **kwargs, **dkwargs) | |
| pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs) | |
| sconf = get_scheduler_config(scheduler) | |
| pipe.scheduler = sconf[0].from_config(pipe.scheduler.config, **sconf[1]) | |
| if image_size != "512": pipe.vae = AutoencoderKL.from_config(pipe.vae.config, sample_size=int(image_size)) | |
| pipe.save_pretrained(new_dir) | |
| elif model_type == "FLUX": | |
| if dtype != "fp8": | |
| if is_repo_name(url): | |
| transformer = FluxTransformer2DModel.from_pretrained(url, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs) | |
| #if my_t5_encoder is None: | |
| # t5 = T5EncoderModel.from_pretrained(url, subfolder="text_encoder_2", config=base_repo, **dkwargs, **tfqkwargs) | |
| # kwargs["text_encoder_2"] = t5 | |
| pipe = FluxPipeline.from_pretrained(url, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
| else: | |
| transformer = FluxTransformer2DModel.from_single_file(new_file, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs) | |
| #if my_t5_encoder is None: | |
| # t5 = T5EncoderModel.from_pretrained(base_repo, subfolder="text_encoder_2", config=base_repo, **dkwargs, **tfqkwargs) | |
| # kwargs["text_encoder_2"] = t5 | |
| pipe = FluxPipeline.from_pretrained(base_repo, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
| pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs) | |
| pipe.save_pretrained(new_dir) | |
| elif not is_repo_name(url): convert_flux_fp8_cpu(new_file, new_dir, dtype, base_repo, civitai_key, kwargs) | |
| elif model_type == "SD 3.5": | |
| if dtype != "fp8": | |
| if is_repo_name(url): | |
| transformer = SD3Transformer2DModel.from_pretrained(url, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs) | |
| #if my_t5_encoder is None: | |
| # t5 = T5EncoderModel.from_pretrained(url, subfolder="text_encoder_3", config=base_repo, **dkwargs, **tfqkwargs) | |
| # kwargs["text_encoder_3"] = t5 | |
| pipe = StableDiffusion3Pipeline.from_pretrained(url, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
| else: | |
| transformer = SD3Transformer2DModel.from_single_file(new_file, subfolder="transformer", config=base_repo, **dkwargs, **qkwargs) | |
| #if my_t5_encoder is None: | |
| # t5 = T5EncoderModel.from_pretrained(base_repo, subfolder="text_encoder_3", config=base_repo, **dkwargs, **tfqkwargs) | |
| # kwargs["text_encoder_3"] = t5 | |
| pipe = StableDiffusion3Pipeline.from_pretrained(base_repo, transformer=transformer, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
| pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs) | |
| pipe.save_pretrained(new_dir) | |
| elif not is_repo_name(url): convert_sd35_fp8_cpu(new_file, new_dir, dtype, base_repo, civitai_key, kwargs) | |
| else: # unknown model type | |
| if is_repo_name(url): pipe = DiffusionPipeline.from_pretrained(url, use_safetensors=True, **kwargs, **dkwargs, token=hf_token) | |
| else: pipe = DiffusionPipeline.from_single_file(new_file, use_safetensors=True, **kwargs, **dkwargs) | |
| pipe = fuse_loras(pipe, lora_dict, temp_dir, civitai_key, dkwargs) | |
| pipe.save_pretrained(new_dir) | |
| except Exception as e: | |
| print(f"Failed to load pipeline. {e}") | |
| raise Exception("Failed to load pipeline.") from e | |
| finally: | |
| return pipe | |
| def convert_url_to_diffusers(url: str, civitai_key: str="", is_upload_sf: bool=False, dtype: str="fp16", vae: str="", clip: str="", t5: str="", | |
| scheduler: str="Euler a", ema: bool=True, image_size: str="768", safety_checker: bool=False, | |
| base_repo: str="", mtype: str="", lora_dict: dict={}, is_local: bool=True, progress=gr.Progress(track_tqdm=True)): | |
| try: | |
| hf_token = get_token() | |
| progress(0, desc="Start converting...") | |
| temp_dir = TEMP_DIR | |
| if is_repo_name(url) and is_repo_exists(url): | |
| new_file = url | |
| model_type = mtype | |
| else: | |
| new_file = get_download_file(temp_dir, url, civitai_key) | |
| if not new_file: raise Exception(f"Safetensors file not found: {url}") | |
| model_type = get_model_type_from_key(new_file) | |
| new_dir = Path(new_file).stem.replace(" ", "_").replace(",", "_").replace(".", "_") # | |
| kwargs = {} | |
| dkwargs = {} | |
| if dtype != DTYPE_DEFAULT: dkwargs["torch_dtype"] = get_process_dtype(dtype, model_type) | |
| pipe = None | |
| print(f"Model type: {model_type} / VAE: {vae} / CLIP: {clip} / T5: {t5} / Scheduler: {scheduler} / dtype: {dtype} / EMA: {ema} / Base repo: {base_repo} / LoRAs: {lora_dict}") | |
| my_vae = None | |
| if vae: | |
| progress(0, desc=f"Loading VAE: {vae}...") | |
| if is_repo_name(vae): my_vae = AutoencoderKL.from_pretrained(vae, **dkwargs, token=hf_token) | |
| else: | |
| new_vae_file = get_download_file(temp_dir, vae, civitai_key) | |
| my_vae = AutoencoderKL.from_single_file(new_vae_file, **dkwargs) if new_vae_file else None | |
| safe_clean(new_vae_file) | |
| if my_vae: kwargs["vae"] = my_vae | |
| my_clip_tokenizer = None | |
| my_clip_encoder = None | |
| if clip: | |
| progress(0, desc=f"Loading CLIP: {clip}...") | |
| if is_repo_name(clip): | |
| my_clip_tokenizer = CLIPTokenizer.from_pretrained(clip, token=hf_token) | |
| if model_type == "SD 3.5": my_clip_encoder = CLIPTextModelWithProjection.from_pretrained(clip, **dkwargs, token=hf_token) | |
| else: my_clip_encoder = CLIPTextModel.from_pretrained(clip, **dkwargs, token=hf_token) | |
| else: | |
| new_clip_file = get_download_file(temp_dir, clip, civitai_key) | |
| if model_type == "SD 3.5": my_clip_encoder = CLIPTextModelWithProjection.from_single_file(new_clip_file, **dkwargs) if new_clip_file else None | |
| else: my_clip_encoder = CLIPTextModel.from_single_file(new_clip_file, **dkwargs) if new_clip_file else None | |
| safe_clean(new_clip_file) | |
| if model_type == "SD 3.5": | |
| if my_clip_tokenizer: | |
| kwargs["tokenizer"] = my_clip_tokenizer | |
| kwargs["tokenizer_2"] = my_clip_tokenizer | |
| if my_clip_encoder: | |
| kwargs["text_encoder"] = my_clip_encoder | |
| kwargs["text_encoder_2"] = my_clip_encoder | |
| else: | |
| if my_clip_tokenizer: kwargs["tokenizer"] = my_clip_tokenizer | |
| if my_clip_encoder: kwargs["text_encoder"] = my_clip_encoder | |
| my_t5_tokenizer = None | |
| my_t5_encoder = None | |
| if t5: | |
| progress(0, desc=f"Loading T5: {t5}...") | |
| if is_repo_name(t5): | |
| my_t5_tokenizer = AutoTokenizer.from_pretrained(t5, token=hf_token) | |
| my_t5_encoder = T5EncoderModel.from_pretrained(t5, **dkwargs, token=hf_token) | |
| else: | |
| new_t5_file = get_download_file(temp_dir, t5, civitai_key) | |
| my_t5_encoder = T5EncoderModel.from_single_file(new_t5_file, **dkwargs) if new_t5_file else None | |
| safe_clean(new_t5_file) | |
| if model_type == "SD 3.5": | |
| if my_t5_tokenizer: kwargs["tokenizer_3"] = my_t5_tokenizer | |
| if my_t5_encoder: kwargs["text_encoder_3"] = my_t5_encoder | |
| else: | |
| if my_t5_tokenizer: kwargs["tokenizer_2"] = my_t5_tokenizer | |
| if my_t5_encoder: kwargs["text_encoder_2"] = my_t5_encoder | |
| pipe = load_and_save_pipeline(pipe, model_type, url, new_file, new_dir, dtype, scheduler, ema, image_size, safety_checker, base_repo, civitai_key, lora_dict, | |
| my_vae, my_clip_tokenizer, my_clip_encoder, my_t5_tokenizer, my_t5_encoder, kwargs, dkwargs) | |
| if Path(new_dir).exists(): save_readme_md(new_dir, url) | |
| if not is_local: | |
| if not is_repo_name(new_file) and is_upload_sf: shutil.move(str(Path(new_file).resolve()), str(Path(new_dir, Path(new_file).name).resolve())) | |
| else: safe_clean(new_file) | |
| progress(1, desc="Converted.") | |
| return new_dir | |
| except Exception as e: | |
| print(f"Failed to convert. {e}") | |
| raise Exception("Failed to convert.") from e | |
| finally: | |
| del pipe | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def convert_url_to_diffusers_repo(dl_url: str, hf_user: str, hf_repo: str, hf_token: str, civitai_key="", is_private: bool=True, | |
| gated: str="False", is_overwrite: bool=False, is_pr: bool=False, | |
| is_upload_sf: bool=False, urls: list=[], dtype: str="fp16", vae: str="", clip: str="", t5: str="", scheduler: str="Euler a", | |
| ema: bool=True, image_size: str="768", safety_checker: bool=False, | |
| base_repo: str="", mtype: str="", lora1: str="", lora1s=1.0, lora2: str="", lora2s=1.0, lora3: str="", lora3s=1.0, | |
| lora4: str="", lora4s=1.0, lora5: str="", lora5s=1.0, args: str="", progress=gr.Progress(track_tqdm=True)): | |
| try: | |
| is_local = False | |
| if not civitai_key and os.environ.get("CIVITAI_API_KEY"): civitai_key = os.environ.get("CIVITAI_API_KEY") # default Civitai API key | |
| if not hf_token and os.environ.get("HF_TOKEN"): hf_token = os.environ.get("HF_TOKEN") # default HF write token | |
| if not hf_user: raise gr.Error(f"Invalid user name: {hf_user}") | |
| if gated != "False" and is_private: raise gr.Error(f"Gated repo must be public") | |
| set_token(hf_token) | |
| lora_dict = {lora1: lora1s, lora2: lora2s, lora3: lora3s, lora4: lora4s, lora5: lora5s} | |
| new_path = convert_url_to_diffusers(dl_url, civitai_key, is_upload_sf, dtype, vae, clip, t5, scheduler, ema, image_size, safety_checker, base_repo, mtype, lora_dict, is_local) | |
| if not new_path: return "" | |
| new_repo_id = f"{hf_user}/{Path(new_path).stem}" | |
| if hf_repo != "": new_repo_id = f"{hf_user}/{hf_repo}" | |
| if not is_repo_name(new_repo_id): raise gr.Error(f"Invalid repo name: {new_repo_id}") | |
| if not is_overwrite and is_repo_exists(new_repo_id) and not is_pr: raise gr.Error(f"Repo already exists: {new_repo_id}") | |
| repo_url = upload_repo(new_repo_id, new_path, is_private, is_pr) | |
| gate_repo(new_repo_id, gated) | |
| safe_clean(new_path) | |
| if not urls: urls = [] | |
| urls.append(repo_url) | |
| md = "### Your new repo:\n" | |
| for u in urls: | |
| md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>" | |
| return gr.update(value=urls, choices=urls), gr.update(value=md) | |
| except Exception as e: | |
| print(f"Error occured. {e}") | |
| raise gr.Error(f"Error occured. {e}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--url", type=str, required=True, help="URL of the model to convert.") | |
| parser.add_argument("--dtype", default="fp16", type=str, choices=get_dtypes(), help='Output data type. (Default: "fp16")') | |
| parser.add_argument("--scheduler", default="Euler a", type=str, choices=list(SCHEDULER_CONFIG_MAP.keys()), required=False, help="Scheduler name to use.") | |
| parser.add_argument("--vae", default="", type=str, required=False, help="URL or Repo ID of the VAE to use.") | |
| parser.add_argument("--clip", default="", type=str, required=False, help="URL or Repo ID of the CLIP to use.") | |
| parser.add_argument("--t5", default="", type=str, required=False, help="URL or Repo ID of the T5 to use.") | |
| parser.add_argument("--base", default="", type=str, required=False, help="Repo ID of the base repo.") | |
| parser.add_argument("--nonema", action="store_true", default=False, help="Don't extract EMA (for SD 1.5).") | |
| parser.add_argument("--civitai_key", default="", type=str, required=False, help="Civitai API Key (If you want to download file from Civitai).") | |
| parser.add_argument("--lora1", default="", type=str, required=False, help="URL of the LoRA to use.") | |
| parser.add_argument("--lora1s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora1.") | |
| parser.add_argument("--lora2", default="", type=str, required=False, help="URL of the LoRA to use.") | |
| parser.add_argument("--lora2s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora2.") | |
| parser.add_argument("--lora3", default="", type=str, required=False, help="URL of the LoRA to use.") | |
| parser.add_argument("--lora3s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora3.") | |
| parser.add_argument("--lora4", default="", type=str, required=False, help="URL of the LoRA to use.") | |
| parser.add_argument("--lora4s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora4.") | |
| parser.add_argument("--lora5", default="", type=str, required=False, help="URL of the LoRA to use.") | |
| parser.add_argument("--lora5s", default=1.0, type=float, required=False, help="LoRA weight scale of --lora5.") | |
| parser.add_argument("--loras", default="", type=str, required=False, help="Folder of the LoRA to use.") | |
| args = parser.parse_args() | |
| assert args.url is not None, "Must provide a URL!" | |
| is_local = True | |
| lora_dict = {args.lora1: args.lora1s, args.lora2: args.lora2s, args.lora3: args.lora3s, args.lora4: args.lora4s, args.lora5: args.lora5s} | |
| if args.loras and Path(args.loras).exists(): | |
| for p in Path(args.loras).glob('**/*.safetensors'): | |
| lora_dict[str(p)] = 1.0 | |
| ema = not args.nonema | |
| mtype = "SDXL" | |
| convert_url_to_diffusers(args.url, args.civitai_key, args.dtype, args.vae, args.clip, args.t5, args.scheduler, ema, args.base, mtype, lora_dict, is_local) | |