prithivMLmods commited on
Commit
f948054
·
verified ·
1 Parent(s): 9183b07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -36
app.py CHANGED
@@ -54,7 +54,9 @@ MAX_SEED = np.iinfo(np.int32).max
54
 
55
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
56
 
57
- # Load text-only model and tokenizer
 
 
58
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
59
  tokenizer = AutoTokenizer.from_pretrained(model_id)
60
  model = AutoModelForCausalLM.from_pretrained(
@@ -94,14 +96,15 @@ def clean_chat_history(chat_history):
94
  cleaned.append(msg)
95
  return cleaned
96
 
97
- # Environment variables and parameters for Stable Diffusion XL
 
 
98
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
99
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
100
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
101
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
102
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
103
 
104
- # Load the SDXL pipeline
105
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
106
  MODEL_ID_SD,
107
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -110,15 +113,12 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
110
  ).to(device)
111
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
112
 
113
- # Ensure that the text encoder is in half-precision if using CUDA.
114
  if torch.cuda.is_available():
115
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
116
 
117
- # Optional: compile the model for speedup if enabled
118
  if USE_TORCH_COMPILE:
119
  sd_pipe.compile()
120
 
121
- # Optional: offload parts of the model to CPU if needed
122
  if ENABLE_CPU_OFFLOAD:
123
  sd_pipe.enable_model_cpu_offload()
124
 
