Anurag181011 commited on
Commit
c34a1b5
·
verified ·
1 Parent(s): 3c69e09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -62
app.py CHANGED
@@ -1,76 +1,120 @@
1
- import os
2
  import torch
 
3
  from diffusers import DiffusionPipeline
4
- import gradio as gr
 
 
 
 
5
 
6
- # Load the base model and apply the LoRA weights for super realism
 
7
  def load_pipeline():
 
8
  base_model = "black-forest-labs/FLUX.1-dev"
9
- lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
10
- trigger_word = "Super Realism" # Recommended trigger word
11
-
12
-
13
-
14
  pipe = DiffusionPipeline.from_pretrained(
15
  base_model,
16
- torch_dtype=torch.bfloat16,
17
-
18
  )
19
-
20
- # Load the LoRA weights into the pipeline
 
21
  pipe.load_lora_weights(lora_repo)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Use GPU if available
24
- device = "cuda" if torch.cuda.is_available() else "cpu"
25
- pipe.to(device)
26
- return pipe
27
-
28
- # Instantiate the pipeline once on Space startup
29
- pipe = load_pipeline()
30
-
31
- # Define a function for image generation
32
- def generate_image(prompt, seed, width, height, guidance_scale, randomize_seed):
33
- # If randomize_seed is selected, allow the model to generate a random seed
34
- if randomize_seed:
35
- seed = None
36
 
37
- # Ensure the prompt includes realism trigger words if needed
38
- if "realistic" not in prompt.lower() and "realism" not in prompt.lower():
39
- prompt += " realistic, realism"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # Generate the image
42
- output = pipe(
43
- prompt=prompt,
44
- seed=seed,
45
- width=width,
46
- height=height,
47
- guidance_scale=guidance_scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  )
49
- return output.images[0]
50
-
51
- # Set up Gradio interface
52
- iface = gr.Interface(
53
- fn=generate_image,
54
- inputs=[
55
- gr.Textbox(
56
- lines=2,
57
- label="Prompt",
58
- placeholder="Enter your prompt, e.g., 'A tiny astronaut hatching from an egg on the moon, 4k, planet theme'"
59
- ),
60
- gr.Slider(0, 10000, step=1, value=0, label="Seed (0 for random)"),
61
- gr.Slider(256, 1024, step=64, value=1024, label="Width"),
62
- gr.Slider(256, 1024, step=64, value=1024, label="Height"),
63
- gr.Slider(1, 20, step=0.5, value=6, label="Guidance Scale"),
64
- gr.Checkbox(value=True, label="Randomize Seed")
65
- ],
66
- outputs=gr.Image(type="pil", label="Generated Image"),
67
- title="Flux Super Realism LoRA Demo",
68
- description=(
69
- "This demo uses the Flux Super Realism LoRA model for ultra-realistic image generation. "
70
- "You can use the trigger word 'Super Realism' (recommended) along with other realism-related words "
71
- "to guide the generation process."
72
- ),
73
- )
74
 
75
- if __name__ == "__main__":
76
- iface.launch(share=False)
 
 
1
+ import gradio as gr
2
  import torch
3
+ import numpy as np
4
  from diffusers import DiffusionPipeline
5
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
+ from functools import lru_cache
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+
10
 
11
+ # Cache pipeline loading to improve performance
12
+ @lru_cache(maxsize=1)
13
  def load_pipeline():
14
+ # Load base model
15
  base_model = "black-forest-labs/FLUX.1-dev"
 
 
 
 
 
16
  pipe = DiffusionPipeline.from_pretrained(
17
  base_model,
18
+ torch_dtype=torch.bfloat16
 
19
  )
20
+
21
+ # Load LoRA weights
22
+ lora_repo = "strangerzonehf/Flux-Super-Realism-LoRA"
23
  pipe.load_lora_weights(lora_repo)
24
+
25
+ # Load safety checker
26
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
27
+ "CompVis/stable-diffusion-safety-checker"
28
+ )
29
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
30
+ "openai/clip-vit-base-patch32"
31
+ )
32
+
33
+ # Optimizations
34
+ pipe.enable_xformers_memory_efficient_attention()
35
+ pipe = pipe.to("cuda")
36
+
37
+ return pipe, safety_checker, feature_extractor
38
 
39
+ pipe, safety_checker, feature_extractor = load_pipeline()
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ def generate_image(
42
+ prompt,
43
+ seed=42,
44
+ width=1024,
45
+ height=1024,
46
+ guidance_scale=6,
47
+ steps=28,
48
+ progress=gr.Progress()
49
+ ):
50
+ try:
51
+ progress(0, desc="Initializing...")
52
+ generator = torch.Generator(device="cuda").manual_seed(seed)
53
+
54
+ # Auto-add trigger words
55
+ if "super realism" not in prompt.lower():
56
+ prompt = f"Super Realism, {prompt}"
57
+
58
+ # Create callback for progress updates
59
+ def update_progress(step, _, __):
60
+ progress((step + 1) / steps, desc="Generating image...")
61
+
62
+ # Generate image
63
+ with torch.inference_mode():
64
+ image = pipe(
65
+ prompt=prompt,
66
+ width=width,
67
+ height=height,
68
+ guidance_scale=guidance_scale,
69
+ num_inference_steps=steps,
70
+ generator=generator,
71
+ callback=update_progress
72
+ ).images[0]
73
+
74
+ # Safety check
75
+ progress(1, desc="Safety checking...")
76
+ safety_input = feature_extractor(image, return_tensors="pt")
77
+ np_image = np.array(image)
78
+ safety_result = safety_checker(
79
+ images=[np_image],
80
+ clip_input=safety_input.pixel_values
81
+ )
82
+
83
+ if safety_result.nsfw[0]:
84
+ return Image.new("RGB", (512, 512)), "NSFW content detected"
85
+
86
+ return image, "Generation successful"
87
+
88
+ except Exception as e:
89
+ return Image.new("RGB", (512, 512)), f"Error: {str(e)}"
90
 
91
+ # Create Gradio interface with rate limiting
92
+ with gr.Blocks() as app:
93
+ gr.Markdown("# Flux Super Realism Generator")
94
+
95
+ with gr.Row():
96
+ with gr.Column():
97
+ prompt = gr.Textbox(label="Prompt", value="A portrait of a person")
98
+ seed = gr.Slider(0, 1000, value=42, label="Seed")
99
+ width = gr.Slider(512, 2048, value=1024, label="Width")
100
+ height = gr.Slider(512, 2048, value=1024, label="Height")
101
+ guidance = gr.Slider(1, 20, value=6, label="Guidance Scale")
102
+ steps = gr.Slider(10, 100, value=28, label="Steps")
103
+ submit = gr.Button("Generate")
104
+
105
+ with gr.Column():
106
+ output_image = gr.Image(label="Result", type="pil")
107
+ status = gr.Textbox(label="Status")
108
+
109
+ submit.click(
110
+ generate_image,
111
+ inputs=[prompt, seed, width, height, guidance, steps],
112
+ outputs=[output_image, status]
113
  )
114
+
115
+ # Rate limiting example (1 request every 30 seconds)
116
+ app.queue(concurrency_count=1, max_size=3).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ # For multiple GPU support (advanced)
119
+ # pipe.enable_model_cpu_offload()
120
+ # pipe.enable_sequential_cpu_offload()