# Copyright Philip Brown, ppbrown@github # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ########################################################################### # This pipeline attempts to use a model that has SDXL vae, T5 text encoder, # and SDXL unet. # At the present time, there are no pretrained models that give pleasing # output. So as yet, (2025/06/10) this pipeline is somewhat of a tech # demo proving that the pieces can at least be put together. # Hopefully, it will encourage someone with the hardware available to # throw enough resources into training one up. from typing import Optional import torch.nn as nn from transformers import ( CLIPImageProcessor, CLIPTokenizer, CLIPVisionModelWithProjection, T5EncoderModel, ) from diffusers import DiffusionPipeline, StableDiffusionXLPipeline from diffusers.image_processor import VaeImageProcessor from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.schedulers import KarrasDiffusionSchedulers # Note: At this time, the intent is to use the T5 encoder mentioned # below, with zero changes. # Therefore, the model deliberately does not store the T5 encoder model bytes, # (Since they are not unique!) # but instead takes advantage of huggingface hub cache loading T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly" # Caller is expected to load this, or equivalent, as model name for now # eg: pipe = StableDiffusionXL_T5Pipeline(SDXL_NAME) SDXL_NAME = "stabilityai/stable-diffusion-xl-base-1.0" class LinearWithDtype(nn.Linear): @property def dtype(self): return self.weight.dtype class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline): _expected_modules = [ "vae", "unet", "scheduler", "tokenizer", "image_encoder", "feature_extractor", "t5_encoder", "t5_projection", "t5_pooled_projection", ] _optional_components = [ "image_encoder", "feature_extractor", "t5_encoder", "t5_projection", "t5_pooled_projection", ] def __init__( self, vae: AutoencoderKL, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, tokenizer: CLIPTokenizer, t5_encoder=None, t5_projection=None, t5_pooled_projection=None, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, ): DiffusionPipeline.__init__(self) if t5_encoder is None: self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME, torch_dtype=unet.dtype) else: self.t5_encoder = t5_encoder # ----- build T5 4096 => 2048 dim projection ----- if t5_projection is None: self.t5_projection = LinearWithDtype(4096, 2048) # trainable else: self.t5_projection = t5_projection self.t5_projection.to(dtype=unet.dtype) # ----- build T5 4096 => 1280 dim projection ----- if t5_pooled_projection is None: self.t5_pooled_projection = LinearWithDtype(4096, 1280) # trainable else: self.t5_pooled_projection = t5_pooled_projection self.t5_pooled_projection.to(dtype=unet.dtype) print("dtype of Linear is ", self.t5_projection.dtype) self.register_modules( vae=vae, unet=unet, scheduler=scheduler, tokenizer=tokenizer, t5_encoder=self.t5_encoder, t5_projection=self.t5_projection, t5_pooled_projection=self.t5_pooled_projection, image_encoder=image_encoder, feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = ( self.unet.config.sample_size if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") else 128 ) self.watermark = None # Parts of original SDXL class complain if these attributes are not # at least PRESENT self.text_encoder = self.text_encoder_2 = None # ------------------------------------------------------------------ # Encode a text prompt (T5-XXL + 4096→2048 projection) # Returns exactly four tensors in the order SDXL’s __call__ expects. # ------------------------------------------------------------------ def encode_prompt( self, prompt, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: str | None = None, **_, ): """ Returns ------- prompt_embeds : Tensor [B, T, 2048] negative_prompt_embeds : Tensor [B, T, 2048] | None pooled_prompt_embeds : Tensor [B, 1280] negative_pooled_prompt_embeds: Tensor [B, 1280] | None where B = batch * num_images_per_prompt """ # --- helper to tokenize on the pipeline’s device ---------------- def _tok(text: str): tok_out = self.tokenizer( text, return_tensors="pt", padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, ).to(self.device) return tok_out.input_ids, tok_out.attention_mask # ---------- positive stream ------------------------------------- ids, mask = _tok(prompt) h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state # [b, T, 4096] tok_pos = self.t5_projection(h_pos) # [b, T, 2048] pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1)) # [b, 1280] # expand for multiple images per prompt tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0) pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0) # ---------- negative / CFG stream -------------------------------- if do_classifier_free_guidance: neg_text = "" if negative_prompt is None else negative_prompt ids_n, mask_n = _tok(neg_text) h_neg = self.t5_encoder(ids_n, attention_mask=mask_n).last_hidden_state tok_neg = self.t5_projection(h_neg) pool_neg = self.t5_pooled_projection(h_neg.mean(dim=1)) tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0) pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0) else: tok_neg = pool_neg = None # ----------------- final ordered return -------------------------- # 1) positive token embeddings # 2) negative token embeddings (or None) # 3) positive pooled embeddings # 4) negative pooled embeddings (or None) return tok_pos, tok_neg, pool_pos, pool_neg