Spaces:
Running
Running
import os | |
import base64 | |
import boto3 | |
import json | |
import logging | |
from datetime import datetime | |
from dotenv import load_dotenv | |
from functools import wraps | |
from botocore.config import Config | |
from botocore.exceptions import ClientError | |
load_dotenv() | |
# Move custom exceptions to the top | |
class ImageError(Exception): | |
def __init__(self, message): | |
self.message = message | |
def handle_bedrock_errors(func): | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
def wrapper(*args, **kwargs): | |
try: | |
return func(*args, **kwargs) | |
except ClientError as err: | |
logger.error(f"Bedrock client error: {err.response['Error']['Message']}") | |
raise ImageError(f"Client error: {err.response['Error']['Message']}") | |
except Exception as err: | |
logger.error(f"Unexpected error: {str(err)}") | |
raise ImageError(f"Unexpected error: {str(err)}") | |
return wrapper | |
aws_id = os.getenv('AWS_ID') | |
aws_secret = os.getenv('AWS_SECRET') | |
rate_limit = int(os.getenv('RATE_LIMIT')) | |
nova_image_bucket=os.getenv('NOVA_IMAGE_BUCKET') | |
bucket_region=os.getenv('BUCKET_REGION') | |
rate_limit_message = """<div style='text-align: center;'>Rate limit exceeded. Check back later, use the | |
<a href='https://docs.aws.amazon.com/bedrock/latest/userguide/playgrounds.html'>Bedrock Playground</a> or | |
try it out without an AWS account on <a href='https://partyrock.aws/'>PartyRock</a>.</div>""" | |
# Function to generate an image using Amazon Nova Canvas model | |
class BedrockClient: | |
def __init__(self, aws_id, aws_secret, model_id, timeout=300): | |
self.model_id = model_id | |
self.bedrock_client = boto3.client( | |
service_name='bedrock-runtime', | |
aws_access_key_id=aws_id, | |
aws_secret_access_key=aws_secret, | |
region_name='us-east-1', | |
config=Config(read_timeout=timeout) | |
) | |
self.s3_client = boto3.client( | |
service_name='s3', | |
aws_access_key_id=aws_id, | |
aws_secret_access_key=aws_secret, | |
region_name=bucket_region | |
) | |
def _store_response(self, response_body, image_data=None): | |
"""Store response and image in S3.""" | |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
# Store response body | |
response_key = f'responses/{timestamp}_response.json' | |
self.s3_client.put_object( | |
Bucket=nova_image_bucket, | |
Key=response_key, | |
Body=json.dumps(response_body), | |
ContentType='application/json' | |
) | |
# Store image if present | |
if image_data: | |
image_key = f'images/{timestamp}_image.png' | |
self.s3_client.put_object( | |
Bucket=nova_image_bucket, | |
Key=image_key, | |
Body=image_data, | |
ContentType='image/png' | |
) | |
def _handle_error(self, err): | |
"""Handle client errors""" | |
raise ImageError(f"Client error: {err.response['Error']['Message']}") | |
def generate_image(self, body): | |
"""Generate image using Bedrock service.""" | |
try: | |
response = self.bedrock_client.invoke_model( | |
body=body, | |
modelId=self.model_id, | |
accept="application/json", | |
contentType="application/json" | |
) | |
image_data = self._process_response(response) | |
self._store_response( | |
body, | |
image_data | |
) | |
return image_data | |
except ClientError as err: | |
self._handle_error(err) | |
def generate_prompt(self, body): | |
try: | |
response = self.bedrock_client.converse( | |
modelId=self.model_id, | |
messages=body | |
) | |
return self._process_response(response) | |
except ClientError as err: | |
self._handle_error(err) | |
def _process_response(self, response): | |
"""Process successful response for both image and text.""" | |
if "error" in response: | |
raise ImageError(f"Generation error: {response['error']}") | |
if "output" in response and "message" in response["output"]: | |
message_content = response["output"]["message"]["content"] | |
if message_content and "text" in message_content[0]: | |
return message_content[0]["text"] | |
response_body = json.loads(response.get("body").read()) | |
if "images" in response_body: | |
return base64.b64decode(response_body.get("images")[0].encode('ascii')) | |
raise ImageError("Unexpected response format.") | |
def check_rate_limit(body): | |
body = json.loads(body) | |
quality = body.get('imageGenerationConfig', {}).get('quality', 'standard') | |
s3_client = boto3.client( | |
service_name='s3', | |
aws_access_key_id=os.getenv('AWS_ID'), | |
aws_secret_access_key=os.getenv('AWS_SECRET'), | |
region_name=bucket_region | |
) | |
try: | |
# Get current rate limit data | |
response = s3_client.get_object( | |
Bucket=nova_image_bucket, | |
Key='rate-limit/jsonData.json' | |
) | |
rate_data = json.loads(response['Body'].read().decode('utf-8')) | |
except ClientError as e: | |
if e.response['Error']['Code'] == 'NoSuchKey': | |
# Initialize if file doesn't exist | |
rate_data = {'premium': [], 'standard': []} | |
else: | |
raise ImageError(f"Failed to check rate limit: {str(e)}") | |
# Get current timestamp | |
current_time = datetime.now().timestamp() | |
# Keep only requests from last minute | |
twenty_minutes_ago = current_time - 1200 | |
# Clean up old entries | |
rate_data['premium'] = [t for t in rate_data['premium'] if t > twenty_minutes_ago] | |
rate_data['standard'] = [t for t in rate_data['standard'] if t > twenty_minutes_ago] | |
# Calculate the total count of requests in the last 20 minutes | |
total_count = len(rate_data['premium']) * 2 + len(rate_data['standard']) | |
# Check limits based on quality | |
if quality == 'premium': | |
if total_count + 2 > rate_limit: # Check if adding 2 would exceed the threshold | |
raise ImageError(rate_limit_message) | |
rate_data['premium'].append(current_time) | |
else: # standard | |
if total_count + 1 > rate_limit: # Check if adding 1 would exceed the threshold | |
raise ImageError(rate_limit_message) | |
rate_data['standard'].append(current_time) | |
# Update rate limit file | |
s3_client.put_object( | |
Bucket=nova_image_bucket, | |
Key='rate-limit/jsonData.json', | |
Body=json.dumps(rate_data), | |
ContentType='application/json' | |
) | |
def generate_image(body): | |
"""Generate image using Bedrock service.""" | |
try: | |
check_rate_limit(body) | |
client = BedrockClient( | |
aws_id=os.getenv('AWS_ID'), | |
aws_secret=os.getenv('AWS_SECRET'), | |
model_id='amazon.nova-canvas-v1:0' | |
) | |
return client.generate_image(body) | |
except ImageError as e: | |
return str(e) | |
def generate_prompt(body): | |
client = BedrockClient( | |
aws_id=os.getenv('AWS_ID'), | |
aws_secret=os.getenv('AWS_SECRET'), | |
model_id='us.amazon.nova-lite-v1:0' | |
) | |
return client.generate_prompt(body) | |