gokaygokay commited on
Commit
f469d2f
·
verified ·
1 Parent(s): 62ea6c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +241 -141
app.py CHANGED
@@ -1,168 +1,268 @@
1
  import spaces
2
- import gradio as gr
 
 
 
 
 
 
 
3
  import torch
 
4
  from PIL import Image
5
- from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
6
- from diffusers import DiffusionPipeline
7
- import random
8
- import numpy as np
9
- import os
10
- import subprocess
11
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
 
 
 
 
12
 
13
- # Initialize models
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- dtype = torch.bfloat16
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
18
 
19
- from diffusers import FluxPipeline, FluxTransformer2DModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
21
 
22
- transformer = FluxTransformer2DModel.from_single_file(
23
- "https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  torch_dtype=torch.float16,
25
  )
26
-
27
- pipe = FluxPipeline.from_pretrained(
28
- "black-forest-labs/FLUX.1-dev",
29
- transformer=None,
30
- torch_dtype=torch.bfloat16,
31
- token=huggingface_token
32
  )
33
 
34
- pipe.transformer = transformer
 
 
 
35
 
36
- # Initialize Florence model
37
- florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
38
- florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
39
 
40
- # Prompt Enhancer
41
- enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
 
 
 
 
 
42
 
43
- MAX_SEED = np.iinfo(np.int32).max
44
- MAX_IMAGE_SIZE = 2048
45
 
