Tonic commited on
Commit
87d2db3
·
unverified ·
1 Parent(s): 76435b8

add imagen wrapper

Browse files
Files changed (1) hide show
  1. imagenwrapper.py +115 -0
imagenwrapper.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client
2
+ from typing import Dict, Tuple, Optional, Union
3
+ from dataclasses import dataclass
4
+ import logging
5
+
6
+ @dataclass
7
+ class ImageGenerationParams:
8
+ """Data class to hold image generation parameters"""
9
+ prompt: str
10
+ seed: float = 0
11
+ randomize_seed: bool = True
12
+ width: float = 1024
13
+ height: float = 1024
14
+ guidance_scale: float = 3.5
15
+ num_inference_steps: float = 28
16
+ lora_scale: float = 0.7
17
+
18
+ class ImageGenerationResult:
19
+ """Class to handle the generation result"""
20
+ def __init__(self, image_data: Dict, seed: float):
21
+ self.image_path = image_data.get('path')
22
+ self.image_url = image_data.get('url')
23
+ self.size = image_data.get('size')
24
+ self.orig_name = image_data.get('orig_name')
25
+ self.mime_type = image_data.get('mime_type')
26
+ self.is_stream = image_data.get('is_stream', False)
27
+ self.meta = image_data.get('meta', {})
28
+ self.seed = seed
29
+
30
+ def __str__(self) -> str:
31
+ return f"ImageGenerationResult(url={self.image_url}, seed={self.seed})"
32
+
33
+ class ImagenWrapper:
34
+ """Wrapper class for the Imagen Gradio deployment"""
35
+
36
+ def __init__(self, api_url: str):
37
+ """
38
+ Initialize the wrapper with the API URL
39
+
40
+ Args:
41
+ api_url (str): The URL of the Gradio deployment
42
+ """
43
+ self.api_url = api_url
44
+ self.logger = logging.getLogger(__name__)
45
+ try:
46
+ self.client = Client(api_url)
47
+ self.logger.info(f"Successfully connected to API at {api_url}")
48
+ except Exception as e:
49
+ self.logger.error(f"Failed to connect to API at {api_url}: {str(e)}")
50
+ raise ConnectionError(f"Failed to connect to API: {str(e)}")
51
+
52
+ def generate(self,
53
+ params: Union[ImageGenerationParams, Dict],
54
+ ) -> ImageGenerationResult:
55
+ """
56
+ Generate an image using the provided parameters
57
+
58
+ Args:
59
+ params: Either an ImageGenerationParams object or a dictionary with the parameters
60
+
61
+ Returns:
62
+ ImageGenerationResult: Object containing the generation results
63
+
64
+ Raises:
65
+ ValueError: If parameters are invalid
66
+ RuntimeError: If the API call fails
67
+ """
68
+ try:
69
+ # Convert dict to ImageGenerationParams if necessary
70
+ if isinstance(params, dict):
71
+ params = ImageGenerationParams(**params)
72
+
73
+ # Validate parameters
74
+ if not params.prompt:
75
+ raise ValueError("Prompt cannot be empty")
76
+
77
+ # Make the API call
78
+ result = self.client.predict(
79
+ prompt=params.prompt,
80
+ seed=params.seed,
81
+ randomize_seed=params.randomize_seed,
82
+ width=params.width,
83
+ height=params.height,
84
+ guidance_scale=params.guidance_scale,
85
+ num_inference_steps=params.num_inference_steps,
86
+ lora_scale=params.lora_scale,
87
+ api_name="/infer"
88
+ )
89
+
90
+ # Process the result
91
+ if not result or len(result) != 2:
92
+ raise RuntimeError("Invalid response from API")
93
+
94
+ image_data, seed = result
95
+ return ImageGenerationResult(image_data, seed)
96
+
97
+ except Exception as e:
98
+ self.logger.error(f"Error during image generation: {str(e)}")
99
+ raise RuntimeError(f"Failed to generate image: {str(e)}")
100
+
101
+ def generate_simple(self,
102
+ prompt: str,
103
+ **kwargs) -> ImageGenerationResult:
104
+ """
105
+ Simplified interface for generating images
106
+
107
+ Args:
108
+ prompt (str): The prompt for image generation
109
+ **kwargs: Optional parameters to override defaults
110
+
111
+ Returns:
112
+ ImageGenerationResult: Object containing the generation results
113
+ """
114
+ params = ImageGenerationParams(prompt=prompt, **kwargs)
115
+ return self.generate(params)