File size: 4,119 Bytes
87d2db3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from gradio_client import Client
from typing import Dict, Tuple, Optional, Union
from dataclasses import dataclass
import logging

@dataclass
class ImageGenerationParams:
    """Data class to hold image generation parameters"""
    prompt: str
    seed: float = 0
    randomize_seed: bool = True
    width: float = 1024
    height: float = 1024
    guidance_scale: float = 3.5
    num_inference_steps: float = 28
    lora_scale: float = 0.7

class ImageGenerationResult:
    """Class to handle the generation result"""
    def __init__(self, image_data: Dict, seed: float):
        self.image_path = image_data.get('path')
        self.image_url = image_data.get('url')
        self.size = image_data.get('size')
        self.orig_name = image_data.get('orig_name')
        self.mime_type = image_data.get('mime_type')
        self.is_stream = image_data.get('is_stream', False)
        self.meta = image_data.get('meta', {})
        self.seed = seed

    def __str__(self) -> str:
        return f"ImageGenerationResult(url={self.image_url}, seed={self.seed})"

class ImagenWrapper:
    """Wrapper class for the Imagen Gradio deployment"""
    
    def __init__(self, api_url: str):
        """
        Initialize the wrapper with the API URL
        
        Args:
            api_url (str): The URL of the Gradio deployment
        """
        self.api_url = api_url
        self.logger = logging.getLogger(__name__)
        try:
            self.client = Client(api_url)
            self.logger.info(f"Successfully connected to API at {api_url}")
        except Exception as e:
            self.logger.error(f"Failed to connect to API at {api_url}: {str(e)}")
            raise ConnectionError(f"Failed to connect to API: {str(e)}")

    def generate(self, 
                params: Union[ImageGenerationParams, Dict],
                ) -> ImageGenerationResult:
        """
        Generate an image using the provided parameters
        
        Args:
            params: Either an ImageGenerationParams object or a dictionary with the parameters
            
        Returns:
            ImageGenerationResult: Object containing the generation results
            
        Raises:
            ValueError: If parameters are invalid
            RuntimeError: If the API call fails
        """
        try:
            # Convert dict to ImageGenerationParams if necessary
            if isinstance(params, dict):
                params = ImageGenerationParams(**params)
            
            # Validate parameters
            if not params.prompt:
                raise ValueError("Prompt cannot be empty")
            
            # Make the API call
            result = self.client.predict(
                prompt=params.prompt,
                seed=params.seed,
                randomize_seed=params.randomize_seed,
                width=params.width,
                height=params.height,
                guidance_scale=params.guidance_scale,
                num_inference_steps=params.num_inference_steps,
                lora_scale=params.lora_scale,
                api_name="/infer"
            )
            
            # Process the result
            if not result or len(result) != 2:
                raise RuntimeError("Invalid response from API")
            
            image_data, seed = result
            return ImageGenerationResult(image_data, seed)
            
        except Exception as e:
            self.logger.error(f"Error during image generation: {str(e)}")
            raise RuntimeError(f"Failed to generate image: {str(e)}")

    def generate_simple(self, 
                       prompt: str,
                       **kwargs) -> ImageGenerationResult:
        """
        Simplified interface for generating images
        
        Args:
            prompt (str): The prompt for image generation
            **kwargs: Optional parameters to override defaults
            
        Returns:
            ImageGenerationResult: Object containing the generation results
        """
        params = ImageGenerationParams(prompt=prompt, **kwargs)
        return self.generate(params)