46
- # Florence caption function
47
- @spaces.GPU
48
- def florence_caption(image):
49
- # Convert image to PIL if it's not already
50
- if not isinstance(image, Image.Image):
51
- image = Image.fromarray(image)
52
-
53
- inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
54
- generated_ids = florence_model.generate(
55
- input_ids=inputs["input_ids"],
56
- pixel_values=inputs["pixel_values"],
57
- max_new_tokens=1024,
58
- early_stopping=False,
59
- do_sample=False,
60
- num_beams=3,
61
- )
62
- generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
63
- parsed_answer = florence_processor.post_process_generation(
64
- generated_text,
65
- task="<MORE_DETAILED_CAPTION>",
66
- image_size=(image.width, image.height)
67
- )
68
- return parsed_answer["<MORE_DETAILED_CAPTION>"]
69
-
70
- # Prompt Enhancer function
71
- def enhance_prompt(input_prompt):
72
- result = enhancer_long("Enhance the description: " + input_prompt)
73
- enhanced_text = result[0]['summary_text']
74
- return enhanced_text
75
-
76
- @spaces.GPU(duration=190)
77
- def process_workflow(image, text_prompt, use_enhancer, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
78
- if image is not None:
79
- # Convert image to PIL if it's not already
80
- if not isinstance(image, Image.Image):
81
- image = Image.fromarray(image)
82
-
83
- prompt = florence_caption(image)
84
- print(prompt)
85
  else:
86
- prompt = text_prompt
87
-
88
- if use_enhancer:
89
- prompt = enhance_prompt(prompt)
90
-
91
- if randomize_seed:
92
- seed = random.randint(0, MAX_SEED)
93
-
94
- generator = torch.Generator(device=device).manual_seed(seed)
95
-
96
- image = pipe(
97
- prompt=prompt,
98
- generator=generator,
99
- num_inference_steps=num_inference_steps,
100
- width=width,
101
- height=height,
102
- guidance_scale=guidance_scale
 
 
 
 
 
 
 
 
103
  ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- return image, prompt, seed
106
-
107
- custom_css = """
108
- .input-group, .output-group {
109
- border: 1px solid #e0e0e0;
110
- border-radius: 10px;
111
- padding: 20px;
112
- margin-bottom: 20px;
113
- background-color: #f9f9f9;
114
- }
115
- .submit-btn {
116
- background-color: #2980b9 !important;
117
- color: white !important;
118
- }
119
- .submit-btn:hover {
120
- background-color: #3498db !important;
121
- }
122
- """
123
-
124
- title = """<h1 align="center">FLUX.1-dev with Florence-2 Captioner and Prompt Enhancer</h1>
125
- <p><center>
126
- <a href="https://huggingface.co/black-forest-labs/FLUX.1-dev" target="_blank">[FLUX.1-dev Model]</a>
127
- <a href="https://huggingface.co/microsoft/Florence-2-base" target="_blank">[Florence-2 Model]</a>
128
- <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long" target="_blank">[Prompt Enhancer Long]</a>
129
- <p align="center">Create long prompts from images or enhance your short prompts with prompt enhancer</p>
130
- </center></p>
131
- """
132
-
133
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
134
- gr.HTML(title)
135
-
136
  with gr.Row():
137
- with gr.Column(scale=1):
138
- with gr.Group(elem_classes="input-group"):
139
- input_image = gr.Image(label="Input Image (Florence-2 Captioner)")
 
 
 
140
 
141
  with gr.Accordion("Advanced Settings", open=False):
142
- text_prompt = gr.Textbox(label="Text Prompt (optional, used if no image is uploaded)")
143
- use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
144
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
145
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
146
- width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
147
- height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
148
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=15, step=0.1, value=3.5)
149
- num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
 
 
150
 
151
- generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
152
-
153
- with gr.Column(scale=1):
154
- with gr.Group(elem_classes="output-group"):
155
- output_image = gr.Image(label="Result", elem_id="gallery", show_label=False)
156
- final_prompt = gr.Textbox(label="Final Prompt Used")
157
- used_seed = gr.Number(label="Seed Used")
158
-
 
 
 
 
 
 
 
 
 
 
 
 
159
  generate_btn.click(
160
- fn=process_workflow,
161
- inputs=[
162
- input_image, text_prompt, use_enhancer, seed, randomize_seed,
163
- width, height, guidance_scale, num_inference_steps
164
- ],
165
- outputs=[output_image, final_prompt, used_seed]
166
  )
167
 
168
- demo.launch(debug=True)
 
 
1
  import spaces
2
+ import argparse
3
+ import os
4
+ import time
5
+ from os import path
6
+ from safetensors.torch import load_file
7
+ from huggingface_hub import hf_hub_download
8
+ import imageio
9
+ import numpy as np
10
  import torch
11
+ import rembg
12
  from PIL import Image
13
+ from torchvision.transforms import v2
14
+ from pytorch_lightning import seed_everything
15
+ from omegaconf import OmegaConf
16
+ from einops import rearrange, repeat
17
+ from tqdm import tqdm
18
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
19
+ import gradio as gr
20
+ import shutil
21
+ import tempfile
22
+ from functools import partial
23
+ from optimum.quanto import quantize, qfloat8, freeze
24
+ from diffusers import FluxPipeline
25
 
26
+ from src.utils.train_util import instantiate_from_config
27
+ from src.utils.camera_util import (
28
+ FOV_to_intrinsics,
29
+ get_zero123plus_input_cameras,
30
+ get_circular_camera_poses,
31
+ )
32
+ from src.utils.mesh_util import save_obj, save_glb
33
+ from src.utils.infer_util import remove_background, resize_foreground, images_to_video
34
+
35
+ # Set up cache path
36
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
37
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
38
+ os.environ["HF_HUB_CACHE"] = cache_path
39
+ os.environ["HF_HOME"] = cache_path
40
 
41
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
42
 
43
+ if not path.exists(cache_path):
44
+ os.makedirs(cache_path, exist_ok=True)
45
+
46
+ torch.backends.cuda.matmul.allow_tf32 = True
47
+
48
+ class timer:
49
+ def __init__(self, method_name="timed process"):
50
+ self.method = method_name
51
+ def __enter__(self):
52
+ self.start = time.time()
53
+ print(f"{self.method} starts")
54
+ def __exit__(self, exc_type, exc_val, exc_tb):
55
+ end = time.time()
56
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
57
+
58
+ def find_cuda():
59
+ cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
60
+ if cuda_home and os.path.exists(cuda_home):
61
+ return cuda_home
62
+ nvcc_path = shutil.which('nvcc')
63
+ if nvcc_path:
64
+ cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
65
+ return cuda_path
66
+ return None
67
+
68
+ cuda_path = find_cuda()
69
+ if cuda_path:
70
+ print(f"CUDA installation found at: {cuda_path}")
71
+ else:
72
+ print("CUDA installation not found")
73
+
74
 
75
+ device = torch.device('cuda')
76
 
77
+ base_model = "black-forest-labs/FLUX.1-dev"
78
+ pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16, token=huggingface_token)
79
+
80
+ # Load and fuse LoRA BEFORE quantizing
81
+ print('Loading and fusing lora, please wait...')
82
+ lora_path = hf_hub_download("gokaygokay/Flux-Game-Assets-LoRA-v2", "game_asst.safetensors")
83
+ pipe.load_lora_weights(lora_path)
84
+ pipe.fuse_lora(lora_scale=1.0)
85
+ pipe.unload_lora_weights()
86
+ pipe.enable_model_cpu_offload()
87
+
88
+
89
+ # Load 3D generation models
90
+ config_path = 'configs/instant-mesh-large.yaml'
91
+ config = OmegaConf.load(config_path)
92
+ config_name = os.path.basename(config_path).replace('.yaml', '')
93
+ model_config = config.model_config
94
+ infer_config = config.infer_config
95
+
96
+ IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
97
+
98
+ # Load diffusion model for 3D generation
99
+ print('Loading diffusion model ...')
100
+ pipeline = DiffusionPipeline.from_pretrained(
101
+ "sudo-ai/zero123plus-v1.2",
102
+ custom_pipeline="zero123plus",
103
  torch_dtype=torch.float16,
104
  )
