prithivMLmods commited on
Commit
9183b07
·
verified ·
1 Parent(s): 66a0bd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -7
app.py CHANGED
@@ -23,6 +23,11 @@ from transformers import (
23
  from transformers.image_utils import load_image
24
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
 
 
 
 
 
 
26
 
27
  DESCRIPTION = """
28
  # QwQ Edge 💬
@@ -45,6 +50,7 @@ h1 {
45
  MAX_MAX_NEW_TOKENS = 2048
46
  DEFAULT_MAX_NEW_TOKENS = 1024
47
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
48
 
49
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
 
@@ -116,8 +122,6 @@ if USE_TORCH_COMPILE:
116
  if ENABLE_CPU_OFFLOAD:
117
  sd_pipe.enable_model_cpu_offload()
118
 
119
- MAX_SEED = np.iinfo(np.int32).max
120
-
121
  def save_image(img: Image.Image) -> str:
122
  """Save a PIL image with a unique filename and return the path."""
123
  unique_name = str(uuid.uuid4()) + ".png"
@@ -178,6 +182,57 @@ def generate_image_fn(
178
  image_paths = [save_image(img) for img in images]
179
  return image_paths, seed
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  @spaces.GPU
182
  def generate(
183
  input_dict: dict,
@@ -189,16 +244,41 @@ def generate(
189
  repetition_penalty: float = 1.2,
190
  ):
191
  """
192
- Generates chatbot responses with support for multimodal input, TTS, and image generation.
 
 
193
  Special commands:
194
  - "@tts1" or "@tts2": triggers text-to-speech.
195
  - "@image": triggers image generation using the SDXL pipeline.
 
196
  """
197
  text = input_dict["text"]
198
  files = input_dict.get("files", [])
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  if text.strip().lower().startswith("@image"):
201
- # Remove the "@image" tag and use the rest as prompt
202
  prompt = text[len("@image"):].strip()
203
  yield "Generating image..."
204
  image_paths, used_seed = generate_image_fn(
@@ -214,10 +294,12 @@ def generate(
214
  use_resolution_binning=True,
215
  num_images=1,
216
  )
217
- # Yield the generated image so that the chat interface displays it.
218
  yield gr.Image(image_paths[0])
219
  return # Exit early
220
 
 
 
 
221
  tts_prefix = "@tts"
222
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
223
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
@@ -263,7 +345,6 @@ def generate(
263
  time.sleep(0.01)
264
  yield buffer
265
  else:
266
-
267
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
268
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
269
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -313,7 +394,7 @@ demo = gr.ChatInterface(
313
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
314
  ["Write a Python function to check if a number is prime."],
315
  ["@tts2 What causes rainbows to form?"],
316
-
317
  ],
318
  cache_examples=False,
319
  type="messages",
 
23
  from transformers.image_utils import load_image
24
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
 
26
+ # Additional imports for 3D model generation
27
+ import tempfile
28
+ import trimesh
29
+ from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
30
+ from diffusers.utils import export_to_ply
31
 
32
  DESCRIPTION = """
33
  # QwQ Edge 💬
 
50
  MAX_MAX_NEW_TOKENS = 2048
51
  DEFAULT_MAX_NEW_TOKENS = 1024
52
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
53
+ MAX_SEED = np.iinfo(np.int32).max
54
 
55
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
56
 
 
122
  if ENABLE_CPU_OFFLOAD:
123
  sd_pipe.enable_model_cpu_offload()
124
 
 
 
125
  def save_image(img: Image.Image) -> str:
126
  """Save a PIL image with a unique filename and return the path."""
127
  unique_name = str(uuid.uuid4()) + ".png"
 
182
  image_paths = [save_image(img) for img in images]
183
  return image_paths, seed
184
 
185
+ # ============================================================
186
+ # 3D Model Generation using ShapE (Text-to-3D / Image-to-3D)
187
+ # ============================================================
188
+ class Model3D:
189
+ def __init__(self):
190
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191
+ self.pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16)
192
+ self.pipe.to(self.device)
193
+
194
+ self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
195
+ self.pipe_img.to(self.device)
196
+
197
+ def to_glb(self, ply_path: str) -> str:
198
+ mesh = trimesh.load(ply_path)
199
+ rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
200
+ mesh = mesh.apply_transform(rot)
201
+ rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
202
+ mesh = mesh.apply_transform(rot)
203
+ mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
204
+ mesh.export(mesh_path.name, file_type="glb")
205
+ return mesh_path.name
206
+
207
+ def run_text(self, prompt: str, seed: int = 0, guidance_scale: float = 15.0, num_steps: int = 64) -> str:
208
+ generator = torch.Generator(device=self.device).manual_seed(seed)
209
+ images = self.pipe(
210
+ prompt,
211
+ generator=generator,
212
+ guidance_scale=guidance_scale,
213
+ num_inference_steps=num_steps,
214
+ output_type="mesh",
215
+ ).images
216
+ ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
217
+ export_to_ply(images[0], ply_path.name)
218
+ return self.to_glb(ply_path.name)
219
+
220
+ def run_image(self, image: Image.Image, seed: int = 0, guidance_scale: float = 3.0, num_steps: int = 64) -> str:
221
+ generator = torch.Generator(device=self.device).manual_seed(seed)
222
+ images = self.pipe_img(
223
+ image,
224
+ generator=generator,
225
+ guidance_scale=guidance_scale,
226
+ num_inference_steps=num_steps,
227
+ output_type="mesh",
228
+ ).images
229
+ ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
230
+ export_to_ply(images[0], ply_path.name)
231
+ return self.to_glb(ply_path.name)
232
+
233
+ # Create a global instance of the 3D model generator.
234
+ model_3d = Model3D()
235
+
236
  @spaces.GPU
