Spaces:
Sleeping
Sleeping
File size: 5,875 Bytes
078dd6b 87d2db3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
# This wrapper provides several features:
# 1. A `ImageGenerationParams` dataclass to handle parameters with default values
# 2. A `ImageGenerationResult` class to wrap the API response
# 3. The main `ImagenWrapper` class with:
# - Proper initialization with error handling
# - Logging support
# - Two methods for generation:
# - `generate()` using the `ImageGenerationParams` class
# - `generate_simple()` for a more straightforward interface
# Here's how to use it:
# # Example usage:
# # Initialize the wrapper
# wrapper = ImagenWrapper("https://bcdb8b7f9c4a57127c.gradio.live/")
# # Method 1: Using ImageGenerationParams
# params = ImageGenerationParams(
# prompt="A beautiful sunset over mountains",
# width=512,
# height=512
# )
# result = wrapper.generate(params)
# # Method 2: Using generate_simple
# result = wrapper.generate_simple(
# prompt="A beautiful sunset over mountains",
# width=512,
# height=512
# )
# # Access the results
# print(f"Image URL: {result.image_url}")
# print(f"Seed used: {result.seed}")
# The wrapper includes:
# - Type hints for better IDE support
# - Error handling and logging
# - Parameter validation
# - Flexible parameter input (both through dataclass and dictionary)
# - Clean result handling through a dedicated class
# You can also add error handling in your code:
# try:
# wrapper = ImagenWrapper("https://bcdb8b7f9c4a57127c.gradio.live/")
# result = wrapper.generate_simple("A beautiful sunset")
# print(f"Generated image: {result}")
# except ConnectionError as e:
# print(f"Failed to connect to API: {e}")
# except RuntimeError as e:
# print(f"Generation failed: {e}")
# except Exception as e:
# print(f"Unexpected error: {e}")
from gradio_client import Client
from typing import Dict, Tuple, Optional, Union
from dataclasses import dataclass
import logging
@dataclass
class ImageGenerationParams:
"""Data class to hold image generation parameters"""
prompt: str
seed: float = 0
randomize_seed: bool = True
width: float = 1024
height: float = 1024
guidance_scale: float = 3.5
num_inference_steps: float = 28
lora_scale: float = 0.7
class ImageGenerationResult:
"""Class to handle the generation result"""
def __init__(self, image_data: Dict, seed: float):
self.image_path = image_data.get('path')
self.image_url = image_data.get('url')
self.size = image_data.get('size')
self.orig_name = image_data.get('orig_name')
self.mime_type = image_data.get('mime_type')
self.is_stream = image_data.get('is_stream', False)
self.meta = image_data.get('meta', {})
self.seed = seed
def __str__(self) -> str:
return f"ImageGenerationResult(url={self.image_url}, seed={self.seed})"
class ImagenWrapper:
"""Wrapper class for the Imagen Gradio deployment"""
def __init__(self, api_url: str):
"""
Initialize the wrapper with the API URL
Args:
api_url (str): The URL of the Gradio deployment
"""
self.api_url = api_url
self.logger = logging.getLogger(__name__)
try:
self.client = Client(api_url)
self.logger.info(f"Successfully connected to API at {api_url}")
except Exception as e:
self.logger.error(f"Failed to connect to API at {api_url}: {str(e)}")
raise ConnectionError(f"Failed to connect to API: {str(e)}")
def generate(self,
params: Union[ImageGenerationParams, Dict],
) -> ImageGenerationResult:
"""
Generate an image using the provided parameters
Args:
params: Either an ImageGenerationParams object or a dictionary with the parameters
Returns:
ImageGenerationResult: Object containing the generation results
Raises:
ValueError: If parameters are invalid
RuntimeError: If the API call fails
"""
try:
# Convert dict to ImageGenerationParams if necessary
if isinstance(params, dict):
params = ImageGenerationParams(**params)
# Validate parameters
if not params.prompt:
raise ValueError("Prompt cannot be empty")
# Make the API call
result = self.client.predict(
prompt=params.prompt,
seed=params.seed,
randomize_seed=params.randomize_seed,
width=params.width,
height=params.height,
guidance_scale=params.guidance_scale,
num_inference_steps=params.num_inference_steps,
lora_scale=params.lora_scale,
api_name="/infer"
)
# Process the result
if not result or len(result) != 2:
raise RuntimeError("Invalid response from API")
image_data, seed = result
return ImageGenerationResult(image_data, seed)
except Exception as e:
self.logger.error(f"Error during image generation: {str(e)}")
raise RuntimeError(f"Failed to generate image: {str(e)}")
def generate_simple(self,
prompt: str,
**kwargs) -> ImageGenerationResult:
"""
Simplified interface for generating images
Args:
prompt (str): The prompt for image generation
**kwargs: Optional parameters to override defaults
Returns:
ImageGenerationResult: Object containing the generation results
"""
params = ImageGenerationParams(prompt=prompt, **kwargs)
return self.generate(params) |