Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Text location in screenshots (#12)
Browse files- Add screenshot text location feature (090950138af99f7877d5f7a6565213c0be493da2)
- Screenshot text location: pad images (35feaa002ab2d0ef4b3c1820b559d85ecb277fa3)
Co-authored-by: Pedro Cuenca <[email protected]>
- app.py +115 -14
- assets/localization_example_1.jpeg +0 -0
    	
        app.py
    CHANGED
    
    | @@ -1,9 +1,8 @@ | |
| 1 | 
             
            import gradio as gr
         | 
|  | |
| 2 | 
             
            import torch
         | 
| 3 | 
            -
            from transformers import FuyuForCausalLM, AutoTokenizer
         | 
| 4 | 
            -
            from transformers.models.fuyu.processing_fuyu import FuyuProcessor
         | 
| 5 | 
            -
            from transformers.models.fuyu.image_processing_fuyu import FuyuImageProcessor
         | 
| 6 | 
             
            from PIL import Image
         | 
|  | |
| 7 |  | 
| 8 | 
             
            model_id = "adept/fuyu-8b"
         | 
| 9 | 
             
            dtype = torch.bfloat16
         | 
| @@ -13,9 +12,10 @@ tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| 13 | 
             
            model = FuyuForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=dtype)
         | 
| 14 | 
             
            processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=tokenizer)
         | 
| 15 |  | 
| 16 | 
            -
             | 
