Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,3 +1,4 @@ | |
|  | |
| 1 | 
             
            from typing import Optional
         | 
| 2 | 
             
            import spaces
         | 
| 3 | 
             
            import gradio as gr
         | 
| @@ -73,7 +74,7 @@ This demo is powered by [Gradio](https://gradio.app/) and uses OmniParserv2 to g | |
| 73 | 
             
            DEVICE = torch.device('cuda')  
         | 
| 74 |  | 
| 75 | 
             
            @spaces.GPU
         | 
| 76 | 
            -
            @torch.inference_mode()
         | 
| 77 | 
             
            def get_som_response(instruction, image_som):
         | 
| 78 | 
             
                prompt = magma_som_prompt.format(instruction)
         | 
| 79 | 
             
                if magam_model.config.mm_use_image_start_end:
         | 
| @@ -110,7 +111,7 @@ def get_som_response(instruction, image_som): | |
| 110 | 
             
                return response
         | 
| 111 |  | 
| 112 | 
             
            @spaces.GPU
         | 
| 113 | 
            -
            @torch.inference_mode()
         | 
| 114 | 
             
            def get_qa_response(instruction, image):
         | 
| 115 | 
             
                prompt = magma_qa_prompt.format(instruction)
         | 
| 116 | 
             
                if magam_model.config.mm_use_image_start_end:
         | 
| @@ -147,7 +148,7 @@ def get_qa_response(instruction, image): | |
| 147 | 
             
                return response
         | 
| 148 |  | 
| 149 | 
             
            @spaces.GPU
         | 
| 150 | 
            -
            @torch.inference_mode()
         | 
| 151 | 
             
            # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
         | 
| 152 | 
             
            def process(
         | 
| 153 | 
             
                image_input,
         | 
| @@ -158,98 +159,103 @@ def process( | |
| 158 | 
             
                instruction,
         | 
| 159 | 
             
            ) -> Optional[Image.Image]:
         | 
| 160 |  | 
| 161 | 
            -
                 | 
| 162 | 
            -
             | 
| 163 | 
            -
             | 
| 164 | 
            -
             | 
| 165 | 
            -
             | 
| 166 | 
            -
                     | 
| 167 | 
            -
             | 
| 168 | 
            -
             | 
| 169 | 
            -
             | 
| 170 | 
            -
             | 
| 171 | 
            -
             | 
| 172 | 
            -
             | 
| 173 | 
            -
             | 
| 174 | 
            -
             | 
| 175 | 
            -
             | 
| 176 | 
            -
             | 
| 177 | 
            -
             | 
| 178 | 
            -
                     | 
| 179 | 
            -
             | 
| 180 | 
            -
             | 
| 181 | 
            -
             | 
| 182 | 
            -
             | 
| 183 | 
            -
                     | 
| 184 | 
            -
             | 
| 185 | 
            -
             | 
| 186 | 
            -
             | 
| 187 | 
            -
             | 
| 188 | 
            -
             | 
| 189 | 
            -
             | 
| 190 | 
            -
                     | 
| 191 | 
            -
                         | 
| 192 | 
            -
             | 
| 193 | 
            -
             | 
| 194 | 
            -
             | 
| 195 | 
            -
             | 
| 196 | 
            -
             | 
| 197 | 
            -
                     | 
| 198 | 
            -
             | 
| 199 | 
            -
             | 
| 200 | 
            -
             | 
| 201 | 
            -
                     | 
| 202 | 
            -
             | 
| 203 | 
            -
             | 
| 204 | 
            -
             | 
| 205 | 
            -
             | 
| 206 | 
            -
             | 
| 207 | 
            -
             | 
| 208 | 
            -
             | 
| 209 | 
            -
                    if  | 
| 210 | 
            -
                         | 
|  | |
|  | |
|  | |
| 211 | 
             
                    else:
         | 
| 212 | 
             
                        bbox_for_mark = None
         | 
| 213 | 
            -
             | 
| 214 | 
            -
                    bbox_for_mark | 
| 215 | 
            -
             | 
| 216 | 
            -
             | 
| 217 | 
            -
                    # draw bbox_for_mark on the image
         | 
| 218 | 
            -
                    image_som = plot_boxes_with_marks(
         | 
| 219 | 
            -
                        image_input, 
         | 
| 220 | 
            -
                        [label_coordinates_yxhw[str(mark_id)]], 
         | 
| 221 | 
            -
                        som_generator, 
         | 
| 222 | 
            -
                        edgecolor=(255,127,111), 
         | 
| 223 | 
            -
                        alpha=30, 
         | 
| 224 | 
            -
                        fn_save=None, 
         | 
| 225 | 
            -
                        normalized_to_pixel=False,
         | 
| 226 | 
            -
                        add_mark=False
         | 
| 227 | 
            -
                    )
         | 
| 228 | 
            -
                else:
         | 
| 229 | 
            -
                    try:
         | 
| 230 | 
            -
                        if 'box' in magma_response:
         | 
| 231 | 
            -
                            pred_bbox = extract_bbox(magma_response)
         | 
| 232 | 
            -
                            click_point = [(pred_bbox[0][0] + pred_bbox[1][0]) / 2, (pred_bbox[0][1] + pred_bbox[1][1]) / 2]
         | 
| 233 | 
            -
                            click_point = [item / 1000 for item in click_point]
         | 
| 234 | 
            -
                        else:
         | 
| 235 | 
            -
                            click_point = pred_2_point(magma_response)
         | 
| 236 | 
            -
                        # de-normalize click_point (width, height)
         | 
| 237 | 
            -
                        click_point = [click_point[0] * image_input.size[0], click_point[1] * image_input.size[1]]
         | 
| 238 | 
            -
             | 
| 239 | 
            -
                        image_som = plot_circles_with_marks(
         | 
| 240 | 
             
                            image_input, 
         | 
| 241 | 
            -
                            [ | 
| 242 | 
            -
                            som_generator,
         | 
| 243 | 
             
                            edgecolor=(255,127,111), 
         | 
| 244 | 
            -
                             | 
| 245 | 
            -
                            fn_save=None,
         | 
| 246 | 
             
                            normalized_to_pixel=False,
         | 
| 247 | 
             
                            add_mark=False
         | 
| 248 | 
             
                        )
         | 
| 249 | 
            -
                     | 
| 250 | 
            -
                         | 
| 251 | 
            -
             | 
| 252 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 253 |  | 
| 254 | 
             
            with gr.Blocks() as demo:
         | 
| 255 | 
             
                gr.Markdown(MARKDOWN)
         | 
| @@ -291,4 +297,4 @@ with gr.Blocks() as demo: | |
| 291 |  | 
| 292 | 
             
            demo.launch(debug=True, show_error=True, share=True)
         | 
| 293 | 
             
            # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
         | 
| 294 | 
            -
            # demo.queue().launch(share=False)
         | 
|  | |
| 1 | 
            +
            import traceback
         | 
| 2 | 
             
            from typing import Optional
         | 
| 3 | 
             
            import spaces
         | 
| 4 | 
             
            import gradio as gr
         | 
|  | |
| 74 | 
             
            DEVICE = torch.device('cuda')  
         | 
| 75 |  | 
| 76 | 
             
            @spaces.GPU
         | 
| 77 | 
            +
            # @torch.inference_mode()
         | 
| 78 | 
             
            def get_som_response(instruction, image_som):
         | 
| 79 | 
             
                prompt = magma_som_prompt.format(instruction)
         | 
| 80 | 
             
                if magam_model.config.mm_use_image_start_end:
         | 
|  | |
| 111 | 
             
                return response
         | 
| 112 |  | 
| 113 | 
             
            @spaces.GPU
         | 
| 114 | 
            +
            # @torch.inference_mode()
         | 
| 115 | 
             
            def get_qa_response(instruction, image):
         | 
| 116 | 
             
                prompt = magma_qa_prompt.format(instruction)
         | 
| 117 | 
             
                if magam_model.config.mm_use_image_start_end:
         | 
|  | |
| 148 | 
             
                return response
         | 
| 149 |  | 
| 150 | 
             
            @spaces.GPU
         | 
| 151 | 
            +
            # @torch.inference_mode()
         | 
| 152 | 
             
            # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
         | 
| 153 | 
             
            def process(
         | 
| 154 | 
             
                image_input,
         | 
|  | |
| 159 | 
             
                instruction,
         | 
| 160 | 
             
            ) -> Optional[Image.Image]:
         | 
| 161 |  | 
| 162 | 
            +
                try:
         | 
| 163 | 
            +
                    # image_save_path = 'imgs/saved_image_demo.png'
         | 
| 164 | 
            +
                    # image_input.save(image_save_path)
         | 
| 165 | 
            +
                    # image = Image.open(image_save_path)
         | 
| 166 | 
            +
                    box_overlay_ratio = image_input.size[0] / 3200
         | 
| 167 | 
            +
                    draw_bbox_config = {
         | 
| 168 | 
            +
                        'text_scale': 0.8 * box_overlay_ratio,
         | 
| 169 | 
            +
                        'text_thickness': max(int(2 * box_overlay_ratio), 1),
         | 
| 170 | 
            +
                        'text_padding': max(int(3 * box_overlay_ratio), 1),
         | 
| 171 | 
            +
                        'thickness': max(int(3 * box_overlay_ratio), 1),
         | 
| 172 | 
            +
                    }
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_input, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr)
         | 
| 175 | 
            +
                    text, ocr_bbox = ocr_bbox_rslt
         | 
| 176 | 
            +
                    dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_input, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold, imgsz=imgsz,)  
         | 
| 177 | 
            +
                    parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i,v in enumerate(parsed_content_list)])
         | 
