ford442 commited on
Commit
0e803ed
·
verified ·
1 Parent(s): 4c3e085

Update ip_adapter/ip_adapter.py

Browse files
Files changed (1) hide show
  1. ip_adapter/ip_adapter.py +6 -6
ip_adapter/ip_adapter.py CHANGED
@@ -45,7 +45,7 @@ class IPAdapter:
45
  self.set_ip_adapter()
46
 
47
  # load image encoder
48
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16)
49
  self.clip_image_processor = CLIPImageProcessor()
50
  # image proj model
51
  self.image_proj_model = self.init_proj()
@@ -56,7 +56,7 @@ class IPAdapter:
56
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
57
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
58
  clip_extra_context_tokens=self.num_tokens,
59
- ).to(self.device, dtype=torch.float16)
60
  return image_proj_model
61
 
62
  def set_ip_adapter(self):
@@ -76,7 +76,7 @@ class IPAdapter:
76
  attn_procs[name] = AttnProcessor()
77
  else:
78
  attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
79
- scale=1.0,num_tokens= self.num_tokens).to(self.device, dtype=torch.float16)
80
  unet.set_attn_processor(attn_procs)
81
  if hasattr(self.pipe, "controlnet"):
82
  if isinstance(self.pipe.controlnet, MultiControlNetModel):
@@ -113,7 +113,7 @@ class IPAdapter:
113
  if isinstance(pil_image, Image.Image):
114
  pil_image = [pil_image]
115
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
116
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
117
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
118
  uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
119
  return image_prompt_embeds, uncond_image_prompt_embeds
@@ -257,7 +257,7 @@ class IPAdapterPlus(IPAdapter):
257
  embedding_dim=self.image_encoder.config.hidden_size,
258
  output_dim=self.pipe.unet.config.cross_attention_dim,
259
  ff_mult=4
260
- ).to(self.device, dtype=torch.float16)
261
  return image_proj_model
262
 
263
  @torch.inference_mode()
@@ -265,7 +265,7 @@ class IPAdapterPlus(IPAdapter):
265
  if isinstance(pil_image, Image.Image):
266
  pil_image = [pil_image]
267
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
268
- clip_image = clip_image.to(self.device, dtype=torch.float16)
269
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
270
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
271
  uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
 
45
  self.set_ip_adapter()
46
 
47
  # load image encoder
48
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.bfloat16)
49
  self.clip_image_processor = CLIPImageProcessor()
50
  # image proj model
51
  self.image_proj_model = self.init_proj()
 
56
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
57
  clip_embeddings_dim=self.image_encoder.config.projection_dim,
58
  clip_extra_context_tokens=self.num_tokens,
59
+ ).to(self.device, dtype=torch.bfloat16)
60
  return image_proj_model
61
 
62
  def set_ip_adapter(self):
 
76
  attn_procs[name] = AttnProcessor()
77
  else:
78
  attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,
79
+ scale=1.0,num_tokens= self.num_tokens).to(self.device, dtype=torch.bfloat16)
80
  unet.set_attn_processor(attn_procs)
81
  if hasattr(self.pipe, "controlnet"):
82
  if isinstance(self.pipe.controlnet, MultiControlNetModel):
 
113
  if isinstance(pil_image, Image.Image):
114
  pil_image = [pil_image]
115
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
116
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.bfloat16)).image_embeds
117
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
118
  uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
119
  return image_prompt_embeds, uncond_image_prompt_embeds
 
257
  embedding_dim=self.image_encoder.config.hidden_size,
258
  output_dim=self.pipe.unet.config.cross_attention_dim,
259
  ff_mult=4
260
+ ).to(self.device, dtype=torch.bfloat16)
261
  return image_proj_model
262
 
263
  @torch.inference_mode()
 
265
  if isinstance(pil_image, Image.Image):
266
  pil_image = [pil_image]
267
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
268
+ clip_image = clip_image.to(self.device, dtype=torch.bfloat16)
269
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
270
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
271
  uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]