Spaces:
Build error
Build error
Update pipline_StableDiffusion_ConsistentID.py
Browse files
pipline_StableDiffusion_ConsistentID.py
CHANGED
|
@@ -21,6 +21,7 @@ from attention import Consistent_IPAttProcessor, Consistent_AttProcessor, Facial
|
|
| 21 |
### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
|
| 22 |
### Thanks for the open source of face-parsing model.
|
| 23 |
from models.BiSeNet.model import BiSeNet
|
|
|
|
| 24 |
|
| 25 |
PipelineImageInput = Union[
|
| 26 |
PIL.Image.Image,
|
|
@@ -31,7 +32,6 @@ PipelineImageInput = Union[
|
|
| 31 |
|
| 32 |
### Download the pretrained model from huggingface and put it locally, then place the model in a local directory and specify the directory location.
|
| 33 |
class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
| 34 |
-
|
| 35 |
@validate_hf_hub_args
|
| 36 |
def load_ConsistentID_model(
|
| 37 |
self,
|
|
@@ -65,7 +65,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
| 65 |
### BiSeNet
|
| 66 |
self.bise_net = BiSeNet(n_classes = 19)
|
| 67 |
self.bise_net.cuda()
|
| 68 |
-
self.bise_net_cp=
|
| 69 |
self.bise_net.load_state_dict(torch.load(self.bise_net_cp))
|
| 70 |
self.bise_net.eval()
|
| 71 |
# Colors for all 20 parts
|
|
@@ -124,11 +124,12 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
| 124 |
if weight_name.endswith(".safetensors"):
|
| 125 |
state_dict = {"id_encoder": {}, "lora_weights": {}}
|
| 126 |
with safe_open(model_file, framework="pt", device="cpu") as f:
|
|
|
|
| 127 |
for key in f.keys():
|
| 128 |
-
if key.startswith("
|
| 129 |
-
state_dict["
|
| 130 |
-
elif key.startswith("
|
| 131 |
-
state_dict["
|
| 132 |
else:
|
| 133 |
state_dict = torch.load(model_file, map_location="cpu")
|
| 134 |
else:
|
|
|
|
| 21 |
### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
|
| 22 |
### Thanks for the open source of face-parsing model.
|
| 23 |
from models.BiSeNet.model import BiSeNet
|
| 24 |
+
bise_net_cp_path = hf_hub_download(repo_id="JackAILab/ConsistentID", filename="face_parsing.pth", repo_type="model")
|
| 25 |
|
| 26 |
PipelineImageInput = Union[
|
| 27 |
PIL.Image.Image,
|
|
|
|
| 32 |
|
| 33 |
### Download the pretrained model from huggingface and put it locally, then place the model in a local directory and specify the directory location.
|
| 34 |
class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
|
|
| 35 |
@validate_hf_hub_args
|
| 36 |
def load_ConsistentID_model(
|
| 37 |
self,
|
|
|
|
| 65 |
### BiSeNet
|
| 66 |
self.bise_net = BiSeNet(n_classes = 19)
|
| 67 |
self.bise_net.cuda()
|
| 68 |
+
self.bise_net_cp=bise_net_cp_path
|
| 69 |
self.bise_net.load_state_dict(torch.load(self.bise_net_cp))
|
| 70 |
self.bise_net.eval()
|
| 71 |
# Colors for all 20 parts
|
|
|
|
| 124 |
if weight_name.endswith(".safetensors"):
|
| 125 |
state_dict = {"id_encoder": {}, "lora_weights": {}}
|
| 126 |
with safe_open(model_file, framework="pt", device="cpu") as f:
|
| 127 |
+
### TODO safetensors add
|
| 128 |
for key in f.keys():
|
| 129 |
+
if key.startswith("FacialEncoder."):
|
| 130 |
+
state_dict["FacialEncoder"][key.replace("FacialEncoder.", "")] = f.get_tensor(key)
|
| 131 |
+
elif key.startswith("image_proj."):
|
| 132 |
+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
| 133 |
else:
|
| 134 |
state_dict = torch.load(model_file, map_location="cpu")
|
| 135 |
else:
|