File size: 3,694 Bytes
0241217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Helper functions for Panoptic Narrative Grounding."""

import os
from os.path import join, isdir, exists
from typing import List

import torch
from PIL import Image
from skimage import io
import numpy as np
import textwrap
import matplotlib.pyplot as plt
from matplotlib import transforms
from imgaug.augmentables.segmaps import SegmentationMapsOnImage


def rainbow_text(x,y,ls,lc,fig, ax,**kw):
    """
    Take a list of strings ``ls`` and colors ``lc`` and place them next to each
    other, with text ls[i] being shown in color lc[i].
    
    Ref: https://stackoverflow.com/questions/9169052/partial-coloring-of-text-in-matplotlib
    """
    t = ax.transAxes

    for s,c in zip(ls,lc):
        
        text = ax.text(x,y,s+" ",color=c, transform=t, **kw)
        text.draw(fig.canvas.get_renderer())
        ex = text.get_window_extent()
        t = transforms.offset_copy(text._transform, x=ex.width, units='dots')


def find_first_index_greater_than(elements, key):
    return next(x[0] for x in enumerate(elements) if x[1] > key)


def split_caption_phrases(caption_phrases, colors, max_char_in_a_line=50):
    char_lengths = np.cumsum([len(x) for x in caption_phrases])
    thresholds = [max_char_in_a_line * i for i in range(1, 1 + char_lengths[-1] // max_char_in_a_line)]

    utt_per_line = []
    col_per_line = []
    start_index = 0
    for t in thresholds:
        index = find_first_index_greater_than(char_lengths, t)
        utt_per_line.append(caption_phrases[start_index:index])
        col_per_line.append(colors[start_index:index])
        start_index = index

    return utt_per_line, col_per_line


def show_image_and_caption(image: Image, caption_phrases: list, colors: list = None):

    if colors is None:
        colors = ["black" for _ in range(len(caption_phrases))]

    fig, axes = plt.subplots(1, 2, figsize=(15, 4))

    ax = axes[0]
    ax.imshow(image)
    ax.set_xticks([])
    ax.set_yticks([])

    ax = axes[1]
    utt_per_line, col_per_line = split_caption_phrases(caption_phrases, colors, max_char_in_a_line=50)
    y = 0.7
    for U, C in zip(utt_per_line, col_per_line):
        rainbow_text(
            0., y,
            U,
            C,
            size=15, ax=ax, fig=fig,
            horizontalalignment='left',
            verticalalignment='center',
        )
        y -= 0.11

    ax.axis("off")

    fig.tight_layout()
    plt.show()


def show_images_and_caption(
        images: List,
        caption_phrases: list,
        colors: list = None,
        image_xlabels: List=[],
        figsize=None,
        show=False,
        xlabelsize=14,
    ):

    if colors is None:
        colors = ["black" for _ in range(len(caption_phrases))]
    caption_phrases[0] = caption_phrases[0].capitalize()

    if figsize is None:
        figsize = (5 * len(images) + 8, 4)
    
    if image_xlabels is None:
        image_xlabels = ["" for _ in range(len(images))]

    fig, axes = plt.subplots(1, len(images) + 1, figsize=figsize)

    for i, image in enumerate(images):
        ax = axes[i]
        ax.imshow(image)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel(image_xlabels[i], fontsize=xlabelsize)

    ax = axes[-1]
    utt_per_line, col_per_line = split_caption_phrases(caption_phrases, colors, max_char_in_a_line=40)
    y = 0.7
    for U, C in zip(utt_per_line, col_per_line):
        rainbow_text(
            0., y,
            U,
            C,
            size=23, ax=ax, fig=fig,
            horizontalalignment='left',
            verticalalignment='center',
            # weight='bold'
        )
        y -= 0.11

    ax.axis("off")

    fig.tight_layout()
    
    if show:
        plt.show()