Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update app.py
Browse filesCleanup + cuda init update
    	
        app.py
    CHANGED
    
    | 
         @@ -21,16 +21,33 @@ def workaround_fixed_get_imports(filename: str | os.PathLike) -> list[str]: 
     | 
|
| 21 | 
         
             
                imports.remove("flash_attn")
         
     | 
| 22 | 
         
             
                return imports
         
     | 
| 23 | 
         | 
| 24 | 
         
            -
             
     | 
| 25 | 
         
            -
             
     | 
| 26 | 
         
            -
            with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports):
         
     | 
| 27 | 
         
            -
             
     | 
| 28 | 
         
            -
             
     | 
| 29 | 
         
            -
             
     | 
| 30 | 
         
            -
             
     | 
| 31 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 32 | 
         | 
| 33 | 
         
             
            def run_example(task_prompt, image, text_input=None):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 34 | 
         
             
                prompt = task_prompt if text_input is None else task_prompt + text_input
         
     | 
| 35 | 
         
             
                inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
         
     | 
| 36 | 
         
             
                with torch.inference_mode():
         
     | 
| 
         @@ -38,6 +55,9 @@ def run_example(task_prompt, image, text_input=None): 
     | 
|
| 38 | 
         
             
                generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
         
     | 
| 39 | 
         
             
                return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1]))
         
     | 
| 40 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 41 | 
         
             
            def fig_to_pil(fig):
         
     | 
| 42 | 
         
             
                buf = io.BytesIO()
         
     | 
| 43 | 
         
             
                fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
         
     | 
| 
         @@ -85,7 +105,7 @@ def draw_ocr_bboxes(image, prediction): 
     | 
|
| 85 | 
         
             
                bboxes, labels = prediction['quad_boxes'], prediction['labels']
         
     | 
| 86 | 
         
             
                for box, label in zip(bboxes, labels):
         
     | 
| 87 | 
         
             
                    color = random.choice(colormap)
         
     | 
| 88 | 
         
            -
                    box_array = np.array(box).reshape(-1, 2) 
     | 
| 89 | 
         
             
                    polygon = patches.Polygon(box_array, edgecolor=color, fill=False, linewidth=2)
         
     | 
| 90 | 
         
             
                    ax.add_patch(polygon)
         
     | 
| 91 | 
         
             
                    plt.text(box_array[0, 0], box_array[0, 1], label, color='white', fontsize=10, bbox=dict(facecolor=color, alpha=0.8))
         
     | 
| 
         @@ -101,7 +121,7 @@ def plot_bbox(image, data): 
     | 
|
| 101 | 
         
             
                    draw.text((x1, y1), label, fill="white")
         
     | 
| 102 | 
         
             
                return np.array(img_draw)
         
     | 
| 103 | 
         | 
| 104 | 
         
            -
            @spaces.GPU 
     | 
| 105 | 
         
             
            def process_video(input_video_path, task_prompt):
         
     | 
| 106 | 
         
             
                cap = cv2.VideoCapture(input_video_path)
         
     | 
| 107 | 
         
             
                if not cap.isOpened():
         
     | 
| 
         @@ -118,7 +138,7 @@ def process_video(input_video_path, task_prompt): 
     | 
|
| 118 | 
         | 
| 119 | 
         
             
                processed_frames = 0
         
     | 
| 120 | 
         
             
                frame_results = []
         
     | 
| 121 | 
         
            -
                color_map = {} 
     | 
| 122 | 
         | 
| 123 | 
         
             
                def get_color(label):
         
     | 
| 124 | 
         
             
                    if label not in color_map:
         
     | 
| 
         @@ -229,6 +249,10 @@ def process_video_p(input_video, task, text_input): 
     | 
|
| 229 | 
         
             
                    return None, "Error: Video processing failed. Check logs above for info.", str(frame_results)
         
     | 
| 230 | 
         
             
                return result, result, str(frame_results)
         
     | 
| 231 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 232 | 
         
             
            with gr.Blocks() as demo:
         
     | 
| 233 | 
         
             
                gr.HTML("<h1><center>Microsoft Florence-2-large-ft</center></h1>")
         
     | 
| 234 | 
         | 
| 
         @@ -300,7 +324,16 @@ with gr.Blocks() as demo: 
     | 
|
| 300 | 
         | 
| 301 | 
         
             
                video_task_dropdown.change(fn=update_video_text_input, inputs=video_task_dropdown, outputs=video_text_input)
         
     | 
| 302 | 
         | 
