File size: 2,185 Bytes
45ce449
6853258
9eb98e1
6853258
 
 
 
 
e68ce70
e26163a
 
6853258
 
365daee
6853258
 
 
 
 
dd47fae
10fe40c
6853258
365daee
00ef2b0
 
365daee
00ef2b0
6853258
 
 
 
 
 
 
 
 
10fe40c
6853258
 
 
 
 
 
e26163a
6853258
10fe40c
 
 
 
 
6853258
 
 
00ef2b0
6853258
 
 
4a092d4
 
 
6853258
10fe40c
 
 
331df7b
48de697
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from PIL import Image
import torch
from transformers.agents.tools import Tool
from transformers.utils import (
    is_accelerate_available,
    is_vision_available,
)

from diffusers import DiffusionPipeline
if is_accelerate_available():
    from accelerate import PartialState

IMAGE_TRANSFORMATION_DESCRIPTION = (
    "This is a tool that transforms an image according to a prompt and returns the "
    "modified image."
)


class ImageTransformationTool(Tool):
    name = "image_transformation"
    default_stable_diffusion_checkpoint = "timbrooks/instruct-pix2pix"
    description = IMAGE_TRANSFORMATION_DESCRIPTION
    inputs = {
        'image': {"type": "image", "description": "the image to transform"},
        'prompt': {"type": "string", "description": "the prompt to use to change the image"}
    }
    output_type = "image"

    def __init__(self, device=None, controlnet=None, stable_diffusion=None, **hub_kwargs) -> None:
        if not is_accelerate_available():
            raise ImportError("Accelerate should be installed in order to use tools.")
        if not is_vision_available():
            raise ImportError("Pillow should be installed in order to use the StableDiffusionTool.")

        super().__init__()

        self.stable_diffusion = self.default_stable_diffusion_checkpoint

        self.device = device
        self.hub_kwargs = hub_kwargs

    def setup(self):
        if self.device is None:
            self.device = PartialState().default_device

        self.pipeline = DiffusionPipeline.from_pretrained(self.stable_diffusion)

        self.pipeline.to(self.device)
        if self.device.type == "cuda":
            self.pipeline.to(torch_dtype=torch.float16)

        self.is_initialized = True

    def forward(self, image, prompt):
        if not self.is_initialized:
            self.setup()

        negative_prompt = "low quality, bad quality, deformed, low resolution"
        added_prompt = " , highest quality, highly realistic, very high resolution"

        return self.pipeline(
            prompt + added_prompt,
            image,
            negative_prompt=negative_prompt,
            num_inference_steps=50,
        ).images[0]