File size: 3,569 Bytes
3b3ffae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# Initialize model and pipeline once at startup
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors"

# Load model with float32 precision for CPU compatibility
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(torch.float32)
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cpu"))

# Create pipeline with CPU configuration
pipe = StableDiffusionXLPipeline.from_pretrained(
    base, 
    unet=unet, 
    torch_dtype=torch.float32
).to("cpu")

# Configure scheduler
pipe.scheduler = EulerDiscreteScheduler.from_config(
    pipe.scheduler.config, 
    timestep_spacing="trailing"
)

# Expanded list of predefined elements
elements_list = [
    "Kittens", "Tea", "Home", "Snow", "Young Girl", "Stars",
    "Blanket", "Books", "Candles", "Flowers", "Moon", "Cookies",
    "Fireplace", "Pillows", "Mittens", "Lanterns", "Socks",
    "Hot Chocolate", "Snowflakes", "Winter Scarf", "Marshmallows",
    "Vintage Clock", "Knitted Sweater", "Fairy Lights", "Porcelain Cup"
]

def generate_image(custom_text, elements, steps):
    """Generate image using the provided text, selected elements, and steps"""
    # Construct the prompt
    prompt_parts = []
    if custom_text.strip():
        prompt_parts.append(custom_text)
    if elements:
        prompt_parts.append(", ".join(elements))
    
    prompt = ", ".join(prompt_parts) or "a beautiful image" 
    
    image = pipe(
        prompt,
        num_inference_steps=int(steps),
        guidance_scale=0,
        width=768,
        height=960
    ).images[0]
    
    return image

# Create Gradio interface
with gr.Blocks(title="Good Night Image Diffuser") as demo:
    gr.Markdown("# 🌙 Generate Good Night Wish Images")
    gr.Markdown("Create personalized good night images with your message and favorite elements!")
    
    with gr.Row():
        with gr.Column(scale=1):
            custom_text = gr.Textbox(
                label="Your Message",
                value="Create a cozy and heartwarming scene. Use a warm, pastel color palette with soft shadows and subtle textures to evoke comfort and nostalgia. Additional elements to include:",
                max_lines=3
            )
            elements = gr.CheckboxGroup(
                label="Image Elements",
                choices=elements_list,
                value=["Kittens", "Moon"],
                info="Select elements to include in your image"
            )
            steps_slider = gr.Slider(
                label="Number of Inference Steps",
                minimum=1,
                maximum=8,
                value=4,
                step=2,
                info="Adjust the number of denoising steps (more steps can improve quality but take longer)"
            )
            generate_btn = gr.Button("✨ Generate Image", variant="primary")
        
        with gr.Column(scale=1):
            output_image = gr.Image(
                label="Generated Image",
                width=768,
                height=960,
                elem_id="output-image"
            )

    # Connect components
    generate_btn.click(
        fn=generate_image,
        inputs=[custom_text, elements, steps_slider],
        outputs=output_image,
        api_name="generate"
    )

if __name__ == "__main__":
    demo.launch()