Spaces:
Runtime error
Runtime error
#ํ๊น ํ์ด์ค์์ ๋์๊ฐ ์ ์๋๋ก ๋ฐ๊พธ์ด ๋ณด์์ | |
import torch | |
from transformers import BertTokenizerFast, BertForQuestionAnswering, Trainer, TrainingArguments | |
from datasets import load_dataset | |
from collections import defaultdict | |
# ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ | |
dataset_load = load_dataset('Multimodal-Fatima/OK-VQA_train') | |
dataset = dataset_load['train'].select(range(300)) | |
# ๋ถํ์ํ ํน์ฑ ์ ํ | |
selected_features = ['image', 'answers', 'question'] | |
selected_dataset = dataset.map(lambda ex: {feature: ex[feature] for feature in selected_features}) | |
# ์ํํธ ์ธ์ฝ๋ฉ | |
answers_to_id = defaultdict(lambda: len(answers_to_id)) | |
selected_dataset = selected_dataset.map(lambda ex: { | |
'answers': [answers_to_id[ans] for ans in ex['answers']], | |
'question': ex['question'], | |
'image': ex['image'] | |
}) | |
id_to_answers = {v: k for k, v in answers_to_id.items()} | |
id_to_labels = {k: ex['answers'] for k, ex in enumerate(selected_dataset)} | |
selected_dataset = selected_dataset.map(lambda ex: {'answers': id_to_labels.get(ex['answers'][0]), | |
'question': ex['question'], | |
'image': ex['image']}) | |
flattened_features = [] | |
for ex in selected_dataset: | |
flattened_example = { | |
'answers': ex['answers'], | |
'question': ex['question'], | |
'image': ex['image'], | |
} | |
flattened_features.append(flattened_example) | |
# ๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ | |
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer | |
model_name = 'microsoft/git-base-vqav2' | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
# Trainer๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ ํ์ต | |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased') | |
def preprocess_function(examples): | |
tokenized_inputs = tokenizer(examples['question'], truncation=True, padding=True) | |
return { | |
'input_ids': tokenized_inputs['input_ids'], | |
'attention_mask': tokenized_inputs['attention_mask'], | |
'pixel_values': [(4, 3, 244, 244)] * len(tokenized_inputs['input_ids']), | |
'pixel_mask': [1] * len(tokenized_inputs['input_ids']), | |
'labels': [[label] for label in examples['answers']] | |
} | |
dataset = load_dataset("Multimodal-Fatima/OK-VQA_train")['train'].select(range(300)) | |
ok_vqa_dataset = dataset.map(preprocess_function, batched=True) | |
ok_vqa_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels']) | |
training_args = TrainingArguments( | |
output_dir='./results', | |
num_train_epochs=20, | |
per_device_train_batch_size=4, | |
logging_steps=500, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=ok_vqa_dataset | |
) | |
# ๋ชจ๋ธ ํ์ต | |
trainer.train() | |
import gradio as gr | |
import torch | |
from transformers import BertTokenizer, BertForSequenceClassification | |
# ๋ชจ๋ธ ์ด๊ธฐํ ๋ฐ ๊ฐ์ค์น ๋ถ๋ฌ์ค๊ธฐ | |
model_name = 'microsoft/git-base-vqav2' # ์ฌ์ฉํ ๋ชจ๋ธ์ ์ด๋ฆ | |
model = BertForSequenceClassification.from_pretrained(model_name) | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
# ์์ธก ํจ์ ์ ์ | |
def predict_answer(image, question): | |
inputs = tokenizer(question, return_tensors='pt') | |
input_ids = inputs['input_ids'] | |
attention_mask = inputs['attention_mask'] | |
# ์ด๋ฏธ์ง์ ๊ด๋ จ๋ ์ฒ๋ฆฌ ์ํ | |
# ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์ฝ๋๋ฅผ ์ฌ๊ธฐ์ ์ถ๊ฐํด์ผ ํฉ๋๋ค (์ ๋ ฅ๋ ์ด๋ฏธ์ง์ ๋ํ ์ ์ฒ๋ฆฌ ๋ฑ) | |
# ๋ชจ๋ธ์ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ ๋ฌํ์ฌ ์์ธก ์ํ | |
with torch.no_grad(): | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
# ์์ธก ๊ฒฐ๊ณผ์์ ๊ฐ์ฅ ๋์ ํ๋ฅ ์ ๊ฐ์ง ๋ ์ด๋ธ ID ๊ฐ์ ธ์ค๊ธฐ | |
predicted_label_id = torch.argmax(outputs.logits).item() | |
predicted_label = id_to_label_fn(predicted_label_id) | |
return predicted_label | |
iface = gr.Interface( | |
fn=predict_answer, | |
inputs=["image", "text"], | |
outputs="text", | |
title="Visual Question Answering", | |
description="Input an image and a question to get the model's answer.", | |
example=[ | |
"https://your_image_url.jpg", | |
"What is shown in the image?" | |
] | |
) | |
iface.launch() | |