Spaces:
Running
Running
File size: 7,769 Bytes
d8ca2a9 3997fa3 d8ca2a9 cbabf63 d8ca2a9 cbabf63 d8ca2a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import os
import subprocess
from huggingface_hub import HfApi, upload_folder
import gradio as gr
import hf_utils
import utils
subprocess.run(["git", "clone", "https://github.com/huggingface/diffusers", "diffs"])
def error_str(error, title="Error"):
return f"""#### {title}
{error}""" if error else ""
def on_token_change(token):
model_names, error = hf_utils.get_my_model_names(token)
if model_names:
model_names.append("Other")
return gr.update(visible=bool(model_names)), gr.update(choices=model_names, value=model_names[0] if model_names else None), gr.update(visible=bool(model_names)), gr.update(value=error_str(error))
def url_to_model_id(model_id_str):
return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1] if model_id_str.startswith("https://huggingface.co/") else model_id_str
def get_ckpt_names(token, radio_model_names, input_model):
model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
if token == "" or model_id == "":
return error_str("Please enter both a token and a model name.", title="Invalid input"), gr.update(choices=[]), gr.update(visible=False)
try:
api = HfApi(token=token)
ckpt_files = [f for f in api.list_repo_files(repo_id=model_id) if f.endswith(".ckpt")]
if not ckpt_files:
return error_str("No checkpoint files found in the model repo."), gr.update(choices=[]), gr.update(visible=False)
return None, gr.update(choices=ckpt_files, value=ckpt_files[0], visible=True), gr.update(visible=True)
except Exception as e:
return error_str(e), gr.update(choices=[]), None
def convert_and_push(radio_model_names, input_model, ckpt_name, sd_version, token, path_in_repo):
extract_ema = False
if sd_version == None:
return error_str("You must select a stable diffusion version.", title="Invalid input")
model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
try:
model_id = url_to_model_id(model_id)
# 1. Download the checkpoint file
ckpt_path, revision = hf_utils.download_file(repo_id=model_id, filename=ckpt_name, token=token)
# 2. Run the conversion script
os.makedirs(model_id, exist_ok=True)
run_command = [
"python3",
"./diffs/scripts/convert_original_stable_diffusion_to_diffusers.py",
"--checkpoint_path",
ckpt_path,
"--dump_path" ,
model_id,
]
if extract_ema:
run_command.append("--extract_ema")
subprocess.run(run_command)
# 3. Push to the model repo
commit_message="Add Diffusers weights"
upload_folder(
folder_path=model_id,
repo_id=model_id,
path_in_repo=path_in_repo,
token=token,
create_pr=True,
commit_message=commit_message,
commit_description=f"Add Diffusers weights converted from checkpoint `{ckpt_name}` in revision {revision}",
)
# # 4. Delete the downloaded checkpoint file, yaml files, and the converted model folder
hf_utils.delete_file(revision)
subprocess.run(["rm", "-rf", model_id.split('/')[0]])
import glob
for f in glob.glob("*.yaml*"):
subprocess.run(["rm", "-rf", f])
return f"""Successfully converted the checkpoint and opened a PR to add the weights to the model repo.
You can view and merge the PR [here]({hf_utils.get_pr_url(HfApi(token=token), model_id, commit_message)})."""
return "Done"
except Exception as e:
return error_str(e)
DESCRIPTION = """### Convert a stable diffusion checkpoint to Diffusers🧨
With this space, you can easily convert a CompVis stable diffusion checkpoint to Diffusers and automatically create a pull request to the model repo.
You can choose to convert a checkpoint from one of your own models, or from any other model on the Hub.
You can skip the queue by running the app in the colab: [](https://colab.research.google.com/gist/qunash/f0f3152c5851c0c477b68b7b98d547fe/convert-sd-to-diffusers.ipynb)"""
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=11):
with gr.Column():
gr.Markdown("## 1. Load model info")
input_token = gr.Textbox(
max_lines=1,
type="password",
label="Enter your Hugging Face token",
placeholder="READ permission is sufficient"
)
gr.Markdown("You can get a token [here](https://huggingface.co/settings/tokens)")
with gr.Group(visible=False) as group_model:
radio_model_names = gr.Radio(label="Choose a model")
input_model = gr.Textbox(
max_lines=1,
label="Model name or URL",
placeholder="username/model_name",
visible=False,
)
btn_get_ckpts = gr.Button("Load", visible=False)
with gr.Column(scale=10):
with gr.Column(visible=False) as group_convert:
gr.Markdown("## 2. Convert to Diffusers🧨")
radio_ckpts = gr.Radio(label="Choose the checkpoint to convert", visible=False)
path_in_repo = gr.Textbox(label="Path where the weights will be saved", placeholder="Leave empty for root folder")
radio_sd_version = gr.Radio(label="Choose the model version", choices=["v1", "v2", "v2.1"])
gr.Markdown("Conversion may take a few minutes.")
btn_convert = gr.Button("Convert & Push")
error_output = gr.Markdown(label="Output")
input_token.change(
fn=on_token_change,
inputs=input_token,
outputs=[group_model, radio_model_names, btn_get_ckpts, error_output],
queue=False,
scroll_to_output=True)
radio_model_names.change(
lambda x: gr.update(visible=x == "Other"),
inputs=radio_model_names,
outputs=input_model,
queue=False,
scroll_to_output=True)
btn_get_ckpts.click(
fn=get_ckpt_names,
inputs=[input_token, radio_model_names, input_model],
outputs=[error_output, radio_ckpts, group_convert],
scroll_to_output=True,
queue=False
)
btn_convert.click(
fn=convert_and_push,
inputs=[radio_model_names, input_model, radio_ckpts, radio_sd_version, input_token, path_in_repo],
outputs=error_output,
scroll_to_output=True
)
# gr.Markdown("""<img src="https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/imgs/diffusers_library.jpg" width="150"/>""")
gr.HTML("""
<div style="border-top: 1px solid #303030;">
<br>
<p>Space by: <a href="https://twitter.com/hahahahohohe"><img src="https://img.shields.io/twitter/follow/hahahahohohe?label=%40anzorq&style=social" alt="Twitter Follow"></a></p><br>
<a href="https://www.buymeacoffee.com/anzorq" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 45px !important;width: 162px !important;" ></a><br><br>
<p><img src="https://visitor-badge.glitch.me/badge?page_id=anzorq.sd-to-diffusers" alt="visitors"></p>
</div>
""")
demo.queue()
demo.launch(debug=True, share=utils.is_google_colab())
|