File size: 1,894 Bytes
4a1e480
1f7551f
4a1e480
1f7551f
4a1e480
1f7551f
 
4a1e480
e0d4444
4a1e480
 
3a07267
4a1e480
 
 
 
69481a1
4a1e480
 
9e8370c
69481a1
4a1e480
 
 
69481a1
4a1e480
 
69d41c4
69481a1
 
9e8370c
 
69481a1
4a1e480
 
69481a1
e0d4444
4a1e480
69481a1
4a1e480
69481a1
 
4a1e480
69481a1
4a1e480
 
e0d4444
4a1e480
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import List, Any
import torch
from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline

# Configurar el dispositivo para ejecutar el modelo
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
    raise ValueError("Se requiere ejecutar en GPU")

# Configurar el tipo de dato mixto basado en la capacidad de la GPU
dtype = torch.bfloat16 if torch.cuda.get_device_capability(device.index)[0] >= 8 else torch.float16

class EndpointHandler():
    def __init__(self):
        # Inicializar aquí si es necesario
        pass

    def __call__(self, data: Any) -> List[Any]:
        # Configurar el número de imágenes por prompt
        num_images_per_prompt = 1

        # Cargar los modelos con el tipo de dato y dispositivo correctos
        prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype).to(device)
        decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype).to(device)

        prompt = data.get("inputs", "Una imagen interesante")  # Asegúrate de pasar un prompt adecuado
        negative_prompt = data.get("negative_prompt", "")

        prior_output = prior(
            prompt=prompt,
            height=512,
            width=512,
            negative_prompt=negative_prompt,
            guidance_scale=7.5,
            num_inference_steps=50,
            num_images_per_prompt=num_images_per_prompt,
        )

        decoder_output = decoder(
            image_embeddings=prior_output["image_embeddings"].half(),
            prompt=prompt,
            negative_prompt=negative_prompt,
            guidance_scale=7.5,
            output_type="pil",
            num_inference_steps=20
        )

        # Asumiendo que quieres retornar la primera imagen
        return [decoder_output.images[0]]