105
+ pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
106
+ pipeline.scheduler.config, timestep_spacing='trailing'
 
 
 
 
107
  )
108
 
109
+ # Load custom white-background UNet
110
+ unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
111
+ state_dict = torch.load(unet_ckpt_path, map_location='cpu')
112
+ pipeline.unet.load_state_dict(state_dict, strict=True)
113
 
114
+ pipeline = pipeline.to(device)
 
 
115
 
116
+ # Load reconstruction model
117
+ print('Loading reconstruction model ...')
118
+ model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
119
+ model = instantiate_from_config(model_config)
120
+ state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
121
+ state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
122
+ model.load_state_dict(state_dict, strict=True)
123
 
124
+ model = model.to(device)
 
125
 
126
+ print('Loading Finished!')
127
+
128
+ def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
129
+ c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
130
+ if is_flexicubes:
131
+ cameras = torch.linalg.inv(c2ws)
132
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  else:
134
+ extrinsics = c2ws.flatten(-2)
135
+ intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
136
+ cameras = torch.cat([extrinsics, intrinsics], dim=-1)
137
+ cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
138
+ return cameras
139
+
140
+ def preprocess(input_image, do_remove_background):
141
+ rembg_session = rembg.new_session() if do_remove_background else None
142
+ if do_remove_background:
143
+ input_image = remove_background(input_image, rembg_session)
144
+ input_image = resize_foreground(input_image, 0.85)
145
+ return input_image
146
+
147
+ ts_cutoff = 2
148
+
149
+ @spaces.GPU
150
+ def generate_flux_image(prompt, height, width, steps, scales, seed):
151
+ return pipe(
152
+ prompt=prompt,
153
+ width=int(height),
154
+ height=int(width),
155
+ num_inference_steps=int(steps),
156
+ generator=torch.Generator().manual_seed(int(seed)),
157
+ guidance_scale=float(scales),
158
+ timestep_to_start_cfg=ts_cutoff,
159
  ).images[0]
