Spaces:
Runtime error
Runtime error
File size: 4,211 Bytes
7172545 b28441d 7172545 7063ff1 7172545 7063ff1 7172545 7063ff1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
#ํ๊น
ํ์ด์ค์์ ๋์๊ฐ ์ ์๋๋ก ๋ฐ๊พธ์ด ๋ณด์์
import torch
from transformers import BertTokenizerFast, BertForQuestionAnswering, Trainer, TrainingArguments
from datasets import load_dataset
from collections import defaultdict
# ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ
dataset_load = load_dataset('Multimodal-Fatima/OK-VQA_train')
dataset = dataset_load['train'].select(range(300))
# ๋ถํ์ํ ํน์ฑ ์ ํ
selected_features = ['image', 'answers', 'question']
selected_dataset = dataset.map(lambda ex: {feature: ex[feature] for feature in selected_features})
# ์ํํธ ์ธ์ฝ๋ฉ
answers_to_id = defaultdict(lambda: len(answers_to_id))
selected_dataset = selected_dataset.map(lambda ex: {
'answers': [answers_to_id[ans] for ans in ex['answers']],
'question': ex['question'],
'image': ex['image']
})
id_to_answers = {v: k for k, v in answers_to_id.items()}
id_to_labels = {k: ex['answers'] for k, ex in enumerate(selected_dataset)}
selected_dataset = selected_dataset.map(lambda ex: {'answers': id_to_labels.get(ex['answers'][0]),
'question': ex['question'],
'image': ex['image']})
flattened_features = []
for ex in selected_dataset:
flattened_example = {
'answers': ex['answers'],
'question': ex['question'],
'image': ex['image'],
}
flattened_features.append(flattened_example)
# ๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
model_name = 'microsoft/git-base-vqav2'
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Trainer๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ ํ์ต
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
def preprocess_function(examples):
tokenized_inputs = tokenizer(examples['question'], truncation=True, padding=True)
return {
'input_ids': tokenized_inputs['input_ids'],
'attention_mask': tokenized_inputs['attention_mask'],
'pixel_values': [(4, 3, 244, 244)] * len(tokenized_inputs['input_ids']),
'pixel_mask': [1] * len(tokenized_inputs['input_ids']),
'labels': [[label] for label in examples['answers']]
}
dataset = load_dataset("Multimodal-Fatima/OK-VQA_train")['train'].select(range(300))
ok_vqa_dataset = dataset.map(preprocess_function, batched=True)
ok_vqa_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'])
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=20,
per_device_train_batch_size=4,
logging_steps=500,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=ok_vqa_dataset
)
# ๋ชจ๋ธ ํ์ต
trainer.train()
import gradio as gr
import torch
from transformers import BertTokenizer, BertForSequenceClassification
# ๋ชจ๋ธ ์ด๊ธฐํ ๋ฐ ๊ฐ์ค์น ๋ถ๋ฌ์ค๊ธฐ
model_name = 'microsoft/git-base-vqav2' # ์ฌ์ฉํ ๋ชจ๋ธ์ ์ด๋ฆ
model = BertForSequenceClassification.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
# ์์ธก ํจ์ ์ ์
def predict_answer(image, question):
inputs = tokenizer(question, return_tensors='pt')
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
# ์ด๋ฏธ์ง์ ๊ด๋ จ๋ ์ฒ๋ฆฌ ์ํ
# ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์ฝ๋๋ฅผ ์ฌ๊ธฐ์ ์ถ๊ฐํด์ผ ํฉ๋๋ค (์
๋ ฅ๋ ์ด๋ฏธ์ง์ ๋ํ ์ ์ฒ๋ฆฌ ๋ฑ)
# ๋ชจ๋ธ์ ์
๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ ๋ฌํ์ฌ ์์ธก ์ํ
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
# ์์ธก ๊ฒฐ๊ณผ์์ ๊ฐ์ฅ ๋์ ํ๋ฅ ์ ๊ฐ์ง ๋ ์ด๋ธ ID ๊ฐ์ ธ์ค๊ธฐ
predicted_label_id = torch.argmax(outputs.logits).item()
predicted_label = id_to_label_fn(predicted_label_id)
return predicted_label
iface = gr.Interface(
fn=predict_answer,
inputs=["image", "text"],
outputs="text",
title="Visual Question Answering",
description="Input an image and a question to get the model's answer.",
example=[
"https://your_image_url.jpg",
"What is shown in the image?"
]
)
iface.launch()
|