237
  def generate(
238
  input_dict: dict,
 
244
  repetition_penalty: float = 1.2,
245
  ):
246
  """
247
+ Generates chatbot responses with support for multimodal input, TTS, image generation,
248
+ and 3D model generation.
249
+
250
  Special commands:
251
  - "@tts1" or "@tts2": triggers text-to-speech.
252
  - "@image": triggers image generation using the SDXL pipeline.
253
+ - "@3d": triggers 3D model generation using the ShapE pipeline.
254
  """
255
  text = input_dict["text"]
256
  files = input_dict.get("files", [])
257
 
258
+ # ------------------------------
259
+ # 3D Model Generation Command
260
+ # ------------------------------
261
+ if text.strip().lower().startswith("@3d"):
262
+ # Remove the "@3d" tag and use the remaining text as the prompt.
263
+ text = text[len("@3d"):].strip()
264
+ yield "Generating 3D model..."
265
+ seed = random.randint(0, MAX_SEED)
266
+ if files:
267
+ # If an image is provided, use image-to-3D.
268
+ image = load_image(files[0])
269
+ glb_file = model_3d.run_image(image, seed=seed)
270
+ else:
271
+ # Otherwise, generate a 3D model from the text prompt.
272
+ glb_file = model_3d.run_text(text, seed=seed)
273
+ # Yield the generated GLB file as a downloadable file.
274
+ yield gr.File(glb_file)
275
+ return
276
+
277
+ # ------------------------------
278
+ # Image Generation Command
279
+ # ------------------------------
280
  if text.strip().lower().startswith("@image"):
281
+ # Remove the "@image" tag and use the rest as prompt.
282
  prompt = text[len("@image"):].strip()
283
  yield "Generating image..."
284
  image_paths, used_seed = generate_image_fn(
 
294
  use_resolution_binning=True,
295
  num_images=1,
296
  )
 
297
  yield gr.Image(image_paths[0])
298
  return # Exit early
299
 
300
+ # ------------------------------
301
+ # TTS / Regular Text Generation
302
+ # ------------------------------
303
  tts_prefix = "@tts"
304
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
305
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
 
345
  time.sleep(0.01)
346
  yield buffer
347
  else:
 
348
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
349
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
350
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
394
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
395
  ["Write a Python function to check if a number is prime."],
396
  ["@tts2 What causes rainbows to form?"],
397
+ ["@3d A futuristic spaceship in low-poly style"],
398
  ],
399
  cache_examples=False,
400
  type="messages",