File size: 5,875 Bytes
078dd6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87d2db3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176


# This wrapper provides several features:

# 1. A `ImageGenerationParams` dataclass to handle parameters with default values
# 2. A `ImageGenerationResult` class to wrap the API response
# 3. The main `ImagenWrapper` class with:
#    - Proper initialization with error handling
#    - Logging support
#    - Two methods for generation:
#      - `generate()` using the `ImageGenerationParams` class
#      - `generate_simple()` for a more straightforward interface

# Here's how to use it:

# # Example usage:

# # Initialize the wrapper
# wrapper = ImagenWrapper("https://bcdb8b7f9c4a57127c.gradio.live/")

# # Method 1: Using ImageGenerationParams
# params = ImageGenerationParams(
#     prompt="A beautiful sunset over mountains",
#     width=512,
#     height=512
# )
# result = wrapper.generate(params)

# # Method 2: Using generate_simple
# result = wrapper.generate_simple(
#     prompt="A beautiful sunset over mountains",
#     width=512,
#     height=512
# )

# # Access the results
# print(f"Image URL: {result.image_url}")
# print(f"Seed used: {result.seed}")

# The wrapper includes:
# - Type hints for better IDE support
# - Error handling and logging
# - Parameter validation
# - Flexible parameter input (both through dataclass and dictionary)
# - Clean result handling through a dedicated class

# You can also add error handling in your code:

# try:
#     wrapper = ImagenWrapper("https://bcdb8b7f9c4a57127c.gradio.live/")
#     result = wrapper.generate_simple("A beautiful sunset")
#     print(f"Generated image: {result}")
# except ConnectionError as e:
#     print(f"Failed to connect to API: {e}")
# except RuntimeError as e:
#     print(f"Generation failed: {e}")
# except Exception as e:
#     print(f"Unexpected error: {e}")



from gradio_client import Client
from typing import Dict, Tuple, Optional, Union
from dataclasses import dataclass
import logging

@dataclass
class ImageGenerationParams:
    """Data class to hold image generation parameters"""
    prompt: str
    seed: float = 0
    randomize_seed: bool = True
    width: float = 1024
    height: float = 1024
    guidance_scale: float = 3.5
    num_inference_steps: float = 28
    lora_scale: float = 0.7

class ImageGenerationResult:
    """Class to handle the generation result"""
    def __init__(self, image_data: Dict, seed: float):
        self.image_path = image_data.get('path')
        self.image_url = image_data.get('url')
        self.size = image_data.get('size')
        self.orig_name = image_data.get('orig_name')
        self.mime_type = image_data.get('mime_type')
        self.is_stream = image_data.get('is_stream', False)
        self.meta = image_data.get('meta', {})
        self.seed = seed

    def __str__(self) -> str:
        return f"ImageGenerationResult(url={self.image_url}, seed={self.seed})"

class ImagenWrapper:
    """Wrapper class for the Imagen Gradio deployment"""
    
    def __init__(self, api_url: str):
        """
        Initialize the wrapper with the API URL
        
        Args:
            api_url (str): The URL of the Gradio deployment
        """
        self.api_url = api_url
        self.logger = logging.getLogger(__name__)
        try:
            self.client = Client(api_url)
            self.logger.info(f"Successfully connected to API at {api_url}")
        except Exception as e:
            self.logger.error(f"Failed to connect to API at {api_url}: {str(e)}")
            raise ConnectionError(f"Failed to connect to API: {str(e)}")

    def generate(self, 
                params: Union[ImageGenerationParams, Dict],
                ) -> ImageGenerationResult:
        """
        Generate an image using the provided parameters
        
        Args:
            params: Either an ImageGenerationParams object or a dictionary with the parameters
            
        Returns:
            ImageGenerationResult: Object containing the generation results
            
        Raises:
            ValueError: If parameters are invalid
            RuntimeError: If the API call fails
        """
        try:
            # Convert dict to ImageGenerationParams if necessary
            if isinstance(params, dict):
                params = ImageGenerationParams(**params)
            
            # Validate parameters
            if not params.prompt:
                raise ValueError("Prompt cannot be empty")
            
            # Make the API call
            result = self.client.predict(
                prompt=params.prompt,
                seed=params.seed,
                randomize_seed=params.randomize_seed,
                width=params.width,
                height=params.height,
                guidance_scale=params.guidance_scale,
                num_inference_steps=params.num_inference_steps,
                lora_scale=params.lora_scale,
                api_name="/infer"
            )
            
            # Process the result
            if not result or len(result) != 2:
                raise RuntimeError("Invalid response from API")
            
            image_data, seed = result
            return ImageGenerationResult(image_data, seed)
            
        except Exception as e:
            self.logger.error(f"Error during image generation: {str(e)}")
            raise RuntimeError(f"Failed to generate image: {str(e)}")

    def generate_simple(self, 
                       prompt: str,
                       **kwargs) -> ImageGenerationResult:
        """
        Simplified interface for generating images
        
        Args:
            prompt (str): The prompt for image generation
            **kwargs: Optional parameters to override defaults
            
        Returns:
            ImageGenerationResult: Object containing the generation results
        """
        params = ImageGenerationParams(prompt=prompt, **kwargs)
        return self.generate(params)