Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 4,907 Bytes
			
			| 05a002a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw
import requests
from transformers import SamModel, SamProcessor
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import cv2
from typing import List
device = "cuda" if torch.cuda.is_available() else "cpu"
#Load clipseg Model
clip_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
clip_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
# Load SAM model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
cache_data = None
# Prompts to segment damaged area and car
prompts = ['damaged', 'car']
damage_threshold = 0.4
vehicle_threshold = 0.5
def bbox_normalization(bbox, width, height):
    height_coeff = height/352
    width_coeff = width/352
    normalized_bbox = [int(bbox[0]*width_coeff), int(bbox[1]*height_coeff),
                       int(bbox[2]*width_coeff), int(bbox[3]*height_coeff)]
    return normalized_bbox
def bbox_area(bbox):
    area = (bbox[2]-bbox[0])*(bbox[3]-bbox[1])
    return area
def segment_to_bbox(segment_indexs):
    x_points = []
    y_points = []
    for y, list_val in enumerate(segment_indexs):
        for x, val in enumerate(list_val):
            if val == 1:
                x_points.append(x)
                y_points.append(y)
    return [np.min(x_points), np.min(y_points), np.max(x_points), np.max(y_points)]
def clipseg_prediction(image):
    inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")
    # predict
    with torch.no_grad():
        outputs = model(**inputs)
    preds = outputs.logits.unsqueeze(1)
    # Setting threshold and classify the image contains vehicle or not
    flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1))
    # Initialize a dummy "unlabeled" mask with the threshold
    flat_damage_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), damage_threshold)
    flat_vehicle_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), vehicle_threshold)
    flat_damage_preds_with_treshold[1:2,:] = flat_preds[0] # damage
    flat_vehicle_preds_with_treshold[1:2,:] = flat_preds[1] # vehicle
    # Get the top mask index for each pixel
    damage_inds = torch.topk(flat_damage_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
    vehicle_inds = torch.topk(flat_vehicle_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1]))
    # bbox creation
    damage_bbox = segment_to_bbox(damage_inds)
    vehicle_bbox = segment_to_bbox(vehicle_inds)
    # Vehicle checking
    if bbox_area(vehicle_bbox) > bbox_area(damage_bbox):
        return True, bbox_normalization(damage_bbox)
    else:
        return False, []
     
@torch.no_grad()
def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
    global cache_data
    image_input = Image.fromarray(image_input)
    inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
    if not cache_data or not torch.equal(inputs['pixel_values'],cache_data[0]):
        embedding = model.get_image_embeddings(inputs["pixel_values"])
        pixels = inputs["pixel_values"]
        cache_data = [pixels, embedding]
    del inputs["pixel_values"]
    outputs = model.forward(image_embeddings=cache_data[1], **inputs)
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
    )
    masks = masks[0].squeeze(0).numpy().transpose(1, 2, 0)
    return masks
def main_func(inputs):
    
    image_input = inputs['image']
    classification, points = clipseg_prediction(image_input)
    if classification:
        masks = foward_pass(image_input, points)
    
        image_input = Image.fromarray(image_input)
        
        final_mask = masks[0]
        mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
        mask_colors[final_mask, :] = np.array([[128, 0, 0]])
        return Image.fromarray((mask_colors * 0.6 + image_input * 0.4).astype('uint8'), 'RGB')
    else:
        return Image.fromarray(image_input)
    return pred_masks
def reset_data():
    global cache_data
    cache_data = None
with gr.Blocks() as demo:
    gr.Markdown("# Demo to run Vehicle damage detection")
    gr.Markdown("""This app uses the SAM model and clipseg model to get a vehicle damage area from image.""")
    with gr.Row():
        image_input = gr.Image()
        image_output = gr.Image()
    
    image_button = gr.Button("Segment Image", variant='primary')
    image_button.click(main_func, inputs=image_input, outputs=image_output)
    image_input.upload(reset_data)
demo.launch()
 |