Upload anytext.py
Browse files- anytext.py +28 -24
anytext.py
CHANGED
@@ -334,13 +334,12 @@ def crop_image(src_img, mask):
|
|
334 |
return result
|
335 |
|
336 |
|
337 |
-
def create_predictor(
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
)
|
344 |
if not os.path.exists(model_dir):
|
345 |
raise ValueError("not find model file path {}".format(model_dir))
|
346 |
|
@@ -540,16 +539,17 @@ class FrozenCLIPEmbedderT3(AbstractEncoder):
|
|
540 |
|
541 |
def __init__(
|
542 |
self,
|
543 |
-
version="openai/clip-vit-large-patch14",
|
544 |
device="cpu",
|
545 |
max_length=77,
|
546 |
freeze=True,
|
547 |
use_fp16=False,
|
|
|
548 |
):
|
549 |
super().__init__()
|
550 |
-
self.tokenizer = CLIPTokenizer.from_pretrained(
|
551 |
self.transformer = CLIPTextModel.from_pretrained(
|
552 |
-
|
|
|
553 |
).to(device)
|
554 |
self.device = device
|
555 |
self.max_length = max_length
|
@@ -746,8 +746,7 @@ class TextEmbeddingModule(ModelMixin, ConfigMixin):
|
|
746 |
self.device = device
|
747 |
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
|
748 |
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
|
749 |
-
|
750 |
-
self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
|
751 |
args = {}
|
752 |
args["rec_image_shape"] = "3, 48, 320"
|
753 |
args["rec_batch_num"] = 6
|
@@ -1045,7 +1044,8 @@ def retrieve_latents(
|
|
1045 |
raise AttributeError("Could not access latents of provided encoder_output")
|
1046 |
|
1047 |
|
1048 |
-
class AuxiliaryLatentModule(
|
|
|
1049 |
def __init__(
|
1050 |
self,
|
1051 |
font_path,
|
@@ -1229,7 +1229,7 @@ class AnyTextPipeline(
|
|
1229 |
Args:
|
1230 |
vae ([`AutoencoderKL`]):
|
1231 |
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
1232 |
-
text_encoder ([`~
|
1233 |
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
1234 |
tokenizer ([`~transformers.CLIPTokenizer`]):
|
1235 |
A `CLIPTokenizer` to tokenize text.
|
@@ -1259,7 +1259,7 @@ class AnyTextPipeline(
|
|
1259 |
self,
|
1260 |
font_path: str,
|
1261 |
vae: AutoencoderKL,
|
1262 |
-
text_encoder:
|
1263 |
tokenizer: CLIPTokenizer,
|
1264 |
unet: UNet2DConditionModel,
|
1265 |
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
@@ -1267,17 +1267,21 @@ class AnyTextPipeline(
|
|
1267 |
safety_checker: StableDiffusionSafetyChecker,
|
1268 |
feature_extractor: CLIPImageProcessor,
|
1269 |
auxiliary_latent_module: AuxiliaryLatentModule,
|
|
|
1270 |
trust_remote_code: bool = False,
|
1271 |
image_encoder: CLIPVisionModelWithProjection = None,
|
1272 |
requires_safety_checker: bool = True,
|
1273 |
):
|
1274 |
super().__init__()
|
1275 |
-
|
1276 |
-
|
1277 |
-
|
1278 |
-
|
1279 |
-
|
1280 |
-
|
|
|
|
|
|
|
1281 |
|
1282 |
if safety_checker is None and requires_safety_checker:
|
1283 |
logger.warning(
|
@@ -1308,7 +1312,7 @@ class AnyTextPipeline(
|
|
1308 |
safety_checker=safety_checker,
|
1309 |
feature_extractor=feature_extractor,
|
1310 |
image_encoder=image_encoder,
|
1311 |
-
|
1312 |
auxiliary_latent_module=auxiliary_latent_module,
|
1313 |
)
|
1314 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
@@ -2177,7 +2181,7 @@ class AnyTextPipeline(
|
|
2177 |
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
2178 |
)
|
2179 |
draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos
|
2180 |
-
prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.
|
2181 |
prompt,
|
2182 |
texts,
|
2183 |
negative_prompt,
|
@@ -2419,6 +2423,6 @@ class AnyTextPipeline(
|
|
2419 |
|
2420 |
def to(self, *args, **kwargs):
|
2421 |
super().to(*args, **kwargs)
|
2422 |
-
|
2423 |
self.auxiliary_latent_module.to(*args, **kwargs)
|
2424 |
return self
|
|
|
334 |
return result
|
335 |
|
336 |
|
337 |
+
def create_predictor(model_lang="ch", device="cpu", use_fp16=False):
|
338 |
+
model_dir = hf_hub_download(
|
339 |
+
repo_id="tolgacangoz/anytext",
|
340 |
+
filename="text_embedding_module/OCR/ppv3_rec.pth",
|
341 |
+
cache_dir=HF_MODULES_CACHE,
|
342 |
+
)
|
|
|
343 |
if not os.path.exists(model_dir):
|
344 |
raise ValueError("not find model file path {}".format(model_dir))
|
345 |
|
|
|
539 |
|
540 |
def __init__(
|
541 |
self,
|
|
|
542 |
device="cpu",
|
543 |
max_length=77,
|
544 |
freeze=True,
|
545 |
use_fp16=False,
|
546 |
+
variant="fp32",
|
547 |
):
|
548 |
super().__init__()
|
549 |
+
self.tokenizer = CLIPTokenizer.from_pretrained("tolgacangoz/anytext", subfolder="tokenizer")
|
550 |
self.transformer = CLIPTextModel.from_pretrained(
|
551 |
+
"tolgacangoz/anytext", subfolder="text_encoder", use_safetensors=True,
|
552 |
+
torch_dtype=torch.float16 if use_fp16 else torch.float32, variant=variant,
|
553 |
).to(device)
|
554 |
self.device = device
|
555 |
self.max_length = max_length
|
|
|
746 |
self.device = device
|
747 |
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
|
748 |
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
|
749 |
+
self.text_predictor = create_predictor(device=device, use_fp16=use_fp16).eval()
|
|
|
750 |
args = {}
|
751 |
args["rec_image_shape"] = "3, 48, 320"
|
752 |
args["rec_batch_num"] = 6
|
|
|
1044 |
raise AttributeError("Could not access latents of provided encoder_output")
|
1045 |
|
1046 |
|
1047 |
+
class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
|
1048 |
+
@register_to_config
|
1049 |
def __init__(
|
1050 |
self,
|
1051 |
font_path,
|
|
|
1229 |
Args:
|
1230 |
vae ([`AutoencoderKL`]):
|
1231 |
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
1232 |
+
text_encoder ([`~transformers.CLIPTextModel`]):
|
1233 |
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
1234 |
tokenizer ([`~transformers.CLIPTokenizer`]):
|
1235 |
A `CLIPTokenizer` to tokenize text.
|
|
|
1259 |
self,
|
1260 |
font_path: str,
|
1261 |
vae: AutoencoderKL,
|
1262 |
+
text_encoder: CLIPTextModel,
|
1263 |
tokenizer: CLIPTokenizer,
|
1264 |
unet: UNet2DConditionModel,
|
1265 |
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
|
|
|
1267 |
safety_checker: StableDiffusionSafetyChecker,
|
1268 |
feature_extractor: CLIPImageProcessor,
|
1269 |
auxiliary_latent_module: AuxiliaryLatentModule,
|
1270 |
+
text_embedding_module: TextEmbeddingModule,
|
1271 |
trust_remote_code: bool = False,
|
1272 |
image_encoder: CLIPVisionModelWithProjection = None,
|
1273 |
requires_safety_checker: bool = True,
|
1274 |
):
|
1275 |
super().__init__()
|
1276 |
+
self.text_embedding_module = TextEmbeddingModule(
|
1277 |
+
font_path=font_path,
|
1278 |
+
# use_fp16=unet.dtype == torch.float16, device=unet.device,
|
1279 |
+
)
|
1280 |
+
self.auxiliary_latent_module = AuxiliaryLatentModule(
|
1281 |
+
font_path=font_path,
|
1282 |
+
vae=vae,
|
1283 |
+
# use_fp16=unet.dtype == torch.float16, device=unet.device,
|
1284 |
+
)
|
1285 |
|
1286 |
if safety_checker is None and requires_safety_checker:
|
1287 |
logger.warning(
|
|
|
1312 |
safety_checker=safety_checker,
|
1313 |
feature_extractor=feature_extractor,
|
1314 |
image_encoder=image_encoder,
|
1315 |
+
text_embedding_module=text_embedding_module,
|
1316 |
auxiliary_latent_module=auxiliary_latent_module,
|
1317 |
)
|
1318 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
|
|
2181 |
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
2182 |
)
|
2183 |
draw_pos = draw_pos.to(device=device) if isinstance(draw_pos, torch.Tensor) else draw_pos
|
2184 |
+
prompt_embeds, negative_prompt_embeds, text_info, np_hint = self.text_embedding_module(
|
2185 |
prompt,
|
2186 |
texts,
|
2187 |
negative_prompt,
|
|
|
2423 |
|
2424 |
def to(self, *args, **kwargs):
|
2425 |
super().to(*args, **kwargs)
|
2426 |
+
self.text_embedding_module.to(*args, **kwargs)
|
2427 |
self.auxiliary_latent_module.to(*args, **kwargs)
|
2428 |
return self
|