Spaces:
Runtime error
Runtime error
Commit
·
1f359be
1
Parent(s):
95ede1d
Update app.py
Browse files
app.py
CHANGED
|
@@ -705,6 +705,15 @@ def change_radio_display(task_type, mask_source_radio):
|
|
| 705 |
num_relation_visible = True
|
| 706 |
return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
|
| 707 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
if __name__ == "__main__":
|
| 709 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
| 710 |
parser.add_argument("--debug", action="store_true", help="using debug mode")
|
|
@@ -721,7 +730,12 @@ if __name__ == "__main__":
|
|
| 721 |
load_ram_model()
|
| 722 |
|
| 723 |
os.system("pip list")
|
| 724 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
block = gr.Blocks().queue()
|
| 726 |
with block:
|
| 727 |
with gr.Row():
|
|
|
|
| 705 |
num_relation_visible = True
|
| 706 |
return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
|
| 707 |
|
| 708 |
+
def get_model_device(module):
|
| 709 |
+
if isinstance(module, torch.nn.DataParallel):
|
| 710 |
+
module = module.module
|
| 711 |
+
for submodule in module.children():
|
| 712 |
+
if hasattr(submodule, "_parameters"):
|
| 713 |
+
parameters = submodule._parameters
|
| 714 |
+
if "weight" in parameters:
|
| 715 |
+
return parameters["weight"].device
|
| 716 |
+
|
| 717 |
if __name__ == "__main__":
|
| 718 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
| 719 |
parser.add_argument("--debug", action="store_true", help="using debug mode")
|
|
|
|
| 730 |
load_ram_model()
|
| 731 |
|
| 732 |
os.system("pip list")
|
| 733 |
+
print(f'groundingdino_model__{get_model_device(groundingdino_model)}')
|
| 734 |
+
print(f'sam_model__{get_model_device(sam_model)}')
|
| 735 |
+
print(f'sd_model__{get_model_device(sd_pipe)}')
|
| 736 |
+
print(f'lama_cleaner_model__{get_model_device(lama_cleaner_model)}')
|
| 737 |
+
print(f'ram_model__{get_model_device(ram_model)}')
|
| 738 |
+
|
| 739 |
block = gr.Blocks().queue()
|
| 740 |
with block:
|
| 741 |
with gr.Row():
|