File size: 605 Bytes
f428b3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from transformers.agents.tools import Tool
from huggingface_hub import InferenceClient


class TextToImageTool(Tool):
    default_checkpoint = "runwayml/stable-diffusion-v1-5"
    description = "This is a tool that creates an image according to a prompt, which is a text description."
    name = "image_generator"
    inputs = {"prompt": {"type": "text", "description": "the image description"}}
    output_type = "image"
    model_sdxl = "stabilityai/stable-diffusion-xl-base-1.0"
    client = InferenceClient(model_sdxl)


    def forward(self, prompt):
        return self.client.text_to_image(prompt)