160
+
161
+
162
+ @spaces.GPU
163
+ def generate_mvs(input_image, sample_steps, sample_seed):
164
+ seed_everything(sample_seed)
165
+ z123_image = pipeline(
166
+ input_image,
167
+ num_inference_steps=sample_steps
168
+ ).images[0]
169
+ show_image = np.asarray(z123_image, dtype=np.uint8)
170
+ show_image = torch.from_numpy(show_image)
171
+ show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
172
+ show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
173
+ show_image = Image.fromarray(show_image.numpy())
174
+ return z123_image, show_image
175
+
176
+ @spaces.GPU
177
+ def make3d(images):
178
+ global model
179
+ if IS_FLEXICUBES:
180
+ model.init_flexicubes_geometry(device, use_renderer=False)
181
+ model = model.eval()
182
+
183
+ images = np.asarray(images, dtype=np.float32) / 255.0
184
+ images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float()
185
+ images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2)
186
+
187
+ input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device)
188
+ render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
189
+
190
+ images = images.unsqueeze(0).to(device)
191
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
192
+
193
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
194
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
195
+ mesh_dirname = os.path.dirname(mesh_fpath)
196
+ mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
197
+
198
+ with torch.no_grad():
199
+ planes = model.forward_planes(images, input_cameras)
200
+ mesh_out = model.extract_mesh(
201
+ planes,
202
+ use_texture_map=False,
203
+ **infer_config,
204
+ )
205
+ vertices, faces, vertex_colors = mesh_out
206
+ vertices = vertices[:, [1, 2, 0]]
207
+ save_glb(vertices, faces, vertex_colors, mesh_glb_fpath)
208
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
209
 
210
+ return mesh_fpath, mesh_glb_fpath
211
+
212
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
213
+ gr.Markdown(
214
+ """
215
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
216
+ <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem;">Flux Image to 3D Model Generator</h1>
217
+ </div>
218
+ """
219
+ )
220
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  with gr.Row():
222
+ with gr.Column(scale=3):
223
+ prompt = gr.Textbox(
224
+ label="Your Image Description",
225
+ placeholder="E.g., A serene landscape with mountains and a lake at sunset",
226
+ lines=3
227
+ )
228
 
229
  with gr.Accordion("Advanced Settings", open=False):
230
+ with gr.Group():
231
+ with gr.Row():
232
+ height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
233
+ width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
234
+
235
+ with gr.Row():
236
+ steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28)
237
+ scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
238
+
239
+ seed = gr.Number(label="Seed (for reproducibility)", value=3413, precision=0)
240
 
241
+ generate_btn = gr.Button("Generate 3D Model", variant="primary")
242
+
243
+ with gr.Column(scale=4):
244
+ flux_output = gr.Image(label="Generated Flux Image")
245
+ mv_show_images = gr.Image(label="Generated Multi-views")
246
+ with gr.Row():
247
+ with gr.Tab("OBJ"):
248
+ output_model_obj = gr.Model3D(label="Output Model (OBJ Format)")
249
+ with gr.Tab("GLB"):
250
+ output_model_glb = gr.Model3D(label="Output Model (GLB Format)")
251
+
252
+ mv_images = gr.State()
253
+
254
+ def process_pipeline(prompt, height, width, steps, scales, seed):
255
+ flux_image = generate_flux_image(prompt, height, width, steps, scales, seed)
256
+ processed_image = preprocess(flux_image, do_remove_background=True)
257
+ mv_images, show_image = generate_mvs(processed_image, steps, seed)
258
+ obj_path, glb_path = make3d(mv_images)
259
+ return flux_image, show_image, obj_path, glb_path
260
+
261
  generate_btn.click(
262
+ fn=process_pipeline,
263
+ inputs=[prompt, height, width, steps, scales, seed],
264
+ outputs=[flux_output, mv_show_images, output_model_obj, output_model_glb]
 
 
 
265
  )
266
 
267
+ if __name__ == "__main__":
268
+ demo.launch()