demo-25-01-1515 / imagenwrapper.py
Tonic's picture
add imagenwrapper howto use
078dd6b unverified
# This wrapper provides several features:
# 1. A `ImageGenerationParams` dataclass to handle parameters with default values
# 2. A `ImageGenerationResult` class to wrap the API response
# 3. The main `ImagenWrapper` class with:
# - Proper initialization with error handling
# - Logging support
# - Two methods for generation:
# - `generate()` using the `ImageGenerationParams` class
# - `generate_simple()` for a more straightforward interface
# Here's how to use it:
# # Example usage:
# # Initialize the wrapper
# wrapper = ImagenWrapper("https://bcdb8b7f9c4a57127c.gradio.live/")
# # Method 1: Using ImageGenerationParams
# params = ImageGenerationParams(
# prompt="A beautiful sunset over mountains",
# width=512,
# height=512
# )
# result = wrapper.generate(params)
# # Method 2: Using generate_simple
# result = wrapper.generate_simple(
# prompt="A beautiful sunset over mountains",
# width=512,
# height=512
# )
# # Access the results
# print(f"Image URL: {result.image_url}")
# print(f"Seed used: {result.seed}")
# The wrapper includes:
# - Type hints for better IDE support
# - Error handling and logging
# - Parameter validation
# - Flexible parameter input (both through dataclass and dictionary)
# - Clean result handling through a dedicated class
# You can also add error handling in your code:
# try:
# wrapper = ImagenWrapper("https://bcdb8b7f9c4a57127c.gradio.live/")
# result = wrapper.generate_simple("A beautiful sunset")
# print(f"Generated image: {result}")
# except ConnectionError as e:
# print(f"Failed to connect to API: {e}")
# except RuntimeError as e:
# print(f"Generation failed: {e}")
# except Exception as e:
# print(f"Unexpected error: {e}")
from gradio_client import Client
from typing import Dict, Tuple, Optional, Union
from dataclasses import dataclass
import logging
@dataclass
class ImageGenerationParams:
"""Data class to hold image generation parameters"""
prompt: str
seed: float = 0
randomize_seed: bool = True
width: float = 1024
height: float = 1024
guidance_scale: float = 3.5
num_inference_steps: float = 28
lora_scale: float = 0.7
class ImageGenerationResult:
"""Class to handle the generation result"""
def __init__(self, image_data: Dict, seed: float):
self.image_path = image_data.get('path')
self.image_url = image_data.get('url')
self.size = image_data.get('size')
self.orig_name = image_data.get('orig_name')
self.mime_type = image_data.get('mime_type')
self.is_stream = image_data.get('is_stream', False)
self.meta = image_data.get('meta', {})
self.seed = seed
def __str__(self) -> str:
return f"ImageGenerationResult(url={self.image_url}, seed={self.seed})"
class ImagenWrapper:
"""Wrapper class for the Imagen Gradio deployment"""
def __init__(self, api_url: str):
"""
Initialize the wrapper with the API URL
Args:
api_url (str): The URL of the Gradio deployment
"""
self.api_url = api_url
self.logger = logging.getLogger(__name__)
try:
self.client = Client(api_url)
self.logger.info(f"Successfully connected to API at {api_url}")
except Exception as e:
self.logger.error(f"Failed to connect to API at {api_url}: {str(e)}")
raise ConnectionError(f"Failed to connect to API: {str(e)}")
def generate(self,
params: Union[ImageGenerationParams, Dict],
) -> ImageGenerationResult:
"""
Generate an image using the provided parameters
Args:
params: Either an ImageGenerationParams object or a dictionary with the parameters
Returns:
ImageGenerationResult: Object containing the generation results
Raises:
ValueError: If parameters are invalid
RuntimeError: If the API call fails
"""
try:
# Convert dict to ImageGenerationParams if necessary
if isinstance(params, dict):
params = ImageGenerationParams(**params)
# Validate parameters
if not params.prompt:
raise ValueError("Prompt cannot be empty")
# Make the API call
result = self.client.predict(
prompt=params.prompt,
seed=params.seed,
randomize_seed=params.randomize_seed,
width=params.width,
height=params.height,
guidance_scale=params.guidance_scale,
num_inference_steps=params.num_inference_steps,
lora_scale=params.lora_scale,
api_name="/infer"
)
# Process the result
if not result or len(result) != 2:
raise RuntimeError("Invalid response from API")
image_data, seed = result
return ImageGenerationResult(image_data, seed)
except Exception as e:
self.logger.error(f"Error during image generation: {str(e)}")
raise RuntimeError(f"Failed to generate image: {str(e)}")
def generate_simple(self,
prompt: str,
**kwargs) -> ImageGenerationResult:
"""
Simplified interface for generating images
Args:
prompt (str): The prompt for image generation
**kwargs: Optional parameters to override defaults
Returns:
ImageGenerationResult: Object containing the generation results
"""
params = ImageGenerationParams(prompt=prompt, **kwargs)
return self.generate(params)