| 178 | 
            +
                    
         | 
| 179 | 
            +
                    if len(instruction) == 0:
         | 
| 180 | 
            +
                        print('finish processing')
         | 
| 181 | 
            +
                        image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
         | 
| 182 | 
            +
                        return image, str(parsed_content_list)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    elif instruction.startswith('Q:'):
         | 
| 185 | 
            +
                        response = get_qa_response(instruction, image_input)
         | 
| 186 | 
            +
                        return image_input, response
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    # parsed_content_list = str(parsed_content_list)
         | 
| 189 | 
            +
                    # convert xywh to yxhw
         | 
| 190 | 
            +
                    label_coordinates_yxhw = {}
         | 
| 191 | 
            +
                    for key, val in label_coordinates.items():
         | 
| 192 | 
            +
                        if val[2] < 0 or val[3] < 0:
         | 
| 193 | 
            +
                            continue
         | 
| 194 | 
            +
                        label_coordinates_yxhw[key] = [val[1], val[0], val[3], val[2]]
         | 
| 195 | 
            +
                    image_som = plot_boxes_with_marks(image_input.copy(), [val for key, val in label_coordinates_yxhw.items()], som_generator, edgecolor=(255,0,0), fn_save=None, normalized_to_pixel=False)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    # convert xywh to xyxy
         | 