| 303 | 
         
            -
                submit_btn.click( 
     | 
| 304 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 305 | 
         | 
| 306 | 
         
             
            demo.launch()
         
     | 
| 
         | 
|
| 21 | 
         
             
                imports.remove("flash_attn")
         
     | 
| 22 | 
         
             
                return imports
         
     | 
| 23 | 
         | 
| 24 | 
         
            +
            def load_model():
         
     | 
| 25 | 
         
            +
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 26 | 
         
            +
                with patch("transformers.dynamic_module_utils.get_imports", workaround_fixed_get_imports):
         
     | 
| 27 | 
         
            +
                    model = AutoModelForCausalLM.from_pretrained(
         
     | 
| 28 | 
         
            +
                        "microsoft/Florence-2-large-ft", 
         
     | 
| 29 | 
         
            +
                        trust_remote_code=True
         
     | 
| 30 | 
         
            +
                    ).to(device).eval()
         
     | 
| 31 | 
         
            +
                    processor = AutoProcessor.from_pretrained(
         
     | 
| 32 | 
         
            +
                        "microsoft/Florence-2-large-ft", 
         
     | 
| 33 | 
         
            +
                        trust_remote_code=True
         
     | 
| 34 | 
         
            +
                    )
         
     | 
| 35 | 
         
            +
                return model, processor, device
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            model = None
         
     | 
| 38 | 
         
            +
            processor = None
         
     | 
| 39 | 
         
            +
            device = None
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            @spaces.GPU
         
     | 
| 42 | 
         
            +
            def initialize_model():
         
     | 
| 43 | 
         
            +
                global model, processor, device
         
     | 
| 44 | 
         
            +
                model, processor, device = load_model()
         
     | 
| 45 | 
         | 
| 46 | 
         
             
            def run_example(task_prompt, image, text_input=None):
         
     | 
| 47 | 
         
            +
                global model, processor, device
         
     | 
| 48 | 
         
            +
                if model is None or processor is None:
         
     | 
| 49 | 
         
            +
                    initialize_model()
         
     | 
| 50 | 
         
            +
                
         
     | 
| 51 | 
         
             
                prompt = task_prompt if text_input is None else task_prompt + text_input
         
     | 
| 52 | 
         
             
                inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
         
     | 
| 53 | 
         
             
                with torch.inference_mode():
         
     | 
| 
         | 
|
| 55 | 
         
             
                generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
         
     | 
| 56 | 
         
             
                return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1]))
         
     | 
| 57 | 
         | 
| 58 | 
         
            +
            colormap = ['blue', 'orange', 'green', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'red',
         
     | 
| 59 | 
         
            +
                        'lime', 'indigo', 'violet', 'aqua', 'magenta', 'coral', 'gold', 'tan', 'skyblue']
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
             
            def fig_to_pil(fig):
         
     | 
| 62 | 
         
             
                buf = io.BytesIO()
         
     | 
| 63 | 
         
             
                fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
         
     | 
| 
         | 
|
| 105 | 
         
             
                bboxes, labels = prediction['quad_boxes'], prediction['labels']
         
     | 
| 106 | 
         
             
                for box, label in zip(bboxes, labels):
         
     | 
| 107 | 
         
             
                    color = random.choice(colormap)
         
     | 
| 108 | 
         
            +
                    box_array = np.array(box).reshape(-1, 2)
         
     | 
| 109 | 
         
             
                    polygon = patches.Polygon(box_array, edgecolor=color, fill=False, linewidth=2)
         
     | 
| 110 | 
         
             
                    ax.add_patch(polygon)
         
     | 
| 111 | 
         
             
                    plt.text(box_array[0, 0], box_array[0, 1], label, color='white', fontsize=10, bbox=dict(facecolor=color, alpha=0.8))
         
     | 
| 
         | 
|
| 121 | 
         
             
                    draw.text((x1, y1), label, fill="white")
         
     | 
| 122 | 
         
             
                return np.array(img_draw)
         
     | 
| 123 | 
         | 
| 124 | 
         
            +
            @spaces.GPU
         
     | 
| 125 | 
         
             
            def process_video(input_video_path, task_prompt):
         
     | 
| 126 | 
         
             
                cap = cv2.VideoCapture(input_video_path)
         
     | 
| 127 | 
         
             
                if not cap.isOpened():
         
     | 
| 
         | 
|
| 138 | 
         | 
| 139 | 
         
             
                processed_frames = 0
         
     | 
| 140 | 
         
             
                frame_results = []
         
     | 
| 141 | 
         
            +
                color_map = {}
         
     | 
| 142 | 
         | 
| 143 | 
         
             
                def get_color(label):
         
     | 
| 144 | 
         
             
                    if label not in color_map:
         
     | 
| 
         | 
|
| 249 | 
         
             
                    return None, "Error: Video processing failed. Check logs above for info.", str(frame_results)
         
     | 
| 250 | 
         
             
                return result, result, str(frame_results)
         
     | 
| 251 | 
         | 
| 252 | 
         
            +
            @spaces.GPU
         
     | 
| 253 | 
         
            +
            def process_image_with_gpu(image, task, text):
         
     | 
| 254 | 
         
            +
                return process_image(image, task, text)
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
             
            with gr.Blocks() as demo:
         
     | 
| 257 | 
         
             
                gr.HTML("<h1><center>Microsoft Florence-2-large-ft</center></h1>")
         
     | 
| 258 | 
         | 
| 
         | 
|
| 324 | 
         | 
| 325 | 
         
             
                video_task_dropdown.change(fn=update_video_text_input, inputs=video_task_dropdown, outputs=video_text_input)
         
     | 
| 326 | 
         | 
| 327 | 
         
            +
                submit_btn.click(
         
     | 
| 328 | 
         
            +
                    fn=process_image_with_gpu,
         
     | 
| 329 | 
         
            +
                    inputs=[input_img, task_dropdown, text_input],
         
     | 
| 330 | 
         
            +
                    outputs=[output_text, output_image]
         
     | 
| 331 | 
         
            +
                )
         
     | 
| 332 | 
         
            +
                
         
     | 
| 333 | 
         
            +
                video_submit_btn.click(
         
     | 
| 334 | 
         
            +
                    fn=process_video_p,
         
     | 
| 335 | 
         
            +
                    inputs=[input_video, video_task_dropdown, video_text_input],
         
     | 
| 336 | 
         
            +
                    outputs=[output_video, output_video, frame_results_output]
         
     | 
| 337 | 
         
            +
                )
         
     | 
| 338 | 
         | 
| 339 | 
         
             
            demo.launch()
         
     |