File size: 2,885 Bytes
5b2ab1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch.utils.data import DataLoader
from torchvision.transforms.functional import to_pil_image

from ..data import ImageFolderDataset
from ..models import create_diffusion_model, create_segmentation_model
from ..utils import get_masked_images


class InpaintPipeline:
    def __init__(
        self,
        segmentation_model_name: str,
        diffusion_model_name: str,
        control_model_name: str,
        images_root: str,
        prompts_path: Optional[str] = None,
        sd_model_name: Optional[str] = "runwayml/stable-diffusion-inpainting",
        image_size: Optional[Tuple[int, int]] = (512, 512),
        image_extensions: Optional[Tuple[str]] = (".jpg", ".jpeg", ".png", ".webp"),
        segmentation_model_size: Optional[str] = "large",
    ):
        self.segmentation_model = create_segmentation_model(
            segmentation_model_name=segmentation_model_name,
            model_size=segmentation_model_size,
        )

        self.diffusion_model = create_diffusion_model(
            diffusion_model_name=diffusion_model_name,
            control_model_name=control_model_name,
            sd_model_name=sd_model_name,
        )

        self.data_loader = self.build_data_loader(
            images_root=images_root,
            prompts_path=prompts_path,
            image_size=image_size,
            image_extensions=image_extensions,
        )

    def build_data_loader(
        self,
        images_root: str,
        prompts_path: Optional[str] = None,
        image_size: Optional[Tuple[int, int]] = (512, 512),
        image_extensions: Optional[Tuple[str]] = (".jpg", ".jpeg", ".png", ".webp"),
        batch_size: Optional[int] = 1,
    ) -> DataLoader:
        dataset = ImageFolderDataset(
            images_root, prompts_path, image_size, image_extensions
        )
        data_loader = DataLoader(
            dataset, batch_size=batch_size, shuffle=False, num_workers=8
        )

        return data_loader

    def run(self, data_loader: Optional[DataLoader] = None) -> List[Dict[str, Any]]:
        if data_loader is not None:
            self.data_loader = data_loader

        results = []
        for idx, (images, prompts) in enumerate(self.data_loader):
            images = [to_pil_image(img) for img in images]

            semantic_maps = self.segmentation_model.process(images)

            object_masks = [
                get_object_mask(seg_map, class_id=0) for seg_map in semantic_maps
            ]

            outputs = self.diffusion_model.process(
                images=images,
                prompts=[prompts[0]],
                mask_images=object_masks,
                negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
            )
            results += outputs["output_images"]

        return results