Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Commit 
							
							·
						
						d1bffba
	
1
								Parent(s):
							
							297686d
								
adding app with CLIP image segmentation
Browse files- app.py +93 -0
- images/image2.png +0 -0
- images/room.jpg +0 -0
- requirements.txt +11 -0
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,93 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from turtle import title
         | 
| 2 | 
            +
            import os 
         | 
| 3 | 
            +
            import gradio as gr
         | 
| 4 | 
            +
            from transformers import pipeline
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from PIL import Image
         | 
| 7 | 
            +
            import torch 
         | 
| 8 | 
            +
            import cv2 
         | 
| 9 | 
            +
            from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
         | 
| 10 | 
            +
            from skimage.measure import label, regionprops
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
         | 
| 13 | 
            +
            model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
         | 
| 14 | 
            +
            classes = list()
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def create_mask(image,image_mask,alpha=0.7):
         | 
| 17 | 
            +
                mask = np.zeros_like(image)
         | 
| 18 | 
            +
                # copy your image_mask to all dimensions (i.e. colors) of your image
         | 
| 19 | 
            +
                for i in range(3): 
         | 
| 20 | 
            +
                    mask[:,:,i] = image_mask.copy()
         | 
| 21 | 
            +
                # apply the mask to your image
         | 
| 22 | 
            +
                overlay_image = cv2.addWeighted(mask,alpha,image,1-alpha,0)
         | 
| 23 | 
            +
                return overlay_image
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
         | 
| 26 | 
            +
                bbox = np.asarray(bbox)/model_shape
         | 
| 27 | 
            +
                y1,y2 = bbox[::2] *orig_image_shape[0]
         | 
| 28 | 
            +
                x1,x2 = bbox[1::2]*orig_image_shape[1]
         | 
| 29 | 
            +
                return [int(y1),int(x1),int(y2),int(x2)]
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            def detect_using_clip(image,prompts=[],threshould=0.4):
         | 
| 32 | 
            +
                model_detections = dict()
         | 
| 33 | 
            +
                predicted_images = dict()
         | 
| 34 | 
            +
                inputs = processor(
         | 
| 35 | 
            +
                    text=prompts,
         | 
| 36 | 
            +
                    images=[image] * len(prompts),
         | 
| 37 | 
            +
                    padding="max_length",
         | 
| 38 | 
            +
                    return_tensors="pt",
         | 
| 39 | 
            +
                )
         | 
| 40 | 
            +
                with torch.no_grad():  # Use 'torch.no_grad()' to disable gradient computation
         | 
| 41 | 
            +
                    outputs = model(**inputs)
         | 
| 42 | 
            +
                preds = outputs.logits.unsqueeze(1)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                detection = outputs.logits[0]  # Assuming class index 0
         | 
| 45 | 
            +
                for i,prompt in enumerate(prompts):
         | 
| 46 | 
            +
                    predicted_image =  torch.sigmoid(preds[i][0]).detach().cpu().numpy()
         | 
| 47 | 
            +
                    predicted_image = np.where(predicted_image>threshould,255,0)
         | 
| 48 | 
            +
                    # extract countours from the image
         | 
| 49 | 
            +
                    lbl_0 = label(predicted_image)
         | 
| 50 | 
            +
                    props = regionprops(lbl_0)
         | 
| 51 | 
            +
                    prompt = prompt.lower()
         | 
| 52 | 
            +
                    model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
         | 
| 53 | 
            +
                    predicted_images[prompt]= cv2.resize(predicted_image,image.shape[:2])
         | 
| 54 | 
            +
                return model_detections , predicted_images
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            def visualize_images(image,detections,predicted_image,prompt):
         | 
| 57 | 
            +
                alpha = 0.7
         | 
| 58 | 
            +
                H,W = image.shape[:2]
         | 
| 59 | 
            +
                prompt = prompt.lower()
         | 
| 60 | 
            +
                image_copy = image.copy()
         | 
| 61 | 
            +
                mask_image = create_mask(image=image_copy,image_mask=predicted_image)
         | 
| 62 | 
            +
                
         | 
| 63 | 
            +
                if prompt not in detections.keys():
         | 
| 64 | 
            +
                    print("prompt not in query ..")
         | 
| 65 | 
            +
                    return image_copy
         | 
| 66 | 
            +
                for bbox in detections[prompt]:
         | 
| 67 | 
            +
                    cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
         | 
| 68 | 
            +
                    cv2.putText(image_copy,str(prompt),(int(bbox[1]), int(bbox[0])),cv2.FONT_HERSHEY_SIMPLEX, 2, 255)
         | 
| 69 | 
            +
                final_image = cv2.addWeighted(image_copy,alpha,mask_image,1-alpha,0)
         | 
| 70 | 
            +
                return final_image
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            def shot(image, labels_text,selected_categoty):
         | 
| 73 | 
            +
                prompts = labels_text.split(',')
         | 
| 74 | 
            +
                prompts = list(map(lambda x: x.strip(),prompts))
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                model_detections,predicted_images  = detect_using_clip(image,prompts=prompts)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                category_image = visualize_images(image=image,detections=model_detections,predicted_image=predicted_images,prompt=selected_categoty)
         | 
| 79 | 
            +
                return category_image
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            iface = gr.Interface(fn=shot,
         | 
| 82 | 
            +
                                inputs = ["image","text","text"],
         | 
| 83 | 
            +
                                outputs = "image",
         | 
| 84 | 
            +
                                description ="Add an Image and list of category to be detected separated by commas",
         | 
| 85 | 
            +
                                title = "Zero-shot Image Classification with Prompt ",
         | 
| 86 | 
            +
                                examples=[
         | 
| 87 | 
            +
                                    ["images/room.jpg","bed, table, plant, light, window",'plant'],
         | 
| 88 | 
            +
                                    ["images/image2.png","banner, building,door, sign","sign"]
         | 
| 89 | 
            +
                                    ],
         | 
| 90 | 
            +
                                # allow_flagging=False, 
         | 
| 91 | 
            +
                                # analytics_enabled=False,
         | 
| 92 | 
            +
                            )
         | 
| 93 | 
            +
            iface.launch()
         | 
    	
        images/image2.png
    ADDED
    
    |   | 
    	
        images/room.jpg
    ADDED
    
    |   | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,11 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            transformers
         | 
| 2 | 
            +
            torch
         | 
| 3 | 
            +
            sentencepiece
         | 
| 4 | 
            +
            huggingface_hub
         | 
| 5 | 
            +
            numpy
         | 
| 6 | 
            +
            scikit-image
         | 
| 7 | 
            +
            opencv-python
         | 
| 8 | 
            +
            Pillow
         | 
| 9 | 
            +
            requests
         | 
| 10 | 
            +
            urllib3<2
         | 
| 11 | 
            +
            git+https://github.com/facebookresearch/segment-anything.git
         | 
