Spaces:
Build error
Build error
| import sys | |
| sys.path.append('.') | |
| from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image, ImageDraw, ImageFont | |
| from utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, show_mask | |
| import torch | |
| from ram_train_eval import RamModel,RamPredictor | |
| from mmengine.config import Config | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| input_size = 512 | |
| hidden_size = 256 | |
| num_classes = 56 | |
| # load sam model | |
| sam = build_sam(checkpoint="./checkpoints/sam_vit_h_4b8939.pth").to(device) | |
| predictor = SamPredictor(sam) | |
| mask_generator = SamAutomaticMaskGenerator(sam) | |
| # load ram model | |
| model_path = "./checkpoints/ram_epoch12.pth" | |
| config = dict( | |
| model=dict( | |
| pretrained_model_name_or_path='bert-base-uncased', | |
| load_pretrained_weights=False, | |
| num_transformer_layer=2, | |
| input_feature_size=256, | |
| output_feature_size=768, | |
| cls_feature_size=512, | |
| num_relation_classes=56, | |
| pred_type='attention', | |
| loss_type='multi_label_ce', | |
| ), | |
| load_from=model_path, | |
| ) | |
| config = Config(config) | |
| class Predictor(RamPredictor): | |
| def __init__(self,config): | |
| self.config = config | |
| self.device = torch.device( | |
| 'cuda' if torch.cuda.is_available() else 'cpu') | |
| self._build_model() | |
| def _build_model(self): | |
| self.model = RamModel(**self.config.model).to(self.device) | |
| if self.config.load_from is not None: | |
| self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device)) | |
| self.model.train() | |
| model = Predictor(config) | |
| # visualization | |
| def draw_selected_mask(mask, draw): | |
| color = (255, 0, 0, 153) | |
| nonzero_coords = np.transpose(np.nonzero(mask)) | |
| for coord in nonzero_coords: | |
| draw.point(coord[::-1], fill=color) | |
| def draw_object_mask(mask, draw): | |
| color = (0, 0, 255, 153) | |
| nonzero_coords = np.transpose(np.nonzero(mask)) | |
| for coord in nonzero_coords: | |
| draw.point(coord[::-1], fill=color) | |
| def vis_selected(pil_image, coords): | |
| w, h = pil_image.size | |
| max_edge = 1500 | |
| if w > max_edge or h > max_edge: | |
| ratio = max(w, h) / max_edge | |
| new_size = (int(w / ratio), int(h / ratio)) | |
| pil_image.thumbnail(new_size) | |
| coords = str(int(int(coords.split(',')[0]) * new_size[0] / w)) + ',' + str(int(int(coords.split(',')[1]) * new_size[1] / h)) | |
| # get coords | |
| coords_x, coords_y = coords.split(',') | |
| input_point = np.array([[int(coords_x), int(coords_y)]]) | |
| input_label = np.array([1]) | |
| # load image | |
| image = np.array(pil_image) | |
| predictor.set_image(image) | |
| mask1, score1, logit1, feat1 = predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| multimask_output=False, | |
| ) | |
| pil_image = pil_image.convert('RGBA') | |
| mask_image = Image.new('RGBA', pil_image.size, color=(0, 0, 0, 0)) | |
| mask_draw = ImageDraw.Draw(mask_image) | |
| draw_selected_mask(mask1[0], mask_draw) | |
| pil_image.alpha_composite(mask_image) | |
| yield [pil_image] | |
| def create_title_image(word1, word2, word3, width, font_path='./assets/OpenSans-Bold.ttf'): | |
| # Define the colors to use for each word | |
| color_red = (255, 0, 0) | |
| color_black = (0, 0, 0) | |
| color_blue = (0, 0, 255) | |
| # Define the initial font size and spacing between words | |
| font_size = 40 | |
| # Create a new image with the specified width and white background | |
| image = Image.new('RGB', (width, 60), (255, 255, 255)) | |
| # Load the specified font | |
| font = ImageFont.truetype(font_path, font_size) | |
| # Keep increasing the font size until all words fit within the desired width | |
| while True: | |
| # Create a draw object for the image | |
| draw = ImageDraw.Draw(image) | |
| word_spacing = font_size / 2 | |
| # Draw each word in the appropriate color | |
| x_offset = word_spacing | |
| draw.text((x_offset, 0), word1, color_red, font=font) | |
| x_offset += font.getsize(word1)[0] + word_spacing | |
| draw.text((x_offset, 0), word2, color_black, font=font) | |
| x_offset += font.getsize(word2)[0] + word_spacing | |
| draw.text((x_offset, 0), word3, color_blue, font=font) | |
| word_sizes = [font.getsize(word) for word in [word1, word2, word3]] | |
| total_width = sum([size[0] for size in word_sizes]) + word_spacing * 3 | |
| # Stop increasing font size if the image is within the desired width | |
| if total_width <= width: | |
| break | |
| # Increase font size and reset the draw object | |
| font_size -= 1 | |
| image = Image.new('RGB', (width, 50), (255, 255, 255)) | |
| font = ImageFont.truetype(font_path, font_size) | |
| draw = None | |
| return image | |
| def concatenate_images_vertical(image1, image2): | |
| # Get the dimensions of the two images | |
| width1, height1 = image1.size | |
| width2, height2 = image2.size | |
| # Create a new image with the combined height and the maximum width | |
| new_image = Image.new('RGBA', (max(width1, width2), height1 + height2)) | |
| # Paste the first image at the top of the new image | |
| new_image.paste(image1, (0, 0)) | |
| # Paste the second image below the first image | |
| new_image.paste(image2, (0, height1)) | |
| return new_image | |
| def relate_selected(input_image, k, coords): | |
| # load image | |
| pil_image = input_image.convert('RGBA') | |
| w, h = pil_image.size | |
| max_edge = 1500 | |
| if w > max_edge or h > max_edge: | |
| ratio = max(w, h) / max_edge | |
| new_size = (int(w / ratio), int(h / ratio)) | |
| pil_image.thumbnail(new_size) | |
| input_image.thumbnail(new_size) | |
| coords = str(int(int(coords.split(',')[0]) * new_size[0] / w)) + ',' + str(int(int(coords.split(',')[1]) * new_size[1] / h)) | |
| image = np.array(input_image) | |
| sam_masks = mask_generator.generate(image) | |
| # get old mask | |
| coords_x, coords_y = coords.split(',') | |
| input_point = np.array([[int(coords_x), int(coords_y)]]) | |
| input_label = np.array([1]) | |
| mask1, score1, logit1, feat1 = predictor.predict( | |
| point_coords=input_point, | |
| point_labels=input_label, | |
| multimask_output=False, | |
| ) | |
| filtered_masks = sort_and_deduplicate(sam_masks) | |
| filtered_masks = [d for d in sam_masks if iou(d['segmentation'], mask1[0]) < 0.95][:k] | |
| pil_image_list = [] | |
| # run model | |
| feat = feat1 | |
| for fm in filtered_masks: | |
| feat2 = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device) | |
| feat = torch.cat((feat, feat2), dim=1) | |
| matrix_output, rel_triplets = model.predict(feat) | |
| subject_output = matrix_output.permute([0,2,3,1])[:,0,1:] | |
| for i in range(len(filtered_masks)): | |
| output = subject_output[:,i] | |
| topk_indices = torch.argsort(-output).flatten() | |
| relation = relation_classes[topk_indices[:1][0]] | |
| mask_image = Image.new('RGBA', pil_image.size, color=(0, 0, 0, 0)) | |
| mask_draw = ImageDraw.Draw(mask_image) | |
| draw_selected_mask(mask1[0], mask_draw) | |
| draw_object_mask(filtered_masks[i]['segmentation'], mask_draw) | |
| current_pil_image = pil_image.copy() | |
| current_pil_image.alpha_composite(mask_image) | |
| title_image = create_title_image('Red', relation, 'Blue', current_pil_image.size[0]) | |
| concate_pil_image = concatenate_images_vertical(current_pil_image, title_image) | |
| pil_image_list.append(concate_pil_image) | |
| yield pil_image_list | |
| def relate_anything(input_image, k): | |
| w, h = input_image.size | |
| max_edge = 1500 | |
| if w > max_edge or h > max_edge: | |
| ratio = max(w, h) / max_edge | |
| new_size = (int(w / ratio), int(h / ratio)) | |
| input_image.thumbnail(new_size) | |
| # load image | |
| pil_image = input_image.convert('RGBA') | |
| image = np.array(input_image) | |
| sam_masks = mask_generator.generate(image) | |
| filtered_masks = sort_and_deduplicate(sam_masks) | |
| feat_list = [] | |
| for fm in filtered_masks: | |
| feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device) | |
| feat_list.append(feat) | |
| feat = torch.cat(feat_list, dim=1).to(device) | |
| matrix_output, rel_triplets = model.predict(feat) | |
| pil_image_list = [] | |
| for i, rel in enumerate(rel_triplets[:k]): | |
| s,o,r = int(rel[0]),int(rel[1]),int(rel[2]) | |
| relation = relation_classes[r] | |
| mask_image = Image.new('RGBA', pil_image.size, color=(0, 0, 0, 0)) | |
| mask_draw = ImageDraw.Draw(mask_image) | |
| draw_selected_mask(filtered_masks[s]['segmentation'], mask_draw) | |
| draw_object_mask(filtered_masks[o]['segmentation'], mask_draw) | |
| current_pil_image = pil_image.copy() | |
| current_pil_image.alpha_composite(mask_image) | |
| title_image = create_title_image('Red', relation, 'Blue', current_pil_image.size[0]) | |
| concate_pil_image = concatenate_images_vertical(current_pil_image, title_image) | |
| pil_image_list.append(concate_pil_image) | |
| yield pil_image_list | |
| DESCRIPTION = '''# Relate-Anyting | |
| ### π π π RAM (Relate-Anything-Model) combines Meta's Segment-Anything model with the ECCV'22 paper: [Panoptic Scene Graph Generation](https://psgdataset.org/). | |
| ### π€ π€ π€ Given an image, RAM finds all the meaningful relations between anything. (Check Tab: Relate Anything) | |
| ### π±οΈ π±οΈ π±οΈ You can also click something on the image, and RAM find anything relates to that. (Check Tab: Relate Something) | |
| ### π₯ π₯ π₯ Please star our codebase [OpenPSG](https://github.com/Jingkang50/OpenPSG) and [RAM](https://github.com/Luodian/RelateAnything) if you find it useful / interesting. | |
| ### It is recommended to upgrade to GPU in Settings after duplicating this space to use it. [](https://huggingface.co/spaces/mmlab-ntu/relate-anything-model?duplicate=true) | |
| ### Here is a 3-day Gradio link: https://bf5e65e511446cbe60.gradio.live/, expires by 3am, April 28, Singapore Time. | |
| ''' | |
| block = gr.Blocks() | |
| block = block.queue() | |
| with block: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(source="upload", type="pil", value="assets/dog.jpg") | |
| with gr.Tab("Relate Anything"): | |
| num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1) | |
| relate_all_button = gr.Button(label="Relate Anything!") | |
| with gr.Tab("Relate Something"): | |
| img_input_coords = gr.Textbox(label="Click something to get input coords") | |
| def select_handler(evt: gr.SelectData): | |
| coords = evt.index | |
| return f"{coords[0]},{coords[1]}" | |
| input_image.select(select_handler, None, img_input_coords) | |
| run_button_vis = gr.Button(label="Visualize the Select Thing") | |
| selected_gallery = gr.Gallery(label="Selected Thing", show_label=True, elem_id="gallery").style(object_fit="scale-down") | |
| k = gr.Slider(label="Number of things you want to relate", minimum=1, maximum=20, value=5, step=1) | |
| relate_selected_button = gr.Button(value="Relate it with Anything", interactive=True) | |
| with gr.Column(): | |
| image_gallery = gr.Gallery(label="Your Result", show_label=True, elem_id="gallery").style(preview=True, columns=5, object_fit="scale-down") | |
| # relate anything | |
| relate_all_button.click(fn=relate_anything, inputs=[input_image, num_relation], outputs=[image_gallery], show_progress=True, queue=True) | |
| # relate selected | |
| run_button_vis.click(fn=vis_selected, inputs=[input_image, img_input_coords], outputs=[selected_gallery], show_progress=True, queue=True) | |
| relate_selected_button.click(fn=relate_selected, inputs=[input_image, k, img_input_coords], outputs=[image_gallery], show_progress=True, queue=True) | |
| block.launch() | |