gokaygokay commited on
Commit
c3d416d
·
1 Parent(s): 03dc1fe
Files changed (1) hide show
  1. app.py +57 -37
app.py CHANGED
@@ -1,27 +1,20 @@
1
  import spaces
2
- import argparse
3
  import os
4
  import time
5
  from os import path
6
- from safetensors.torch import load_file
7
  from huggingface_hub import hf_hub_download
8
- import imageio
9
  import numpy as np
10
  import torch
11
  import rembg
12
  from PIL import Image
13
  from torchvision.transforms import v2
 
14
  from pytorch_lightning import seed_everything
15
  from omegaconf import OmegaConf
16
- from einops import rearrange, repeat
17
- from tqdm import tqdm
18
  from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
19
  import gradio as gr
20
  import shutil
21
  import tempfile
22
- from functools import partial
23
- from optimum.quanto import quantize, qfloat8, freeze
24
- from diffusers import FluxPipeline
25
  from src.utils.train_util import instantiate_from_config
26
  from src.utils.camera_util import (
27
  FOV_to_intrinsics,
@@ -30,6 +23,9 @@ from src.utils.camera_util import (
30
  )
31
  from src.utils.mesh_util import save_obj, save_glb
32
  from src.utils.infer_util import remove_background, resize_foreground, images_to_video
 
 
 
33
 
34
  # Set up cache path
35
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
@@ -71,19 +67,11 @@ else:
71
  print("CUDA installation not found")
72
 
73
 
74
- device = 'cuda'
75
-
76
- base_model = "black-forest-labs/FLUX.1-dev"
77
- file_flux = hf_hub_download("marduk191/Flux.1_collection", "flux.1_dev_8x8_e4m3fn-marduk191.safetensors")
78
- pipe = FluxPipeline.from_single_file(file_flux, torch_dtype=torch.bfloat16, token=huggingface_token)
79
-
80
- # Load and fuse LoRA BEFORE quantizing
81
- print('Loading and fusing lora, please wait...')
82
- lora_path = hf_hub_download("gokaygokay/Flux-Game-Assets-LoRA-v2", "game_asst.safetensors")
83
- pipe.load_lora_weights(lora_path)
84
- pipe.fuse_lora(lora_scale=1.0)
85
- pipe.unload_lora_weights()
86
 
 
87
 
88
  # Load 3D generation models
89
  config_path = 'configs/instant-mesh-large.yaml'
@@ -143,20 +131,7 @@ def preprocess(input_image, do_remove_background):
143
  input_image = resize_foreground(input_image, 0.85)
144
  return input_image
145
 
146
- ts_cutoff = 2
147
 
148
- @spaces.GPU
149
- def generate_flux_image(prompt, height, width, steps, scales, seed):
150
- pipe.to(device)
151
- return pipe(
152
- prompt=prompt,
153
- width=int(height),
154
- height=int(width),
155
- num_inference_steps=int(steps),
156
- generator=torch.Generator().manual_seed(int(seed)),
157
- guidance_scale=float(scales),
158
- timestep_to_start_cfg=ts_cutoff,
159
- ).images[0]
160
 
161
 
162
  @spaces.GPU
@@ -209,6 +184,45 @@ def make3d(images):
209
 
210
  return mesh_fpath, mesh_glb_fpath
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
213
  gr.Markdown(
214
  """
@@ -236,7 +250,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
236
  steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28)
237
  scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
238
 
239
- seed = gr.Number(label="Seed (for reproducibility)", value=3413, precision=0)
 
240
 
241
  generate_btn = gr.Button("Generate 3D Model", variant="primary")
242
 
@@ -251,8 +266,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
251
 
252
  mv_images = gr.State()
253
 
254
- def process_pipeline(prompt, height, width, steps, scales, seed):
255
- flux_image = generate_flux_image(prompt, height, width, steps, scales, seed)
 
 
 
 
 
256
  processed_image = preprocess(flux_image, do_remove_background=True)
257
  mv_images, show_image = generate_mvs(processed_image, steps, seed)
258
  obj_path, glb_path = make3d(mv_images)
@@ -260,7 +280,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
260
 
261
  generate_btn.click(
262
  fn=process_pipeline,
263
- inputs=[prompt, height, width, steps, scales, seed],
264
  outputs=[flux_output, mv_show_images, output_model_obj, output_model_glb]
265
  )
266
 
 
1
  import spaces
 
2
  import os
3
  import time
4
  from os import path
 
5
  from huggingface_hub import hf_hub_download
 
6
  import numpy as np
7
  import torch
8
  import rembg
9
  from PIL import Image
10
  from torchvision.transforms import v2
11
+ from einops import rearrange
12
  from pytorch_lightning import seed_everything
13
  from omegaconf import OmegaConf
 
 
14
  from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
15
  import gradio as gr
16
  import shutil
17
  import tempfile
 
 
 
18
  from src.utils.train_util import instantiate_from_config
19
  from src.utils.camera_util import (
20
  FOV_to_intrinsics,
 
23
  )
24
  from src.utils.mesh_util import save_obj, save_glb
25
  from src.utils.infer_util import remove_background, resize_foreground, images_to_video
26
+ import random
27
+ import requests
28
+ import io
29
 
30
  # Set up cache path
31
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
 
67
  print("CUDA installation not found")
68
 
69
 
70
+ API_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
71
+ headers = {"Authorization": f"Bearer {API_TOKEN}"}
72
+ timeout = 100
 
 
 
 
 
 
 
 
 
73
 
74
+ device = 'cuda'
75
 
76
  # Load 3D generation models
77
  config_path = 'configs/instant-mesh-large.yaml'
 
131
  input_image = resize_foreground(input_image, 0.85)
132
  return input_image
133
 
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
 
137
  @spaces.GPU
 
184
 
185
  return mesh_fpath, mesh_glb_fpath
186
 
187
+ # Remove the FluxPipeline setup and replace with the query function
188
+ def query(prompt, steps=28, cfg_scale=3.5, randomize_seed=True, seed=-1, width=1024, height=1024):
189
+ if not prompt:
190
+ return None
191
+
192
+ lora_id = "gokaygokay/Flux-Game-Assets-LoRA-v2"
193
+ API_URL = f"https://api-inference.huggingface.co/models/{lora_id}"
194
+
195
+ if randomize_seed:
196
+ seed = random.randint(1, 4294967296)
197
+
198
+ prompt = f"{prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
199
+
200
+ payload = {
201
+ "inputs": prompt,
202
+ "steps": steps,
203
+ "cfg_scale": cfg_scale,
204
+ "seed": seed,
205
+ "parameters": {
206
+ "width": width,
207
+ "height": height
208
+ }
209
+ }
210
+
211
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=100)
212
+ if response.status_code != 200:
213
+ if response.status_code == 503:
214
+ raise gr.Error("The model is being loaded")
215
+ raise gr.Error(f"Error {response.status_code}")
216
+
217
+ try:
218
+ image_bytes = response.content
219
+ image = Image.open(io.BytesIO(image_bytes))
220
+ return image
221
+ except Exception as e:
222
+ print(f"Error when trying to open the image: {e}")
223
+ return None
224
+
225
+ # Update the Gradio interface
226
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
227
  gr.Markdown(
228
  """
 
250
  steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, step=1, value=28)
251
  scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
252
 
253
+ seed = gr.Number(label="Seed", value=-1, precision=0)
254
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
255
 
256
  generate_btn = gr.Button("Generate 3D Model", variant="primary")
257
 
 
266
 
267
  mv_images = gr.State()
268
 
269
+ def process_pipeline(prompt, height, width, steps, scales, seed, randomize_seed):
270
+ # Generate Flux image using the API
271
+ prompt_real = f"wbgmsst, {prompt}, white background"
272
+ flux_image = query(prompt_real, steps, scales, randomize_seed, seed, width, height)
273
+ if flux_image is None:
274
+ raise gr.Error("Failed to generate image")
275
+
276
  processed_image = preprocess(flux_image, do_remove_background=True)
277
  mv_images, show_image = generate_mvs(processed_image, steps, seed)
278
  obj_path, glb_path = make3d(mv_images)
 
280
 
281
  generate_btn.click(
282
  fn=process_pipeline,
283
+ inputs=[prompt, height, width, steps, scales, seed, randomize_seed],
284
  outputs=[flux_output, mv_show_images, output_model_obj, output_model_glb]
285
  )
286