File size: 7,493 Bytes
ab6cb7b
 
 
 
 
9c1e305
ab6cb7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9781d82
b49c611
 
bd05777
059f429
 
ab6cb7b
 
 
 
9c1e305
ab6cb7b
9c1e305
ab6cb7b
 
 
9c1e305
ab6cb7b
 
9c1e305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab6cb7b
 
 
 
 
 
 
 
9c1e305
ab6cb7b
 
 
 
 
9c1e305
 
 
 
 
 
 
 
ab6cb7b
 
 
 
 
 
9c1e305
ab6cb7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911f98c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd05777
 
 
 
911f98c
 
bd05777
 
911f98c
 
bd05777
 
911f98c
 
 
 
 
 
 
 
 
 
 
ab6cb7b
 
911f98c
 
 
 
 
 
 
 
 
 
 
ab6cb7b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import os
import base64
import boto3
import json
import logging
from datetime import datetime
from dotenv import load_dotenv
from functools import wraps
from botocore.config import Config
from botocore.exceptions import ClientError

load_dotenv()
# Move custom exceptions to the top
class ImageError(Exception):
    def __init__(self, message):
        self.message = message

def handle_bedrock_errors(func):
    logger = logging.getLogger(__name__)
    logging.basicConfig(level=logging.INFO)
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except ClientError as err:
            logger.error(f"Bedrock client error: {err.response['Error']['Message']}")
            raise ImageError(f"Client error: {err.response['Error']['Message']}")
        except Exception as err:
            logger.error(f"Unexpected error: {str(err)}")
            raise ImageError(f"Unexpected error: {str(err)}")
    return wrapper

aws_id = os.getenv('AWS_ID')
aws_secret = os.getenv('AWS_SECRET')
rate_limit = int(os.getenv('RATE_LIMIT'))
nova_image_bucket=os.getenv('NOVA_IMAGE_BUCKET')
bucket_region=os.getenv('BUCKET_REGION')
rate_limit_message = """<div style='text-align: center;'>Rate limit exceeded. Check back later, use the 
            <a href='https://docs.aws.amazon.com/bedrock/latest/userguide/playgrounds.html'>Bedrock Playground</a> or
            try it out without an AWS account on <a href='https://partyrock.aws/'>PartyRock</a>.</div>"""

# Function to generate an image using Amazon Nova Canvas model
class BedrockClient:

    def __init__(self, aws_id, aws_secret, model_id, timeout=300):
        self.model_id = model_id
        self.bedrock_client = boto3.client(
            service_name='bedrock-runtime',
            aws_access_key_id=aws_id,
            aws_secret_access_key=aws_secret,
            region_name='us-east-1',
            config=Config(read_timeout=timeout)
        )
        self.s3_client = boto3.client(
            service_name='s3',
            aws_access_key_id=aws_id,
            aws_secret_access_key=aws_secret,
            region_name=bucket_region
        )

    def _store_response(self, response_body, image_data=None):
        """Store response and image in S3."""
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        
        # Store response body
        response_key = f'responses/{timestamp}_response.json'
        self.s3_client.put_object(
            Bucket=nova_image_bucket,
            Key=response_key,
            Body=json.dumps(response_body),
            ContentType='application/json'
        )
        
        # Store image if present
        if image_data:
            image_key = f'images/{timestamp}_image.png'
            self.s3_client.put_object(
                Bucket=nova_image_bucket,
                Key=image_key,
                Body=image_data,
                ContentType='image/png'
            )
    
    
    def _handle_error(self, err):
        """Handle client errors"""
        raise ImageError(f"Client error: {err.response['Error']['Message']}")
    
    def generate_image(self, body):
        """Generate image using Bedrock service."""
        try:
            response = self.bedrock_client.invoke_model(
                body=body,
                modelId=self.model_id,
                accept="application/json",
                contentType="application/json"
            )
            image_data =  self._process_response(response)

            self._store_response(
                body,
                image_data
            )

            return image_data
        except ClientError as err:
            self._handle_error(err)
    
    @handle_bedrock_errors
    def generate_prompt(self, body):
        try:
            response = self.bedrock_client.converse(
                modelId=self.model_id, 
                messages=body
            )
            return self._process_response(response)
        except ClientError as err:
            self._handle_error(err)

    @handle_bedrock_errors
    def _process_response(self, response):
        """Process successful response for both image and text."""
        if "error" in response:
            raise ImageError(f"Generation error: {response['error']}")
        
        if "output" in response and "message" in response["output"]:
            message_content = response["output"]["message"]["content"]
            if message_content and "text" in message_content[0]:
                return message_content[0]["text"]

        response_body = json.loads(response.get("body").read())    
        if "images" in response_body:
            return base64.b64decode(response_body.get("images")[0].encode('ascii'))
        
        raise ImageError("Unexpected response format.")

def check_rate_limit(body):
    body = json.loads(body)
    quality = body.get('imageGenerationConfig', {}).get('quality', 'standard')
    
    s3_client = boto3.client(
        service_name='s3',
        aws_access_key_id=os.getenv('AWS_ID'),
        aws_secret_access_key=os.getenv('AWS_SECRET'),
        region_name=bucket_region
    )
    
    try:
        # Get current rate limit data
        response = s3_client.get_object(
            Bucket=nova_image_bucket,
            Key='rate-limit/jsonData.json'
        )
        rate_data = json.loads(response['Body'].read().decode('utf-8'))
    except ClientError as e:
        if e.response['Error']['Code'] == 'NoSuchKey':
            # Initialize if file doesn't exist
            rate_data = {'premium': [], 'standard': []}
        else:
            raise ImageError(f"Failed to check rate limit: {str(e)}")

    # Get current timestamp
    current_time = datetime.now().timestamp()
    # Keep only requests from last minute
    twenty_minutes_ago = current_time - 1200
    
    # Clean up old entries
    rate_data['premium'] = [t for t in rate_data['premium'] if t > twenty_minutes_ago]
    rate_data['standard'] = [t for t in rate_data['standard'] if t > twenty_minutes_ago]

    # Calculate the total count of requests in the last 20 minutes
    total_count = len(rate_data['premium']) * 2 + len(rate_data['standard'])

    # Check limits based on quality
    if quality == 'premium':
        if total_count + 2 > rate_limit:  # Check if adding 2 would exceed the threshold
            raise ImageError(rate_limit_message)
        rate_data['premium'].append(current_time)
    else:  # standard
        if total_count + 1 > rate_limit:  # Check if adding 1 would exceed the threshold
            raise ImageError(rate_limit_message)
        rate_data['standard'].append(current_time)
    
    # Update rate limit file
    s3_client.put_object(
        Bucket=nova_image_bucket,
        Key='rate-limit/jsonData.json',
        Body=json.dumps(rate_data),
        ContentType='application/json'
    )
    

def generate_image(body):
    """Generate image using Bedrock service."""
    try:
        check_rate_limit(body)
        client = BedrockClient(
            aws_id=os.getenv('AWS_ID'),
            aws_secret=os.getenv('AWS_SECRET'),
            model_id='amazon.nova-canvas-v1:0'
        )
        return client.generate_image(body)
    except ImageError as e:
        return str(e)
        

def generate_prompt(body):
    client = BedrockClient(
        aws_id=os.getenv('AWS_ID'),
        aws_secret=os.getenv('AWS_SECRET'),
        model_id='us.amazon.nova-lite-v1:0'
    )
    return client.generate_prompt(body)