Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	File size: 4,119 Bytes
			
			| 87d2db3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | from gradio_client import Client
from typing import Dict, Tuple, Optional, Union
from dataclasses import dataclass
import logging
@dataclass
class ImageGenerationParams:
    """Data class to hold image generation parameters"""
    prompt: str
    seed: float = 0
    randomize_seed: bool = True
    width: float = 1024
    height: float = 1024
    guidance_scale: float = 3.5
    num_inference_steps: float = 28
    lora_scale: float = 0.7
class ImageGenerationResult:
    """Class to handle the generation result"""
    def __init__(self, image_data: Dict, seed: float):
        self.image_path = image_data.get('path')
        self.image_url = image_data.get('url')
        self.size = image_data.get('size')
        self.orig_name = image_data.get('orig_name')
        self.mime_type = image_data.get('mime_type')
        self.is_stream = image_data.get('is_stream', False)
        self.meta = image_data.get('meta', {})
        self.seed = seed
    def __str__(self) -> str:
        return f"ImageGenerationResult(url={self.image_url}, seed={self.seed})"
class ImagenWrapper:
    """Wrapper class for the Imagen Gradio deployment"""
    
    def __init__(self, api_url: str):
        """
        Initialize the wrapper with the API URL
        
        Args:
            api_url (str): The URL of the Gradio deployment
        """
        self.api_url = api_url
        self.logger = logging.getLogger(__name__)
        try:
            self.client = Client(api_url)
            self.logger.info(f"Successfully connected to API at {api_url}")
        except Exception as e:
            self.logger.error(f"Failed to connect to API at {api_url}: {str(e)}")
            raise ConnectionError(f"Failed to connect to API: {str(e)}")
    def generate(self, 
                params: Union[ImageGenerationParams, Dict],
                ) -> ImageGenerationResult:
        """
        Generate an image using the provided parameters
        
        Args:
            params: Either an ImageGenerationParams object or a dictionary with the parameters
            
        Returns:
            ImageGenerationResult: Object containing the generation results
            
        Raises:
            ValueError: If parameters are invalid
            RuntimeError: If the API call fails
        """
        try:
            # Convert dict to ImageGenerationParams if necessary
            if isinstance(params, dict):
                params = ImageGenerationParams(**params)
            
            # Validate parameters
            if not params.prompt:
                raise ValueError("Prompt cannot be empty")
            
            # Make the API call
            result = self.client.predict(
                prompt=params.prompt,
                seed=params.seed,
                randomize_seed=params.randomize_seed,
                width=params.width,
                height=params.height,
                guidance_scale=params.guidance_scale,
                num_inference_steps=params.num_inference_steps,
                lora_scale=params.lora_scale,
                api_name="/infer"
            )
            
            # Process the result
            if not result or len(result) != 2:
                raise RuntimeError("Invalid response from API")
            
            image_data, seed = result
            return ImageGenerationResult(image_data, seed)
            
        except Exception as e:
            self.logger.error(f"Error during image generation: {str(e)}")
            raise RuntimeError(f"Failed to generate image: {str(e)}")
    def generate_simple(self, 
                       prompt: str,
                       **kwargs) -> ImageGenerationResult:
        """
        Simplified interface for generating images
        
        Args:
            prompt (str): The prompt for image generation
            **kwargs: Optional parameters to override defaults
            
        Returns:
            ImageGenerationResult: Object containing the generation results
        """
        params = ImageGenerationParams(prompt=prompt, **kwargs)
        return self.generate(params) | 
