Text-to-Image
Diffusers
Safetensors
tolgacangoz commited on
Commit
fa0e834
·
verified ·
1 Parent(s): 028489c

Upload anytext.py

Browse files
Files changed (1) hide show
  1. anytext.py +28 -24
anytext.py CHANGED
@@ -334,13 +334,12 @@ def crop_image(src_img, mask):
334
  return result
335
 
336
 
337
- def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=False):
338
- if model_dir is None or not os.path.exists(model_dir):
339
- model_dir = hf_hub_download(
340
- repo_id="tolgacangoz/anytext",
341
- filename="text_embedding_module/OCR/ppv3_rec.pth",
342
- cache_dir=HF_MODULES_CACHE,
343
- )
344
  if not os.path.exists(model_dir):
345
  raise ValueError("not find model file path {}".format(model_dir))
346
 
@@ -540,16 +539,17 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
540
 
541
  def __init__(
542
  self,
543
- version="openai/clip-vit-large-patch14",
544
  device="cpu",
545
  max_length=77,
546
  freeze=True,
547
  use_fp16=False,
 
548
  ):
549
  super().__init__()
550
- self.tokenizer = CLIPTokenizer.from_pretrained(version)
551
  self.transformer = CLIPTextModel.from_pretrained(
552
- version, use_safetensors=True, torch_dtype=torch.float16 if use_fp16 else torch.float32
 
553
  ).to(device)
554
  self.device = device
555
  self.max_length = max_length
@@ -746,8 +746,7 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
746
  self.device = device
747
  self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
748
  self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
749
- rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
750
- self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
751
  args = {}
752
  args["rec_image_shape"] = "3, 48, 320"
753
  args["rec_batch_num"] = 6
