gokaygokay commited on
Commit
6de338c
·
verified ·
1 Parent(s): 2e5c176

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -156
app.py CHANGED
@@ -1,179 +1,174 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
  import spaces
 
5
  import torch
6
- from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderTiny, AutoencoderKL
7
- from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
8
- from huggingface_hub import hf_hub_download
9
- from optimum.quanto import freeze, qfloat8, quantize
10
- from live_preview_helpers import flux_pipe_call_that_returns_an_iterable_of_images
11
  import os
 
 
12
 
13
- MAX_SEED = np.iinfo(np.int32).max
14
- MAX_IMAGE_SIZE = 2048
15
- # Set up environment variables and device
16
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
17
- dtype = torch.bfloat16
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
19
 
20
- # Load VAE models
21
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
22
- good_vae = AutoencoderKL.from_pretrained(
23
- "black-forest-labs/FLUX.1-dev",
24
- subfolder="vae",
25
- torch_dtype=dtype,
26
- token=huggingface_token
27
- ).to(device)
 
 
 
28
 
29
- # Initialize FluxPipeline instead of DiffusionPipeline
30
- from pipelines import FluxPipeline
 
31
 
32
  pipe = FluxPipeline.from_pretrained(
33
  "black-forest-labs/FLUX.1-dev",
34
- torch_dtype=dtype,
35
- vae=taef1,
36
  token=huggingface_token
37
- ).to(device)
38
-
39
- # Load and fuse LoRA BEFORE quantizing
40
- print('Loading and fusing LoRA, please wait...')
41
- lora_path = hf_hub_download("gokaygokay/Flux-Game-Assets-LoRA-v2", "game_asst.safetensors")
42
- pipe.load_lora_weights(lora_path)
43
- pipe.fuse_lora(lora_scale=0.125)
44
- pipe.unload_lora_weights()
45
-
46
- # Quantize the transformer
47
- print("Quantizing transformer")
48
- quantize(pipe.transformer, weights=qfloat8)
49
- freeze(pipe.transformer)
50
-
51
- # Quantize the T5 text encoder
52
- print("Quantizing T5 text encoder")
53
- quantize(pipe.text_encoder_2, weights=qfloat8)
54
- freeze(pipe.text_encoder_2)
55
-
56
- # Move quantized components to device (if not already)
57
- pipe.transformer.to(device)
58
- pipe.text_encoder_2.to(device)
59
-
60
- # Move other components to device
61
- pipe.text_encoder.to(device, dtype=dtype)
62
- torch.cuda.empty_cache()
63
-
64
- @spaces.GPU(duration=75)
65
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if randomize_seed:
67
  seed = random.randint(0, MAX_SEED)
68
- generator = torch.Generator().manual_seed(seed)
69
 
70
- for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
71
- prompt=prompt,
72
- guidance_scale=guidance_scale,
73
- num_inference_steps=num_inference_steps,
74
- width=width,
75
- height=height,
76
- generator=generator,
77
- output_type="pil",
78
- good_vae=good_vae,
79
- ):
80
- yield img, seed
81
-
82
- examples = [
83
- "wbgmsst, a cat, white background",
84
- "wbgmsst, a warrior, white background",
85
- "wbgmsst, an anime girl, white background",
86
- ]
87
-
88
- css = """
89
- #col-container {
90
- margin: 0 auto;
91
- max-width: 520px;
92
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  """
94
 
95
- with gr.Blocks(css=css) as demo:
 
96
 
97
- with gr.Column(elem_id="col-container"):
98
- gr.Markdown(f"""# FLUX.1 [dev]
99
- 12B param rectified flow transformer guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
100
- [[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
101
- """)
102
-
103
- with gr.Row():
104
-
105
- prompt = gr.Text(
106
- label="Prompt",
107
- show_label=False,
108
- max_lines=1,
109
- placeholder="Enter your prompt",
110
- container=False,
111
- )
112
 
113
- run_button = gr.Button("Run", scale=0)
114
-
115
- result = gr.Image(label="Result", show_label=False)
116
-
117
- with gr.Accordion("Advanced Settings", open=False):
118
-
119
- seed = gr.Slider(
120
- label="Seed",
121
- minimum=0,
122
- maximum=MAX_SEED,
123
- step=1,
124
- value=0,
125
- )
126
 
127
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
128
-
129
- with gr.Row():
130
-
131
- width = gr.Slider(
132
- label="Width",
133
- minimum=256,
134
- maximum=MAX_IMAGE_SIZE,
135
- step=32,
136
- value=1024,
137
- )
138
-
139
- height = gr.Slider(
140
- label="Height",
141
- minimum=256,
142
- maximum=MAX_IMAGE_SIZE,
143
- step=32,
144
- value=1024,
145
- )
146
-
147
- with gr.Row():
148
- guidance_scale = gr.Slider(
149
- label="Guidance Scale",
150
- minimum=1,
151
- maximum=15,
152
- step=0.1,
153
- value=3.5,
154
- )
155
-
156
- num_inference_steps = gr.Slider(
157
- label="Number of inference steps",
158
- minimum=1,
159
- maximum=50,
160
- step=1,
161
- value=28,
162
- )
163
 
