File size: 2,647 Bytes
d65c9b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from PIL import Image
import os
import torch

def read_images_in_path(path, size = (512,512)):
    image_paths = []
    for filename in os.listdir(path):
        if filename.endswith(".png") or filename.endswith(".jpg") or filename.endswith(".jpeg"):
            image_path = os.path.join(path, filename)
            image_paths.append(image_path)
    image_paths = sorted(image_paths)
    return [Image.open(image_path).convert("RGB").resize(size) for image_path in image_paths]

def concatenate_images(image_lists, return_list = False):
    num_rows = len(image_lists[0])
    num_columns = len(image_lists)
    image_width = image_lists[0][0].width
    image_height = image_lists[0][0].height

    grid_width = num_columns * image_width
    grid_height = num_rows * image_height if not return_list else image_height
    if not return_list:
        grid_image = [Image.new('RGB', (grid_width, grid_height))]
    else:
        grid_image = [Image.new('RGB', (grid_width, grid_height)) for i in range(num_rows)]

    for i in range(num_rows):
        row_index = i if return_list else 0
        for j in range(num_columns):
            image = image_lists[j][i]
            x_offset = j * image_width
            y_offset = i * image_height if not return_list else 0
            grid_image[row_index].paste(image, (x_offset, y_offset))

    return grid_image if return_list else grid_image[0]

def concatenate_images_single(image_lists):
    num_columns = len(image_lists)
    image_width = image_lists[0].width
    image_height = image_lists[0].height

    grid_width = num_columns * image_width
    grid_height = image_height
    grid_image = Image.new('RGB', (grid_width, grid_height))

    for j in range(num_columns):
        image = image_lists[j]
        x_offset = j * image_width
        y_offset = 0
        grid_image.paste(image, (x_offset, y_offset))

    return grid_image

def get_captions_for_images(images, device):
    from transformers import Blip2Processor, Blip2ForConditionalGeneration

    processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
    model = Blip2ForConditionalGeneration.from_pretrained(
        "Salesforce/blip2-opt-2.7b", load_in_8bit=True, device_map={"": 0}, torch_dtype=torch.float16
    )  # doctest: +IGNORE_RESULT

    res = []
    
    for image in images:
        inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)

        generated_ids = model.generate(**inputs)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        res.append(generated_text)

    del processor
    del model
    
    return res