@@ -166,13 +166,11 @@ def generate_image_fn(
166
  options["use_resolution_binning"] = True
167
 
168
  images = []
169
- # Process in batches
170
  for i in range(0, num_images, BATCH_SIZE):
171
  batch_options = options.copy()
172
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
173
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
174
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
175
- # Wrap the pipeline call in autocast if using CUDA
176
  if device.type == "cuda":
177
  with torch.autocast("cuda", dtype=torch.float16):
178
  outputs = sd_pipe(**batch_options)
@@ -182,50 +180,76 @@ def generate_image_fn(
182
  image_paths = [save_image(img) for img in images]
183
  return image_paths, seed
184
 
185
- # ============================================================
186
  # 3D Model Generation using ShapE (Text-to-3D / Image-to-3D)
187
- # ============================================================
188
  class Model3D:
189
  def __init__(self):
190
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
191
  self.pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16)
192
  self.pipe.to(self.device)
 
 
193
 
194
  self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
195
  self.pipe_img.to(self.device)
 
 
196
 
197
  def to_glb(self, ply_path: str) -> str:
198
  mesh = trimesh.load(ply_path)
199
  rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
200
- mesh = mesh.apply_transform(rot)
201
  rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
202
- mesh = mesh.apply_transform(rot)
203
  mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
204
  mesh.export(mesh_path.name, file_type="glb")
205
  return mesh_path.name
206
 
207
  def run_text(self, prompt: str, seed: int = 0, guidance_scale: float = 15.0, num_steps: int = 64) -> str:
208
  generator = torch.Generator(device=self.device).manual_seed(seed)
209
- images = self.pipe(
210
- prompt,
211
- generator=generator,
212
- guidance_scale=guidance_scale,
213
- num_inference_steps=num_steps,
214
- output_type="mesh",
215
- ).images
 
 
 
 
 
 
 
 
 
 
 
216
  ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
217
  export_to_ply(images[0], ply_path.name)
218
  return self.to_glb(ply_path.name)
219
 
220
  def run_image(self, image: Image.Image, seed: int = 0, guidance_scale: float = 3.0, num_steps: int = 64) -> str:
221
  generator = torch.Generator(device=self.device).manual_seed(seed)
222
- images = self.pipe_img(
223
- image,
224
- generator=generator,
225
- guidance_scale=guidance_scale,
226
- num_inference_steps=num_steps,
227
- output_type="mesh",
228
- ).images
 
 
 
 
 
 
 
 
 
 
 
229
  ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
230
  export_to_ply(images[0], ply_path.name)
231
  return self.to_glb(ply_path.name)
@@ -259,18 +283,14 @@ def generate(
259
  # 3D Model Generation Command
260
  # ------------------------------
261
  if text.strip().lower().startswith("@3d"):
262
- # Remove the "@3d" tag and use the remaining text as the prompt.
263
  text = text[len("@3d"):].strip()
264
  yield "Generating 3D model..."
265
  seed = random.randint(0, MAX_SEED)
266
  if files:
267
- # If an image is provided, use image-to-3D.
268
  image = load_image(files[0])
269
  glb_file = model_3d.run_image(image, seed=seed)
270
  else:
271
- # Otherwise, generate a 3D model from the text prompt.
272
  glb_file = model_3d.run_text(text, seed=seed)
273
- # Yield the generated GLB file as a downloadable file.
274
  yield gr.File(glb_file)
275
  return
276
 
@@ -278,7 +298,6 @@ def generate(
278
  # Image Generation Command
279
  # ------------------------------
280
  if text.strip().lower().startswith("@image"):
281
- # Remove the "@image" tag and use the rest as prompt.
282
  prompt = text[len("@image"):].strip()
283
  yield "Generating image..."
284
  image_paths, used_seed = generate_image_fn(
@@ -295,7 +314,7 @@ def generate(
295
  num_images=1,
296
  )
297
  yield gr.Image(image_paths[0])
298
- return # Exit early
299
 
300
  # ------------------------------
301
  # TTS / Regular Text Generation
@@ -307,11 +326,9 @@ def generate(
307
  if is_tts and voice_index:
308
  voice = TTS_VOICES[voice_index - 1]
309
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
310
- # Clear previous chat history for a fresh TTS request.
311
  conversation = [{"role": "user", "content": text}]
312
  else:
313
  voice = None
314
- # Remove any stray @tts tags and build the conversation history.
315
  text = text.replace(tts_prefix, "").strip()
316
  conversation = clean_chat_history(chat_history)
317
  conversation.append({"role": "user", "content": text})
@@ -373,7 +390,6 @@ def generate(
373
  final_response = "".join(outputs)
374
  yield final_response
375
 
376
- # If TTS was requested, convert the final response to speech.
377
  if is_tts and voice:
378
  output_file = asyncio.run(text_to_speech(final_response, voice))
379
  yield gr.Audio(output_file, autoplay=True)
@@ -407,5 +423,4 @@ demo = gr.ChatInterface(
407
  )
408
 
409
  if __name__ == "__main__":
410
- # To create a public link, set share=True in launch().
411
  demo.queue(max_size=20).launch(share=True)
 
54
 
55
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
56
 
57
+ # ------------------------------
58
+ # Text Generation Model
59
+ # ------------------------------
60
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
61
  tokenizer = AutoTokenizer.from_pretrained(model_id)
62
  model = AutoModelForCausalLM.from_pretrained(
 
96
  cleaned.append(msg)
97
  return cleaned
98
 
99
+ # ------------------------------
100
+ # Stable Diffusion XL (Image Generation)
101
+ # ------------------------------
102
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
103
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
104
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
105
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
106
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
107
 
 
108
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
109
  MODEL_ID_SD,
110
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
113
  ).to(device)
114
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
115
 
 
116
  if torch.cuda.is_available():
117
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
118
 
 
119
  if USE_TORCH_COMPILE:
120
  sd_pipe.compile()
121
 
 
122
  if ENABLE_CPU_OFFLOAD:
123
  sd_pipe.enable_model_cpu_offload()
124
 
 
166
  options["use_resolution_binning"] = True
167
 
168
  images = []
 
169
  for i in range(0, num_images, BATCH_SIZE):
170
  batch_options = options.copy()
171
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
172
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
173
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
 
174
  if device.type == "cuda":
175
  with torch.autocast("cuda", dtype=torch.float16):
176
  outputs = sd_pipe(**batch_options)
 
180
  image_paths = [save_image(img) for img in images]
181
  return image_paths, seed
182
 
183
+ # ------------------------------
184
  # 3D Model Generation using ShapE (Text-to-3D / Image-to-3D)
185
+ # ------------------------------
186
  class Model3D:
187
  def __init__(self):
188
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
189
  self.pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16)
190
  self.pipe.to(self.device)
191
+ # Ensure the text encoder is in half precision
192
+ self.pipe.text_encoder = self.pipe.text_encoder.half()
193
 
194
  self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
195
  self.pipe_img.to(self.device)
196
+ # Ensure the text encoder is in half precision
197
+ self.pipe_img.text_encoder = self.pipe_img.text_encoder.half()
198
 
199
  def to_glb(self, ply_path: str) -> str:
200
  mesh = trimesh.load(ply_path)
201
  rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
202
+ mesh.apply_transform(rot)
203
  rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
204
+ mesh.apply_transform(rot)
205
  mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
206
  mesh.export(mesh_path.name, file_type="glb")
207
  return mesh_path.name
208
 
209
  def run_text(self, prompt: str, seed: int = 0, guidance_scale: float = 15.0, num_steps: int = 64) -> str:
210
  generator = torch.Generator(device=self.device).manual_seed(seed)
211
+ if self.device.type == "cuda":
212
+ with torch.autocast("cuda", dtype=torch.float16):
213
+ output = self.pipe(
214
+ prompt,
215
+ generator=generator,
216
+ guidance_scale=guidance_scale,
217
+ num_inference_steps=num_steps,
218
+ output_type="mesh",
219
+ )
220
+ else:
221
+ output = self.pipe(
222
+ prompt,
223
+ generator=generator,
224
+ guidance_scale=guidance_scale,
225
+ num_inference_steps=num_steps,
226
+ output_type="mesh",
227
+ )
228
+ images = output.images
229
  ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
230
  export_to_ply(images[0], ply_path.name)
231
  return self.to_glb(ply_path.name)
232
 
233
  def run_image(self, image: Image.Image, seed: int = 0, guidance_scale: float = 3.0, num_steps: int = 64) -> str:
234
  generator = torch.Generator(device=self.device).manual_seed(seed)
235
+ if self.device.type == "cuda":
236
+ with torch.autocast("cuda", dtype=torch.float16):
237
+ output = self.pipe_img(
238
+ image,
239
+ generator=generator,
240
+ guidance_scale=guidance_scale,
241
+ num_inference_steps=num_steps,
242
+ output_type="mesh",
243
+ )
244
+ else:
245
+ output = self.pipe_img(
246
+ image,
247
+ generator=generator,
248
+ guidance_scale=guidance_scale,
249
+ num_inference_steps=num_steps,
250
+ output_type="mesh",
251
+ )
252
+ images = output.images
253
  ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
254
  export_to_ply(images[0], ply_path.name)
255
  return self.to_glb(ply_path.name)
 
283
  # 3D Model Generation Command
284
  # ------------------------------
285
  if text.strip().lower().startswith("@3d"):
 
286
  text = text[len("@3d"):].strip()
287
  yield "Generating 3D model..."
288
  seed = random.randint(0, MAX_SEED)
289
  if files:
 
290
  image = load_image(files[0])
291
  glb_file = model_3d.run_image(image, seed=seed)
292
  else:
 
293
  glb_file = model_3d.run_text(text, seed=seed)
 
294
  yield gr.File(glb_file)
295
  return
296
 
 
298
  # Image Generation Command
299
  # ------------------------------
300
  if text.strip().lower().startswith("@image"):
 
301
  prompt = text[len("@image"):].strip()
302
  yield "Generating image..."
303
  image_paths, used_seed = generate_image_fn(
 
314
  num_images=1,
315
  )
316
  yield gr.Image(image_paths[0])
317
+ return
318
 
319
  # ------------------------------
320
  # TTS / Regular Text Generation
 
326
  if is_tts and voice_index:
327
  voice = TTS_VOICES[voice_index - 1]
328
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
329
  conversation = [{"role": "user", "content": text}]
330
  else:
331
  voice = None
 
332
  text = text.replace(tts_prefix, "").strip()
333
  conversation = clean_chat_history(chat_history)
334
  conversation.append({"role": "user", "content": text})
 
390
  final_response = "".join(outputs)
391
  yield final_response
392
 
 
393
  if is_tts and voice:
394
  output_file = asyncio.run(text_to_speech(final_response, voice))
395
  yield gr.Audio(output_file, autoplay=True)
 
423
  )
424
 
425
  if __name__ == "__main__":
 
426
  demo.queue(max_size=20).launch(share=True)