164
- gr.Examples(
165
- examples=examples,
166
- fn=infer,
167
- inputs=[prompt],
168
- outputs=[result, seed],
169
- cache_examples="lazy"
170
- )
171
-
172
- gr.on(
173
- triggers=[run_button.click, prompt.submit],
174
- fn=infer,
175
- inputs=[prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
176
- outputs=[result, seed]
177
  )
178
 
179
- demo.launch()
 
 
 
 
1
  import spaces
2
+ import gradio as gr
3
  import torch
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
6
+ from diffusers import DiffusionPipeline
7
+ import random
8
+ import numpy as np
9
  import os
10
+ import subprocess
11
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
+ # Initialize models
 
 
 
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ dtype = torch.bfloat16
16
 
17
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
18
+
19
+ import torch
20
+ from optimum.quanto import QuantizedDiffusersModel
21
+
22
+ from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
23
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
24
+
25
+
26
+ class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
27
+ base_class = FluxTransformer2DModel
28
 
29
+
30
+ transformer = QuantizedFluxTransformer2DModel.from_pretrained("Kijai/flux-fp8")
31
+ transformer.to(device="cuda", dtype=torch.bfloat16)
32
 
33
  pipe = FluxPipeline.from_pretrained(
34
  "black-forest-labs/FLUX.1-dev",
35
+ transformer=None,
36
+ torch_dtype=torch.bfloat16,
37
  token=huggingface_token
38
+ )
39
+
40
+ pipe.transformer = transformer
41
+
42
+ # Initialize Florence model
43
+ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
44
+ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
45
+
46
+ # Prompt Enhancer
47
+ enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
48
+
49
+ MAX_SEED = np.iinfo(np.int32).max
50
+ MAX_IMAGE_SIZE = 2048
51
+
52
+ # Florence caption function
53
+ @spaces.GPU
54
+ def florence_caption(image):
55
+ # Convert image to PIL if it's not already
56
+ if not isinstance(image, Image.Image):
57
+ image = Image.fromarray(image)
58
+
59
+ inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
60
+ generated_ids = florence_model.generate(
61
+ input_ids=inputs["input_ids"],
62
+ pixel_values=inputs["pixel_values"],
63
+ max_new_tokens=1024,
64
+ early_stopping=False,
65
+ do_sample=False,
66
+ num_beams=3,
67
+ )
68
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
69
+ parsed_answer = florence_processor.post_process_generation(
70
+ generated_text,
71
+ task="<MORE_DETAILED_CAPTION>",
72
+ image_size=(image.width, image.height)
73
+ )
74
+ return parsed_answer["<MORE_DETAILED_CAPTION>"]
75
+
76
+ # Prompt Enhancer function
77
+ def enhance_prompt(input_prompt):
78
+ result = enhancer_long("Enhance the description: " + input_prompt)
79
+ enhanced_text = result[0]['summary_text']
80
+ return enhanced_text
81
+
82
+ @spaces.GPU(duration=190)
83
+ def process_workflow(image, text_prompt, use_enhancer, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
84
+ if image is not None:
85
+ # Convert image to PIL if it's not already
86
+ if not isinstance(image, Image.Image):
87
+ image = Image.fromarray(image)
88
+
89
+ prompt = florence_caption(image)
90
+ print(prompt)
91
+ else:
92
+ prompt = text_prompt
93
+
94
+ if use_enhancer:
95
+ prompt = enhance_prompt(prompt)
96
+
97
  if randomize_seed:
98
  seed = random.randint(0, MAX_SEED)
 
99
 
100
+ generator = torch.Generator(device=device).manual_seed(seed)
101
+
102
+ image = pipe(
103
+ prompt=prompt,
104
+ generator=generator,
105
+ num_inference_steps=num_inference_steps,
106
+ width=width,
107
+ height=height,
108
+ guidance_scale=guidance_scale
109
+ ).images[0]
110
+
111
+ return image, prompt, seed
112
+
113
+ custom_css = """
114
+ .input-group, .output-group {
115
+ border: 1px solid #e0e0e0;
116
+ border-radius: 10px;
117
+ padding: 20px;
118
+ margin-bottom: 20px;
119
+ background-color: #f9f9f9;
 
 
120
  }
121
+ .submit-btn {
122
+ background-color: #2980b9 !important;
123
+ color: white !important;
124
+ }
125
+ .submit-btn:hover {
126
+ background-color: #3498db !important;
127
+ }
128
+ """
129
+
130
+ title = """<h1 align="center">FLUX.1-dev with Florence-2 Captioner and Prompt Enhancer</h1>
131
+ <p><center>
132
+ <a href="https://huggingface.co/black-forest-labs/FLUX.1-dev" target="_blank">[FLUX.1-dev Model]</a>
133
+ <a href="https://huggingface.co/microsoft/Florence-2-base" target="_blank">[Florence-2 Model]</a>
134
+ <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long" target="_blank">[Prompt Enhancer Long]</a>
135
+ <p align="center">Create long prompts from images or enhance your short prompts with prompt enhancer</p>
136
+ </center></p>
137
  """
138
 
139
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
140
+ gr.HTML(title)
141
 
142
+ with gr.Row():
143
+ with gr.Column(scale=1):
144
+ with gr.Group(elem_classes="input-group"):
145
+ input_image = gr.Image(label="Input Image (Florence-2 Captioner)")
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ with gr.Accordion("Advanced Settings", open=False):
148
+ text_prompt = gr.Textbox(label="Text Prompt (optional, used if no image is uploaded)")
149
+ use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
150
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
151
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
152
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
153
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
154
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=15, step=0.1, value=3.5)
155
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
 
 
 
 
156
 
157
+ generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ with gr.Column(scale=1):
160
+ with gr.Group(elem_classes="output-group"):
161
+ output_image = gr.Image(label="Result", elem_id="gallery", show_label=False)
162
+ final_prompt = gr.Textbox(label="Final Prompt Used")
163
+ used_seed = gr.Number(label="Seed Used")
164
+
165
+ generate_btn.click(
166
+ fn=process_workflow,
167
+ inputs=[
168
+ input_image, text_prompt, use_enhancer, seed, randomize_seed,
169
+ width, height, guidance_scale, num_inference_steps
170
+ ],
171
+ outputs=[output_image, final_prompt, used_seed]
172
  )
173
 
174
+ demo.launch(debug=True)