import torch from diffusers import StableDiffusion3Pipeline from huggingface_hub import login import os import gradio as gr from diffusers import BitsAndBytesConfig from diffusers import SD3Transformer2DModel # Retrieve the token from the environment variable token = os.getenv("HF_TOKEN") # Hugging Face token from the secret if token: login(token=token) # Log in with the retrieved token else: raise ValueError("Hugging Face token not found. Please set it as a repository secret in the Space settings.") # Define quantization configuration (4-bit quantization) quant_config = BitsAndBytesConfig( load_in_4bit=True, # Enable 4-bit quantization bnb_4bit_quant_type="nf4", # Choose the quantization type (nf4 is often used for high-quality quantization) bnb_4bit_compute_dtype=torch.bfloat16 # Use bfloat16 for computation (works well with CPUs) ) # Load the Stable Diffusion 3.5 model with quantization model_id = "stabilityai/stable-diffusion-3.5-large" model = SD3Transformer2DModel.from_pretrained( model_id, subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.bfloat16 # Ensure the model uses bfloat16 dtype for computation ) # Load the pipeline with the quantized model pipe = StableDiffusion3Pipeline.from_pretrained( model_id, transformer=model, torch_dtype=torch.bfloat16 # Ensuring the pipeline uses bfloat16 ) pipe.to("cpu") # Ensuring it runs on CPU # Define the path to the LoRA model lora_model_path = "./lora_model.pth" # Assuming the file is saved locally # Custom method to load and apply LoRA weights to the Stable Diffusion pipeline def load_lora_model(pipe, lora_model_path): # Load the LoRA weights lora_weights = torch.load(lora_model_path, map_location="cpu") # Apply weights to the UNet submodule for name, param in pipe.unet.named_parameters(): # Accessing unet parameters if name in lora_weights: param.data += lora_weights[name] return pipe # Load and apply the LoRA model weights pipe = load_lora_model(pipe, lora_model_path) # Function to generate an image from a text prompt def generate_image(prompt): image = pipe(prompt).images[0] return image # Gradio interface iface = gr.Interface(fn=generate_image, inputs="text", outputs="image") iface.launch()