@@ -1045,7 +1044,8 @@ def retrieve_latents(
1045
  raise AttributeError("Could not access latents of provided encoder_output")
1046
 
1047
 
1048
- class AuxiliaryLatentModule(nn.Module):
 
1049
  def __init__(
1050
  self,
1051
  font_path,
@@ -1229,7 +1229,7 @@ class AnyTextPipeline(
1229
  Args:
1230
  vae ([`AutoencoderKL`]):
1231
  Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
1232
- text_encoder ([`~anytext.TextEmbeddingModule`]):
1233
  Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
1234
  tokenizer ([`~transformers.CLIPTokenizer`]):
1235
  A `CLIPTokenizer` to tokenize text.
@@ -1259,7 +1259,7 @@ class AnyTextPipeline(
1259
  self,
1260
  font_path: str,
1261
  vae: AutoencoderKL,
1262
- text_encoder: TextEmbeddingModule,
1263
  tokenizer: CLIPTokenizer,
1264
  unet: UNet2DConditionModel,
1265
  controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
@@ -1267,17 +1267,21 @@ class AnyTextPipeline(
1267
  safety_checker: StableDiffusionSafetyChecker,
1268
  feature_extractor: CLIPImageProcessor,
1269
  auxiliary_latent_module: AuxiliaryLatentModule,
 
1270
  trust_remote_code: bool = False,
1271
  image_encoder: CLIPVisionModelWithProjection = None,
1272
  requires_safety_checker: bool = True,
1273
  ):
1274
  super().__init__()
1275
- # self.text_embedding_module = TextEmbeddingModule(
1276
- # use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
1277
- # )
1278
- # self.auxiliary_latent_module = AuxiliaryLatentModule(
1279
- # vae=vae, use_fp16=unet.dtype == torch.float16, device=unet.device, font_path=font_path
1280
- # )
 
 
 
1281
 
1282
  if safety_checker is None and requires_safety_checker:
1283
  logger.warning(
@@ -1308,7 +1312,7 @@ class AnyTextPipeline(
1308
  safety_checker=safety_checker,
1309
  feature_extractor=feature_extractor,
1310
  image_encoder=image_encoder,
1311
- # text_embedding_module=self.text_embedding_module,
1312
  auxiliary_latent_module=auxiliary_latent_module,
1313
  )
1314
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
@@ -2177,7 +2181,7 @@ class AnyTextPipeline(
2177
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
2178
  )
2179
  draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos
2180
- prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_encoder(
2181
  prompt,
2182
  texts,
2183
  negative_prompt,
@@ -2419,6 +2423,6 @@ class AnyTextPipeline(
2419
 
2420
  def to(self, *args, **kwargs):
2421
  super().to(*args, **kwargs)
2422
- # self.text_embedding_module.to(*args, **kwargs)
2423
  self.auxiliary_latent_module.to(*args, **kwargs)
2424
  return self
 
334
  return result
335
 
336
 
337
+ def create_predictor(model_lang="ch", device="cpu", use_fp16=False):
338
+ model_dir = hf_hub_download(
339
+ repo_id="tolgacangoz/anytext",
340
+ filename="text_embedding_module/OCR/ppv3_rec.pth",
341
+ cache_dir=HF_MODULES_CACHE,
342
+ )
 
343
  if not os.path.exists(model_dir):
344
  raise ValueError("not find model file path {}".format(model_dir))
345
 
 
539
 
540
  def __init__(
541
  self,
 
542
  device="cpu",
543
  max_length=77,
544
  freeze=True,
545
  use_fp16=False,
546
+ variant="fp32",
547
  ):
548
  super().__init__()
549
+ self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
550
  self.transformer = CLIPTextModel.from_pretrained(
551
+ "tolgacangoz/anytext", subfolder="text_encoder", use_safetensors=True,
552
+ torch_dtype=torch.float16 if use_fp16 else torch.float32, variant=variant,
553
  ).to(device)
554
  self.device = device
555
  self.max_length = max_length
 
746
  self.device = device
747
  self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
748
  self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
749
+ self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval()
 
750
  args = {}
751
  args["rec_image_shape"] = "3, 48, 320"
752
  args["rec_batch_num"] = 6
 
1044
  raise AttributeError("Could not access latents of provided encoder_output")
1045
 
1046
 
1047
+ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
1048
+ @register_to_config
1049
  def __init__(
1050
  self,
1051
  font_path,
 
1229
  Args:
1230
  vae ([`AutoencoderKL`]):
1231
  Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
1232
+ text_encoder ([`~transformers.CLIPTextModel`]):
1233
  Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
1234
  tokenizer ([`~transformers.CLIPTokenizer`]):
1235
  A `CLIPTokenizer` to tokenize text.
 
1259
  self,
1260
  font_path: str,
1261
  vae: AutoencoderKL,
1262
+ text_encoder: CLIPTextModel,
1263
  tokenizer: CLIPTokenizer,
1264
  unet: UNet2DConditionModel,
1265
  controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
 
1267
  safety_checker: StableDiffusionSafetyChecker,
1268
  feature_extractor: CLIPImageProcessor,
1269
  auxiliary_latent_module: AuxiliaryLatentModule,
1270
+ text_embedding_module: TextEmbeddingModule,
1271
  trust_remote_code: bool = False,
1272
  image_encoder: CLIPVisionModelWithProjection = None,
1273
  requires_safety_checker: bool = True,
1274
  ):
1275
  super().__init__()
1276
+ self.text_embedding_module = TextEmbeddingModule(
1277
+ font_path=font_path,
1278
+ # use_fp16=unet.dtype == torch.float16, device=unet.device,
1279
+ )
1280
+ self.auxiliary_latent_module = AuxiliaryLatentModule(
1281
+ font_path=font_path,
1282
+ vae=vae,
1283
+ # use_fp16=unet.dtype == torch.float16, device=unet.device,
1284
+ )
1285
 
1286
  if safety_checker is None and requires_safety_checker:
1287
  logger.warning(
 
1312
  safety_checker=safety_checker,
1313
  feature_extractor=feature_extractor,
1314
  image_encoder=image_encoder,
1315
+ text_embedding_module=text_embedding_module,
1316
  auxiliary_latent_module=auxiliary_latent_module,
1317
  )
1318
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
 
2181
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
2182
  )
2183
  draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos
2184
+ prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module(
2185
  prompt,
2186
  texts,
2187
  negative_prompt,
 
2423
 
2424
  def to(self, *args, **kwargs):
2425
  super().to(*args, **kwargs)
2426
+ self.text_embedding_module.to(*args, **kwargs)
2427
  self.auxiliary_latent_module.to(*args, **kwargs)
2428
  return self