|
def convert_to_diffusers(state_dict): |
|
new_state_dict = {} |
|
for key in state_dict: |
|
new_key = key |
|
|
|
if key.startswith("diffusion_model."): |
|
new_key = key.replace("diffusion_model.", "transformer.") |
|
|
|
|
|
if "self_attn" in new_key: |
|
new_key = new_key.replace("self_attn", "attn1") |
|
elif "cross_attn" in new_key: |
|
new_key = new_key.replace("cross_attn", "attn2") |
|
|
|
|
|
parts = new_key.split(".") |
|
for i, part in enumerate(parts): |
|
if part in ["q", "k", "v"]: |
|
parts[i] = f"to_{part}" |
|
elif part == "o": |
|
parts[i] = "to_out.0" |
|
new_key = ".".join(parts) |
|
|
|
|
|
if "ffn.0" in new_key: |
|
new_key = new_key.replace("ffn.0", "ffn.net.0.proj") |
|
elif "ffn.2" in new_key: |
|
new_key = new_key.replace("ffn.2", "ffn.net.2") |
|
|
|
new_state_dict[new_key] = state_dict[key] |
|
return new_state_dict |
|
|
|
|
|
def convert_to_original(state_dict): |
|
new_state_dict = {} |
|
for key in state_dict: |
|
new_key = key |
|
|
|
if key.startswith("transformer."): |
|
new_key = key.replace("transformer.", "diffusion_model.") |
|
|
|
|
|
if "attn1" in new_key: |
|
new_key = new_key.replace("attn1", "self_attn") |
|
elif "attn2" in new_key: |
|
new_key = new_key.replace("attn2", "cross_attn") |
|
|
|
|
|
if "to_out.0" in new_key: |
|
new_key = new_key.replace("to_out.0", "o") |
|
elif "to_q" in new_key: |
|
new_key = new_key.replace("to_q", "q") |
|
elif "to_k" in new_key: |
|
new_key = new_key.replace("to_k", "k") |
|
elif "to_v" in new_key: |
|
new_key = new_key.replace("to_v", "v") |
|
|
|
|
|
if "ffn.net.0.proj" in new_key: |
|
new_key = new_key.replace("ffn.net.0.proj", "ffn.0") |
|
elif "ffn.net.2" in new_key: |
|
new_key = new_key.replace("ffn.net.2", "ffn.2") |
|
|
|
new_state_dict[new_key] = state_dict[key] |
|
return new_state_dict |
|
|