File size: 4,140 Bytes
55d914b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import datasets
from datasets import load_dataset, ClassLabel, concatenate_datasets
import torch
import numpy as np
import random
from PIL import Image
import json
import copy
# import torchvision.transforms as T
from torchvision import transforms
import pickle 
import re

from OmniGen import OmniGenProcessor
from OmniGen.processor import OmniGenCollator


class DatasetFromJson(torch.utils.data.Dataset):
    def __init__(

        self,

        json_file: str, 

        image_path: str,

        processer: OmniGenProcessor,

        image_transform,

        max_input_length_limit: int = 18000,

        condition_dropout_prob: float = 0.1,

        keep_raw_resolution: bool = True, 

    ):
        
        self.image_transform = image_transform
        self.processer = processer
        self.condition_dropout_prob = condition_dropout_prob
        self.max_input_length_limit = max_input_length_limit
        self.keep_raw_resolution = keep_raw_resolution

        self.data = load_dataset('json', data_files=json_file)['train']
        self.image_path = image_path

    def process_image(self, image_file):
        if self.image_path is not None:
            image_file = os.path.join(self.image_path, image_file)
        image = Image.open(image_file).convert('RGB')
        return self.image_transform(image)

    def get_example(self, index):
        example = self.data[index]
        
        instruction, input_images, output_image = example['instruction'], example['input_images'], example['output_image']
        if random.random() < self.condition_dropout_prob:
            instruction = '<cfg>'
            input_images = None
        if input_images is not None:
            input_images = [self.process_image(x) for x in input_images]
        mllm_input = self.processer.process_multi_modal_prompt(instruction, input_images)

        output_image = self.process_image(output_image)
            
        return (mllm_input, output_image)


    def __getitem__(self, index):
        return self.get_example(index)
        for _ in range(8):
            try:
                mllm_input, output_image = self.get_example(index)
                if len(mllm_input['input_ids']) > self.max_input_length_limit:
                    raise RuntimeError(f"cur number of tokens={len(mllm_input['input_ids'])}, larger than max_input_length_limit={self.max_input_length_limit}")
                return mllm_input, output_image
            except Exception as e:
                print("error when loading data: ", e)
                print(self.data[index])
                index = random.randint(0, len(self.data)-1)
        raise RuntimeError("Too many bad data.")
    

    def __len__(self):
        return len(self.data)



class TrainDataCollator(OmniGenCollator):
    def __init__(self, pad_token_id: int, hidden_size: int, keep_raw_resolution: bool):
        self.pad_token_id = pad_token_id
        self.hidden_size = hidden_size
        self.keep_raw_resolution = keep_raw_resolution

    def __call__(self, features):
        mllm_inputs = [f[0] for f in features]

        output_images = [f[1].unsqueeze(0) for f in features]
        target_img_size = [[x.size(-2), x.size(-1)] for x in output_images]

        all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)

        if not self.keep_raw_resolution:
            output_image = torch.cat(output_image, dim=0)
            if len(pixel_values) > 0:
                all_pixel_values = torch.cat(all_pixel_values, dim=0)
            else:
                all_pixel_values = None

        data = {"input_ids": all_padded_input_ids,
        "attention_mask": all_attention_mask,
        "position_ids": all_position_ids,
        "input_pixel_values": all_pixel_values,
        "input_image_sizes": all_image_sizes,
        "padding_images": all_padding_images,
        "output_images": output_images,
        }
        return data