| 198 | 
            +
                    for key, val in label_coordinates.items():
         | 
| 199 | 
            +
                        label_coordinates[key] = [val[0], val[1], val[0] + val[2], val[1] + val[3]]
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    # normalize label_coordinates
         | 
| 202 | 
            +
                    for key, val in label_coordinates.items():
         | 
| 203 | 
            +
                        label_coordinates[key] = [val[0] / image_input.size[0], val[1] / image_input.size[1], val[2] / image_input.size[0], val[3] / image_input.size[1]]
         | 
| 204 | 
            +
                    
         | 
| 205 | 
            +
                    magma_response = get_som_response(instruction, image_som)
         | 
| 206 | 
            +
                    print("magma repsonse: ", magma_response)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    # map magma_response into the mark id
         | 
| 209 | 
            +
                    mark_id = extract_mark_id(magma_response)
         | 
| 210 | 
            +
                    if mark_id is not None:
         | 
| 211 | 
            +
                        if str(mark_id) in label_coordinates:
         | 
| 212 | 
            +
                            bbox_for_mark = label_coordinates[str(mark_id)]
         | 
| 213 | 
            +
                        else:
         | 
| 214 | 
            +
                            bbox_for_mark = None
         | 
| 215 | 
             
                    else:
         | 
| 216 | 
             
                        bbox_for_mark = None
         | 
