prithivMLmods commited on
Commit
bfa5350
·
verified ·
1 Parent(s): 8492499

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -1
app.py CHANGED
@@ -114,6 +114,16 @@ class Model:
114
  export_to_ply(images[0], ply_path.name)
115
  return self.to_glb(ply_path.name)
116
 
 
 
 
 
 
 
 
 
 
 
117
  # -----------------------------------------------------------------------------
118
  # Gradio UI configuration
119
  # -----------------------------------------------------------------------------
@@ -367,6 +377,9 @@ def generate(
367
  - "@3d": triggers 3D model generation using the ShapE pipeline.
368
  - "@web": triggers a web command. Use "visit" to visit a URL (e.g., "@web visit https://example.com")
369
  or "search" to perform a DuckDuckGo search (e.g., "@web search AI news").
 
 
 
370
  """
371
  text = input_dict["text"]
372
  files = input_dict.get("files", [])
@@ -389,7 +402,6 @@ def generate(
389
  new_filename = f"mesh_{uuid.uuid4()}.glb"
390
  new_filepath = os.path.join(static_folder, new_filename)
391
  shutil.copy(glb_path, new_filepath)
392
-
393
  yield gr.File(new_filepath)
394
  return
395
 
@@ -477,6 +489,7 @@ def generate(
477
  time.sleep(0.01)
478
  yield buffer
479
  else:
 
480
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
481
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
482
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -498,6 +511,7 @@ def generate(
498
  t.start()
499
 
500
  outputs = []
 
501
  for new_text in streamer:
502
  outputs.append(new_text)
503
  yield "".join(outputs)
@@ -505,6 +519,16 @@ def generate(
505
  final_response = "".join(outputs)
506
  yield final_response
507
 
 
 
 
 
 
 
 
 
 
 
508
  if is_tts and voice:
509
  output_file = asyncio.run(text_to_speech(final_response, voice))
510
  yield gr.Audio(output_file, autoplay=True)
 
114
  export_to_ply(images[0], ply_path.name)
115
  return self.to_glb(ply_path.name)
116
 
117
+ # -----------------------------------------------------------------------------
118
+ # Helper function for 3D generation using the Model class
119
+ # -----------------------------------------------------------------------------
120
+
121
+ def generate_3d_fn(prompt: str, seed: int, guidance_scale: float, num_steps: int, randomize_seed: bool):
122
+ seed = randomize_seed_fn(seed, randomize_seed)
123
+ model_3d = Model()
124
+ glb_path = model_3d.run_text(prompt, seed=seed, guidance_scale=guidance_scale, num_steps=num_steps)
125
+ return glb_path, seed
126
+
127
  # -----------------------------------------------------------------------------
128
  # Gradio UI configuration
129
  # -----------------------------------------------------------------------------
 
377
  - "@3d": triggers 3D model generation using the ShapE pipeline.
378
  - "@web": triggers a web command. Use "visit" to visit a URL (e.g., "@web visit https://example.com")
379
  or "search" to perform a DuckDuckGo search (e.g., "@web search AI news").
380
+
381
+ Additionally, for every default (plain text) query (i.e. no special command), after the answer
382
+ the bot will append reference links for further reading.
383
  """
384
  text = input_dict["text"]
385
  files = input_dict.get("files", [])
 
402
  new_filename = f"mesh_{uuid.uuid4()}.glb"
403
  new_filepath = os.path.join(static_folder, new_filename)
404
  shutil.copy(glb_path, new_filepath)
 
405
  yield gr.File(new_filepath)
406
  return
407
 
 
489
  time.sleep(0.01)
490
  yield buffer
491
  else:
492
+ # --- Default plain text branch ---
493
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
494
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
495
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
511
  t.start()
512
 
513
  outputs = []
514
+ yield "Thinking..."
515
  for new_text in streamer:
516
  outputs.append(new_text)
517
  yield "".join(outputs)
 
519
  final_response = "".join(outputs)
520
  yield final_response
521
 
522
+ # --- Append Reference Links after the answer ---
523
+ try:
524
+ # Use the original query as the search term to fetch reference links (limit to 3 results)
525
+ search_tool = DuckDuckGoSearchTool(max_results=3)
526
+ reference_links = search_tool.forward(input_dict["text"])
527
+ reference_message = "\n\nFor more info, visit:\n" + reference_links
528
+ yield reference_message
529
+ except Exception as e:
530
+ yield f"\n\n[Error retrieving reference links: {str(e)}]"
531
+
532
  if is_tts and voice:
533
  output_file = asyncio.run(text_to_speech(final_response, voice))
534
  yield gr.Audio(output_file, autoplay=True)