| """ |
| Usage: |
| python3 -m unittest tests.test_image_utils |
| """ |
|
|
| import base64 |
| from io import BytesIO |
| import os |
| import unittest |
|
|
| import numpy as np |
| from PIL import Image |
|
|
| from fastchat.utils import ( |
| resize_image_and_return_image_in_bytes, |
| image_moderation_filter, |
| ) |
| from fastchat.conversation import get_conv_template |
|
|
|
|
| def check_byte_size_in_mb(image_base64_str): |
| return len(image_base64_str) / 1024 / 1024 |
|
|
|
|
| def generate_random_image(target_size_mb, image_format="PNG"): |
| |
| target_size_bytes = target_size_mb * 1024 * 1024 |
|
|
| |
| dimension = int((target_size_bytes / 3) ** 0.5) |
|
|
| |
| pixel_data = np.random.randint(0, 256, (dimension, dimension, 3), dtype=np.uint8) |
|
|
| |
| img = Image.fromarray(pixel_data) |
|
|
| |
| temp_filename = "temp_image." + image_format.lower() |
| img.save(temp_filename, format=image_format) |
|
|
| |
| while os.path.getsize(temp_filename) < target_size_bytes: |
| |
| dimension += 1 |
| pixel_data = np.random.randint( |
| 0, 256, (dimension, dimension, 3), dtype=np.uint8 |
| ) |
| img = Image.fromarray(pixel_data) |
| img.save(temp_filename, format=image_format) |
|
|
| return img |
|
|
|
|
| class DontResizeIfLessThanMaxTest(unittest.TestCase): |
| def test_dont_resize_if_less_than_max(self): |
| max_image_size = 5 |
| initial_size_mb = 0.1 |
| img = generate_random_image(initial_size_mb) |
|
|
| image_bytes = BytesIO() |
| img.save(image_bytes, format="PNG") |
| previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) |
|
|
| image_bytes = resize_image_and_return_image_in_bytes( |
| img, max_image_size_mb=max_image_size |
| ) |
| new_image_size = check_byte_size_in_mb(image_bytes.getvalue()) |
|
|
| self.assertEqual(previous_image_size, new_image_size) |
|
|
|
|
| class ResizeLargeImageForModerationEndpoint(unittest.TestCase): |
| def test_resize_large_image_and_send_to_moderation_filter(self): |
| initial_size_mb = 6 |
| img = generate_random_image(initial_size_mb) |
|
|
| nsfw_flag, csam_flag = image_moderation_filter(img) |
| self.assertFalse(nsfw_flag) |
| self.assertFalse(nsfw_flag) |
|
|
|
|
| class DontResizeIfMaxImageSizeIsNone(unittest.TestCase): |
| def test_dont_resize_if_max_image_size_is_none(self): |
| initial_size_mb = 0.2 |
| img = generate_random_image(initial_size_mb) |
|
|
| image_bytes = BytesIO() |
| img.save(image_bytes, format="PNG") |
| previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) |
|
|
| image_bytes = resize_image_and_return_image_in_bytes( |
| img, max_image_size_mb=None |
| ) |
| new_image_size = check_byte_size_in_mb(image_bytes.getvalue()) |
|
|
| self.assertEqual(previous_image_size, new_image_size) |
|
|
|
|
| class OpenAIConversationDontResizeImage(unittest.TestCase): |
| def test(self): |
| conv = get_conv_template("chatgpt") |
| initial_size_mb = 0.2 |
| img = generate_random_image(initial_size_mb) |
| image_bytes = BytesIO() |
| img.save(image_bytes, format="PNG") |
| previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) |
|
|
| resized_img = conv.convert_image_to_base64(img) |
| resized_img_bytes = base64.b64decode(resized_img) |
| new_image_size = check_byte_size_in_mb(resized_img_bytes) |
|
|
| self.assertEqual(previous_image_size, new_image_size) |
|
|
|
|
| class ClaudeConversationResizesCorrectly(unittest.TestCase): |
| def test(self): |
| conv = get_conv_template("claude-3-haiku-20240307") |
| initial_size_mb = 5 |
| img = generate_random_image(initial_size_mb) |
| image_bytes = BytesIO() |
| img.save(image_bytes, format="PNG") |
| previous_image_size = check_byte_size_in_mb(image_bytes.getvalue()) |
|
|
| resized_img = conv.convert_image_to_base64(img) |
| new_base64_image_size = check_byte_size_in_mb(resized_img) |
| new_image_bytes_size = check_byte_size_in_mb(base64.b64decode(resized_img)) |
|
|
| self.assertLess(new_image_bytes_size, previous_image_size) |
| self.assertLessEqual(new_image_bytes_size, conv.max_image_size_mb) |
| self.assertLessEqual(new_base64_image_size, 5) |
|
|