| 217 | 
            +
                    
         | 
| 218 | 
            +
                    if bbox_for_mark:
         | 
| 219 | 
            +
                        # draw bbox_for_mark on the image
         | 
| 220 | 
            +
                        image_som = plot_boxes_with_marks(
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 221 | 
             
                            image_input, 
         | 
| 222 | 
            +
                            [label_coordinates_yxhw[str(mark_id)]], 
         | 
| 223 | 
            +
                            som_generator, 
         | 
| 224 | 
             
                            edgecolor=(255,127,111), 
         | 
| 225 | 
            +
                            alpha=30, 
         | 
| 226 | 
            +
                            fn_save=None, 
         | 
| 227 | 
             
                            normalized_to_pixel=False,
         | 
| 228 | 
             
                            add_mark=False
         | 
| 229 | 
             
                        )
         | 
| 230 | 
            +
                    else:
         | 
| 231 | 
            +
                        try:
         | 
| 232 | 
            +
                            if 'box' in magma_response:
         | 
| 233 | 
            +
                                pred_bbox = extract_bbox(magma_response)
         | 
| 234 | 
            +
                                click_point = [(pred_bbox[0][0] + pred_bbox[1][0]) / 2, (pred_bbox[0][1] + pred_bbox[1][1]) / 2]
         | 
| 235 | 
            +
                                click_point = [item / 1000 for item in click_point]
         | 
| 236 | 
            +
                            else:
         | 
| 237 | 
            +
                                click_point = pred_2_point(magma_response)
         | 
| 238 | 
            +
                            # de-normalize click_point (width, height)
         | 
| 239 | 
            +
                            click_point = [click_point[0] * image_input.size[0], click_point[1] * image_input.size[1]]
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                            image_som = plot_circles_with_marks(
         | 
| 242 | 
            +
                                image_input, 
         | 
| 243 | 
            +
                                [click_point],
         | 
| 244 | 
            +
                                som_generator,
         | 
| 245 | 
            +
                                edgecolor=(255,127,111), 
         | 
| 246 | 
            +
                                linewidth=3,
         | 
| 247 | 
            +
                                fn_save=None,
         | 
| 248 | 
            +
                                normalized_to_pixel=False,
         | 
| 249 | 
            +
                                add_mark=False
         | 
| 250 | 
            +
                            )
         | 
| 251 | 
            +
                        except:
         | 
| 252 | 
            +
                            image_som = image_input
         | 
| 253 | 
            +
                    
         | 
| 254 | 
            +
                    return image_som, str(parsed_content_list)
         | 
| 255 | 
            +
                except Exception as e:
         | 
| 256 | 
            +
                    print('error in process')
         | 
| 257 | 
            +
                    traceback.print_exc()
         | 
| 258 | 
            +
                    return image_input, 'error in process'
         | 
| 259 |  | 
| 260 | 
             
            with gr.Blocks() as demo:
         | 
| 261 | 
             
                gr.Markdown(MARKDOWN)
         | 
|  | |
| 297 |  | 
| 298 | 
             
            demo.launch(debug=True, show_error=True, share=True)
         | 
| 299 | 
             
            # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
         | 
| 300 | 
            +
            # demo.queue().launch(share=False)
         | 