|  | |
| 17 |  | 
| 18 | 
            -
            def resize_to_max(image, max_width= | 
| 19 | 
             
                width, height = image.size
         | 
| 20 | 
             
                if width <= max_width and height <= max_height:
         | 
| 21 | 
             
                    return image
         | 
| @@ -26,23 +26,101 @@ def resize_to_max(image, max_width=1080, max_height=1080): | |
| 26 |  | 
| 27 | 
             
                return image.resize((width, height), Image.LANCZOS)
         | 
| 28 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 29 | 
             
            def predict(image, prompt):
         | 
| 30 | 
             
                # image = image.convert('RGB')
         | 
| 31 | 
            -
                image = resize_to_max(image)
         | 
| 32 | 
            -
             | 
| 33 | 
             
                model_inputs = processor(text=prompt, images=[image])
         | 
| 34 | 
             
                model_inputs = {k: v.to(dtype=dtype if torch.is_floating_point(v) else v.dtype, device=device) for k,v in model_inputs.items()}
         | 
| 35 |  | 
| 36 | 
            -
                generation_output = model.generate(**model_inputs, max_new_tokens= | 
| 37 | 
             
                prompt_len = model_inputs["input_ids"].shape[-1]
         | 
| 38 | 
             
                return tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True)
         | 
| 39 |  | 
| 40 | 
            -
            def caption(image):
         | 
| 41 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
| 42 |  | 
| 43 | 
             
            def set_example_image(example: list) -> dict:
         | 
| 44 | 
             
                return gr.Image.update(value=example[0])
         | 
| 45 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 46 |  | 
| 47 |  | 
| 48 | 
             
            css = """
         | 
| @@ -88,21 +166,44 @@ with gr.Blocks(css=css) as demo: | |
| 88 |  | 
| 89 | 
             
                with gr.Tab("Image Captioning"):
         | 
| 90 | 
             
                    with gr.Row():
         | 
| 91 | 
            -
                         | 
|  | |
|  | |
| 92 | 
             
                        captioning_output = gr.Textbox(label="Output")
         | 
| 93 | 
             
                    captioning_btn = gr.Button("Generate Caption")
         | 
| 94 |  | 
| 95 | 
             
                    gr.Examples(
         | 
| 96 | 
            -
                        [["assets/captioning_example_1.png"], ["assets/captioning_example_2.png"]],
         | 
| 97 | 
            -
                        inputs = [captioning_input],
         | 
| 98 | 
             
                        outputs = [captioning_output],
         | 
| 99 | 
             
                        fn=caption,
         | 
| 100 | 
             
                        cache_examples=True,
         | 
| 101 | 
             
                        label='Click on any Examples below to get captioning results quickly 👇'
         | 
| 102 | 
             
                        )
         | 
| 103 |  | 
| 104 | 
            -
                captioning_btn.click(fn=caption, inputs=captioning_input, outputs=captioning_output)
         | 
| 105 | 
             
                vqa_btn.click(fn=predict, inputs=[image_input, text_input], outputs=vqa_output)
         | 
| 106 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 107 |  | 
| 108 | 
             
            demo.launch(server_name="0.0.0.0")
         | 
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
             
            import torch
         | 
|  | |
|  | |
|  | |
| 4 | 
             
            from PIL import Image
         | 
| 5 | 
            +
            from transformers import AutoTokenizer, FuyuForCausalLM, FuyuImageProcessor, FuyuProcessor
         | 
| 6 |  | 
| 7 | 
             
            model_id = "adept/fuyu-8b"
         | 
| 8 | 
             
            dtype = torch.bfloat16
         | 
|  | |
| 12 | 
             
            model = FuyuForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=dtype)
         | 
| 13 | 
             
            processor = FuyuProcessor(image_processor=FuyuImageProcessor(), tokenizer=tokenizer)
         | 
| 14 |  | 
| 15 | 
            +
            CAPTION_PROMPT = "Generate a coco-style caption.\n"
         | 
| 16 | 
            +
            DETAILED_CAPTION_PROMPT = "What is happening in this image?"
         | 
| 17 |  | 
| 18 | 
            +
            def resize_to_max(image, max_width=1920, max_height=1080):
         | 
| 19 | 
             
                width, height = image.size
         | 
| 20 | 
             
                if width <= max_width and height <= max_height:
         | 
| 21 | 
             
                    return image
         | 
|  | |
| 26 |  | 
| 27 | 
             
                return image.resize((width, height), Image.LANCZOS)
         | 
| 28 |  | 
| 29 | 
            +
            def pad_to_size(image, canvas_width=1920, canvas_height=1080):
         | 
| 30 | 
            +
                width, height = image.size
         | 
| 31 | 
            +
                if width >= canvas_width and height >= canvas_height:
         | 
| 32 | 
            +
                    return image
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                # Paste at (0, 0)
         | 
| 35 | 
            +
                canvas = Image.new("RGB", (canvas_width, canvas_height))
         | 
| 36 | 
            +
                canvas.paste(image)
         | 
| 37 | 
            +
                return canvas
         | 
| 38 | 
            +
             | 
| 39 | 
             
            def predict(image, prompt):
         | 
| 40 | 
             
                # image = image.convert('RGB')
         | 
|  | |
|  | |
| 41 | 
             
                model_inputs = processor(text=prompt, images=[image])
         | 
| 42 | 
             
                model_inputs = {k: v.to(dtype=dtype if torch.is_floating_point(v) else v.dtype, device=device) for k,v in model_inputs.items()}
         | 
| 43 |  | 
| 44 | 
            +
                generation_output = model.generate(**model_inputs, max_new_tokens=50)
         | 
| 45 | 
             
                prompt_len = model_inputs["input_ids"].shape[-1]
         | 
| 46 | 
             
                return tokenizer.decode(generation_output[0][prompt_len:], skip_special_tokens=True)
         | 
| 47 |  | 
| 48 | 
            +
            def caption(image, detailed_captioning):
         | 
| 49 | 
            +
                if detailed_captioning:
         | 
| 50 | 
            +
                    caption_prompt = DETAILED_CAPTION_PROMPT
         | 
| 51 | 
            +
                else:
         | 
| 52 | 
            +
                    caption_prompt = CAPTION_PROMPT
         | 
| 53 | 
            +
                return predict(image, caption_prompt).lstrip()
         | 
| 54 |  | 
| 55 | 
             
            def set_example_image(example: list) -> dict:
         | 
| 56 | 
             
                return gr.Image.update(value=example[0])
         | 
| 57 |  | 
| 58 | 
            +
            def scale_factor_to_fit(original_size, target_size=(1920, 1080)):
         | 
| 59 | 
            +
                width, height = original_size
         | 
| 60 | 
            +
                max_width, max_height = target_size
         | 
| 61 | 
            +
                if width <= max_width and height <= max_height:
         | 
| 62 | 
            +
                    return 1.0
         | 
| 63 | 
            +
                return min(max_width/width, max_height/height)
         | 
| 64 | 
            +
                
         | 
| 65 | 
            +
            def tokens_to_box(tokens, original_size):
         | 
| 66 | 
            +
                bbox_start = tokenizer.convert_tokens_to_ids("<0x00>")
         | 
| 67 | 
            +
                bbox_end = tokenizer.convert_tokens_to_ids("<0x01>")
         | 
| 68 | 
            +
                try:
         | 
| 69 | 
            +
                    # Assumes a single box
         | 
| 70 | 
            +
                    bbox_start_pos = (tokens == bbox_start).nonzero(as_tuple=True)[0].item()
         | 
| 71 | 
            +
                    bbox_end_pos = (tokens == bbox_end).nonzero(as_tuple=True)[0].item()
         | 
| 72 | 
            +
                    
         | 
| 73 | 
            +
                    if bbox_end_pos != bbox_start_pos + 5:
         | 
| 74 | 
            +
                        return tokens
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    # Retrieve transformed coordinates from tokens
         | 
| 77 | 
            +
                    coords = tokenizer.convert_ids_to_tokens(tokens[bbox_start_pos+1:bbox_end_pos])
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    # Scale back to original image size and multiply by 2
         | 
| 80 | 
            +
                    scale = scale_factor_to_fit(original_size)
         | 
| 81 | 
            +
                    top, left, bottom, right = [2 * int(float(c)/scale) for c in coords]
         | 
| 82 | 
            +
                    
         | 
| 83 | 
            +
                    # Replace the IDs so they get detokenized right
         | 
| 84 | 
            +
                    replacement = f" <box>{top}, {left}, {bottom}, {right}</box>"
         | 
| 85 | 
            +
                    replacement = tokenizer.tokenize(replacement)[1:]
         | 
| 86 | 
            +
                    replacement = tokenizer.convert_tokens_to_ids(replacement)
         | 
| 87 | 
            +
                    replacement = torch.tensor(replacement).to(tokens)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    tokens = torch.cat([tokens[:bbox_start_pos], replacement, tokens[bbox_end_pos+1:]], 0)
         | 
| 90 | 
            +
                    return tokens
         | 
| 91 | 
            +
                except:
         | 
| 92 | 
            +
                    gr.Error("Can't convert tokens.")
         | 
| 93 | 
            +
                    return tokens
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            def coords_from_response(response):
         | 
| 96 | 
            +
                # y1, x1, y2, x2
         | 
| 97 | 
            +
                pattern = r"<box>(\d+),\s*(\d+),\s*(\d+),\s*(\d+)</box>"
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                match = re.search(pattern, response)
         | 
| 100 | 
            +
                if match:
         | 
| 101 | 
            +
                    # Unpack and change order
         | 
| 102 | 
            +
                    y1, x1, y2, x2 = [int(coord) for coord in match.groups()]
         | 
| 103 | 
            +
                    return (x1, y1, x2, y2)
         | 
| 104 | 
            +
                else:
         | 
| 105 | 
            +
                    gr.Error("The string is malformed or does not match the expected pattern.")
         | 
| 106 | 
            +
                    
         | 
| 107 | 
            +
            def localize(image, query):
         | 
| 108 | 
            +
                prompt = f"When presented with a box, perform OCR to extract text contained within it. If provided with text, generate the corresponding bounding box.\n{query}"
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                # Downscale and/or pad to 1920x1080
         | 
| 111 | 
            +
                padded = resize_to_max(image)
         | 
| 112 | 
            +
                padded = pad_to_size(padded)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                model_inputs = processor(text=prompt, images=[padded])
         | 
| 115 | 
            +
                model_inputs = {k: v.to(dtype=dtype if torch.is_floating_point(v) else v.dtype, device=device) for k,v in model_inputs.items()}
         | 
| 116 | 
            +
                
         | 
| 117 | 
            +
                generation_output = model.generate(**model_inputs, max_new_tokens=40)
         | 
| 118 | 
            +
                prompt_len = model_inputs["input_ids"].shape[-1]
         | 
| 119 | 
            +
                tokens = generation_output[0][prompt_len:]
         | 
| 120 | 
            +
                tokens = tokens_to_box(tokens, image.size)
         | 
| 121 | 
            +
                decoded = tokenizer.decode(tokens, skip_special_tokens=True)
         | 
| 122 | 
            +
                coords = coords_from_response(decoded)
         | 
| 123 | 
            +
                return image, [(coords, f"Location of \"{query}\"")]
         | 
| 124 |  | 
| 125 |  | 
| 126 | 
             
            css = """
         | 
|  | |
| 166 |  | 
| 167 | 
             
                with gr.Tab("Image Captioning"):
         | 
| 168 | 
             
                    with gr.Row():
         | 
| 169 | 
            +
                        with gr.Column():
         | 
| 170 | 
            +
                            captioning_input = gr.Image(label="Upload your Image", type="pil")
         | 
| 171 | 
            +
                            detailed_captioning_checkbox = gr.Checkbox(label="Enable detailed captioning")
         | 
| 172 | 
             
                        captioning_output = gr.Textbox(label="Output")
         | 
| 173 | 
             
                    captioning_btn = gr.Button("Generate Caption")
         | 
| 174 |  | 
| 175 | 
             
                    gr.Examples(
         | 
| 176 | 
            +
                        [["assets/captioning_example_1.png", False], ["assets/captioning_example_2.png", True]],
         | 
| 177 | 
            +
                        inputs = [captioning_input, detailed_captioning_checkbox],
         | 
| 178 | 
             
                        outputs = [captioning_output],
         | 
| 179 | 
             
                        fn=caption,
         | 
| 180 | 
             
                        cache_examples=True,
         | 
| 181 | 
             
                        label='Click on any Examples below to get captioning results quickly 👇'
         | 
| 182 | 
             
                        )
         | 
| 183 |  | 
| 184 | 
            +
                captioning_btn.click(fn=caption, inputs=[captioning_input, detailed_captioning_checkbox], outputs=captioning_output)
         | 
| 185 | 
             
                vqa_btn.click(fn=predict, inputs=[image_input, text_input], outputs=vqa_output)
         | 
| 186 |  | 
| 187 | 
            +
                with gr.Tab("Find Text in Screenshots"):
         | 
| 188 | 
            +
                    with gr.Row():
         | 
| 189 | 
            +
                        with gr.Column():
         | 
| 190 | 
            +
                            localization_input = gr.Image(label="Upload your Image", type="pil")
         | 
| 191 | 
            +
                            query_input = gr.Textbox(label="Text to find")
         | 
| 192 | 
            +
                            localization_btn = gr.Button("Locate Text")
         | 
| 193 | 
            +
                        with gr.Column():
         | 
| 194 | 
            +
                            with gr.Row(height=800):
         | 
| 195 | 
            +
                                localization_output = gr.AnnotatedImage(label="Text Position")
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    gr.Examples(
         | 
| 198 | 
            +
                        [["assets/localization_example_1.jpeg", "Share your repair"],
         | 
| 199 | 
            +
                         ["assets/screen2words_ui_example.png", "statistics"]],
         | 
| 200 | 
            +
                        inputs = [localization_input, query_input],
         | 
| 201 | 
            +
                        outputs = [localization_output],
         | 
| 202 | 
            +
                        fn=localize,
         | 
| 203 | 
            +
                        cache_examples=True,
         | 
| 204 | 
            +
                        label='Click on any Examples below to get localization results quickly 👇'
         | 
| 205 | 
            +
                        )
         | 
| 206 | 
            +
                
         | 
| 207 | 
            +
                localization_btn.click(fn=localize, inputs=[localization_input, query_input], outputs=localization_output)   
         | 
| 208 |  | 
| 209 | 
             
            demo.launch(server_name="0.0.0.0")
         | 
    	
        assets/localization_example_1.jpeg
    ADDED
    
    |   | 
 
			

 
		