Spaces:
Runtime error
Runtime error
Update ip_adapter/ip_adapter.py
Browse files- ip_adapter/ip_adapter.py +12 -10
ip_adapter/ip_adapter.py
CHANGED
|
@@ -117,15 +117,7 @@ class IPAdapter:
|
|
| 117 |
if isinstance(attn_processor, IPAttnProcessor):
|
| 118 |
attn_processor.scale = scale
|
| 119 |
|
| 120 |
-
|
| 121 |
-
for attn_processor in self.pipe.unet.attn_processors.values():
|
| 122 |
-
if isinstance(attn_processor, IPAttnProcessor):
|
| 123 |
-
print('IP attn_scale:')
|
| 124 |
-
print(attn_processor.scale)
|
| 125 |
-
if isinstance(attn_processor):
|
| 126 |
-
print('UNET attn_scale:')
|
| 127 |
-
print(attn_processor.scale)
|
| 128 |
-
|
| 129 |
def generate(
|
| 130 |
self,
|
| 131 |
pil_image,
|
|
@@ -138,7 +130,6 @@ class IPAdapter:
|
|
| 138 |
num_inference_steps=30,
|
| 139 |
**kwargs,
|
| 140 |
):
|
| 141 |
-
self.get_scale()
|
| 142 |
self.set_scale(scale)
|
| 143 |
|
| 144 |
if isinstance(pil_image, List):
|
|
@@ -193,6 +184,16 @@ class IPAdapter:
|
|
| 193 |
class IPAdapterXL(IPAdapter):
|
| 194 |
"""SDXL"""
|
| 195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
def generate(
|
| 197 |
self,
|
| 198 |
pil_image_1,
|
|
@@ -213,6 +214,7 @@ class IPAdapterXL(IPAdapter):
|
|
| 213 |
guidance_scale=7.5,
|
| 214 |
**kwargs,
|
| 215 |
):
|
|
|
|
| 216 |
self.set_scale(scale_1)
|
| 217 |
|
| 218 |
if isinstance(pil_image_1, Image.Image):
|
|
|
|
| 117 |
if isinstance(attn_processor, IPAttnProcessor):
|
| 118 |
attn_processor.scale = scale
|
| 119 |
|
| 120 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
def generate(
|
| 122 |
self,
|
| 123 |
pil_image,
|
|
|
|
| 130 |
num_inference_steps=30,
|
| 131 |
**kwargs,
|
| 132 |
):
|
|
|
|
| 133 |
self.set_scale(scale)
|
| 134 |
|
| 135 |
if isinstance(pil_image, List):
|
|
|
|
| 184 |
class IPAdapterXL(IPAdapter):
|
| 185 |
"""SDXL"""
|
| 186 |
|
| 187 |
+
def get_scale(self):
|
| 188 |
+
for attn_processor in self.pipe.unet.attn_processors.values():
|
| 189 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
| 190 |
+
print('IP attn_scale:')
|
| 191 |
+
print(attn_processor.scale)
|
| 192 |
+
for attn_processor in self.pipe.unet.attn_processors.values():
|
| 193 |
+
if isinstance(attn_processor):
|
| 194 |
+
print('UNET attn_scale:')
|
| 195 |
+
print(attn_processor.scale)
|
| 196 |
+
|
| 197 |
def generate(
|
| 198 |
self,
|
| 199 |
pil_image_1,
|
|
|
|
| 214 |
guidance_scale=7.5,
|
| 215 |
**kwargs,
|
| 216 |
):
|
| 217 |
+
self.get_scale()
|
| 218 |
self.set_scale(scale_1)
|
| 219 |
|
| 220 |
if isinstance(pil_image_1, Image.Image):
|