|
import time |
|
import requests |
|
from io import BytesIO |
|
from os import path |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
|
|
class TestImageSetOnline(Dataset): |
|
""" Test Image set with hugging face CLIP preprocess interface |
|
|
|
Args: |
|
Dataset (torch.utils.data.Dataset): |
|
""" |
|
def __init__(self, processor, image_list, timeout_base=0.5, timeout_mul=2): |
|
""" |
|
Args: |
|
processor (CLIP preprocessor): process data to a CLIP digestable format |
|
image_list (pandas.DataFrame): pandas.DataFrame that contains image metadata |
|
timeout_base (float, optional): initial timeout parameter. Defaults to 0.5. |
|
timeout_mul (int, optional): multiplier on timeout every time reqeust fails. Defaults to 2. |
|
""" |
|
self.image_list = image_list |
|
self.processor = processor |
|
self.timeout_base = timeout_base |
|
self.timeout = self.timeout_base |
|
self.timeout_mul = timeout_mul |
|
|
|
def __getitem__(self, index): |
|
row = self.image_list[index] |
|
url = str(row['coco_url']) |
|
_id = str(row['id']) |
|
txt, img = None, None |
|
flag = True |
|
while flag: |
|
try: |
|
|
|
response = requests.get(url) |
|
img = Image.open(BytesIO(response.content)) |
|
img_s = img.size |
|
if img.mode in ['L', 'CMYK', 'RGBA']: |
|
|
|
img = img.convert('RGB') |
|
|
|
ret = self.processor(text=txt, images=img, return_tensor='pt') |
|
img = ret['pixel_values'][0] |
|
|
|
flag = False |
|
|
|
if self.timeout > self.timeout_base: |
|
self.timeout /= self.timeout_mul |
|
except Exception as e: |
|
print(f"{_id} {url}: {str(e)}") |
|
if type(e) is KeyboardInterrupt: |
|
raise e |
|
time.sleep(self.timeout) |
|
|
|
self.timeout *= self.timeout_mul |
|
return _id, url, img, img_s |
|
|
|
def get(self, url): |
|
_id = url |
|
txt, img = None, None |
|
flag = True |
|
while flag: |
|
try: |
|
|
|
response = requests.get(url) |
|
img = Image.open(BytesIO(response.content)) |
|
img_s = img.size |
|
if img.mode in ['L', 'CMYK', 'RGBA']: |
|
|
|
img = img.convert('RGB') |
|
|
|
ret = self.processor(text=txt, images=img, return_tensor='pt') |
|
img = ret['pixel_values'][0] |
|
|
|
flag = False |
|
|
|
if self.timeout > self.timeout_base: |
|
self.timeout /= self.timeout_mul |
|
except Exception as e: |
|
print(f"{_id} {url}: {str(e)}") |
|
if type(e) is KeyboardInterrupt: |
|
raise e |
|
time.sleep(self.timeout) |
|
|
|
self.timeout *= self.timeout_mul |
|
return _id, url, img, img_s |
|
|
|
|
|
def __len__(self,): |
|
return len(self.image_list) |
|
|
|
def __add__(self, other): |
|
self.image_list += other.image_list |
|
return self |
|
|
|
class TestImageSet(TestImageSetOnline): |
|
def __init__(self, droot, processor, image_list, timeout_base=0.5, timeout_mul=2): |
|
super().__init__(processor, image_list, timeout_base, timeout_mul) |
|
self.droot = droot |
|
|
|
def __getitem__(self, index): |
|
row = self.image_list[index] |
|
url = str(row['coco_url']) |
|
_id = '_'.join([url.split('/')[-2], str(row['id'])]) |
|
txt, img = None, None |
|
|
|
img = Image.open(path.join(self.droot, |
|
url.split('http://images.cocodataset.org/')[1])) |
|
img_s = img.size |
|
if img.mode in ['L', 'CMYK', 'RGBA']: |
|
|
|
img = img.convert('RGB') |
|
|
|
ret = self.processor(text=txt, images=img, return_tensor='pt') |
|
img = ret['pixel_values'][0] |
|
|
|
return _id, url, img, img_s |
|
|