File size: 2,185 Bytes
45ce449 6853258 9eb98e1 6853258 e68ce70 e26163a 6853258 365daee 6853258 dd47fae 10fe40c 6853258 365daee 00ef2b0 365daee 00ef2b0 6853258 10fe40c 6853258 e26163a 6853258 10fe40c 6853258 00ef2b0 6853258 4a092d4 6853258 10fe40c 331df7b 48de697 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from PIL import Image
import torch
from transformers.agents.tools import Tool
from transformers.utils import (
is_accelerate_available,
is_vision_available,
)
from diffusers import DiffusionPipeline
if is_accelerate_available():
from accelerate import PartialState
IMAGE_TRANSFORMATION_DESCRIPTION = (
"This is a tool that transforms an image according to a prompt and returns the "
"modified image."
)
class ImageTransformationTool(Tool):
name = "image_transformation"
default_stable_diffusion_checkpoint = "timbrooks/instruct-pix2pix"
description = IMAGE_TRANSFORMATION_DESCRIPTION
inputs = {
'image': {"type": "image", "description": "the image to transform"},
'prompt': {"type": "string", "description": "the prompt to use to change the image"}
}
output_type = "image"
def __init__(self, device=None, controlnet=None, stable_diffusion=None, **hub_kwargs) -> None:
if not is_accelerate_available():
raise ImportError("Accelerate should be installed in order to use tools.")
if not is_vision_available():
raise ImportError("Pillow should be installed in order to use the StableDiffusionTool.")
super().__init__()
self.stable_diffusion = self.default_stable_diffusion_checkpoint
self.device = device
self.hub_kwargs = hub_kwargs
def setup(self):
if self.device is None:
self.device = PartialState().default_device
self.pipeline = DiffusionPipeline.from_pretrained(self.stable_diffusion)
self.pipeline.to(self.device)
if self.device.type == "cuda":
self.pipeline.to(torch_dtype=torch.float16)
self.is_initialized = True
def forward(self, image, prompt):
if not self.is_initialized:
self.setup()
negative_prompt = "low quality, bad quality, deformed, low resolution"
added_prompt = " , highest quality, highly realistic, very high resolution"
return self.pipeline(
prompt + added_prompt,
image,
negative_prompt=negative_prompt,
num_inference_steps=50,
).images[0] |