File size: 4,211 Bytes
7172545
b28441d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7172545
7063ff1
7172545
7063ff1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7172545
 
7063ff1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#ํ—ˆ๊น…ํŽ˜์ด์Šค์—์„œ ๋Œ์•„๊ฐˆ ์ˆ˜ ์žˆ๋„๋ก ๋ฐ”๊พธ์–ด ๋ณด์•˜์Œ
import torch
from transformers import BertTokenizerFast, BertForQuestionAnswering, Trainer, TrainingArguments
from datasets import load_dataset
from collections import defaultdict

# ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
dataset_load = load_dataset('Multimodal-Fatima/OK-VQA_train')
dataset = dataset_load['train'].select(range(300))

# ๋ถˆํ•„์š”ํ•œ ํŠน์„ฑ ์„ ํƒ
selected_features = ['image', 'answers', 'question']
selected_dataset = dataset.map(lambda ex: {feature: ex[feature] for feature in selected_features})

# ์†Œํ”„ํŠธ ์ธ์ฝ”๋”ฉ
answers_to_id = defaultdict(lambda: len(answers_to_id))
selected_dataset = selected_dataset.map(lambda ex: {
    'answers': [answers_to_id[ans] for ans in ex['answers']],
    'question': ex['question'],
    'image': ex['image']
})

id_to_answers = {v: k for k, v in answers_to_id.items()}
id_to_labels = {k: ex['answers'] for k, ex in enumerate(selected_dataset)}

selected_dataset = selected_dataset.map(lambda ex: {'answers': id_to_labels.get(ex['answers'][0]),
                                                   'question': ex['question'],
                                                   'image': ex['image']})

flattened_features = []

for ex in selected_dataset:
    flattened_example = {
        'answers': ex['answers'],
        'question': ex['question'],
        'image': ex['image'],
    }
    flattened_features.append(flattened_example)

# ๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

model_name = 'microsoft/git-base-vqav2'
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Trainer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ํ•™์Šต
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')

def preprocess_function(examples):
    tokenized_inputs = tokenizer(examples['question'], truncation=True, padding=True)
    return {
        'input_ids': tokenized_inputs['input_ids'],
        'attention_mask': tokenized_inputs['attention_mask'],
        'pixel_values': [(4, 3, 244, 244)] * len(tokenized_inputs['input_ids']),
        'pixel_mask': [1] * len(tokenized_inputs['input_ids']),
        'labels': [[label] for label in examples['answers']]
    }

dataset = load_dataset("Multimodal-Fatima/OK-VQA_train")['train'].select(range(300))
ok_vqa_dataset = dataset.map(preprocess_function, batched=True)
ok_vqa_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'])

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=20,
    per_device_train_batch_size=4,
    logging_steps=500,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ok_vqa_dataset
)

# ๋ชจ๋ธ ํ•™์Šต
trainer.train()


import gradio as gr
import torch
from transformers import BertTokenizer, BertForSequenceClassification

# ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ๊ฐ€์ค‘์น˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
model_name = 'microsoft/git-base-vqav2'  # ์‚ฌ์šฉํ•  ๋ชจ๋ธ์˜ ์ด๋ฆ„
model = BertForSequenceClassification.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)

# ์˜ˆ์ธก ํ•จ์ˆ˜ ์ •์˜
def predict_answer(image, question):
    inputs = tokenizer(question, return_tensors='pt')
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']

    # ์ด๋ฏธ์ง€์™€ ๊ด€๋ จ๋œ ์ฒ˜๋ฆฌ ์ˆ˜ํ–‰
    # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์ฝ”๋“œ๋ฅผ ์—ฌ๊ธฐ์— ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค (์ž…๋ ฅ๋œ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์ „์ฒ˜๋ฆฌ ๋“ฑ)

    # ๋ชจ๋ธ์— ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ „๋‹ฌํ•˜์—ฌ ์˜ˆ์ธก ์ˆ˜ํ–‰
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

    # ์˜ˆ์ธก ๊ฒฐ๊ณผ์—์„œ ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง„ ๋ ˆ์ด๋ธ” ID ๊ฐ€์ ธ์˜ค๊ธฐ
    predicted_label_id = torch.argmax(outputs.logits).item()
    predicted_label = id_to_label_fn(predicted_label_id)

    return predicted_label

iface = gr.Interface(
    fn=predict_answer,
    inputs=["image", "text"],
    outputs="text",
    title="Visual Question Answering",
    description="Input an image and a question to get the model's answer.",
    example=[
        "https://your_image_url.jpg",
        "What is shown in the image?"
    ]
)

iface.launch()