Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import GPT2Model, ViTModel, GPT2Tokenizer, ViTImageProcessor | |
from captum.attr import IntegratedGradients | |
from PIL import Image | |
import numpy as np | |
import cv2 | |
# 定义多模态模型 | |
class MultiModalModel(torch.nn.Module): | |
def __init__(self, gpt2_model_name="gpt2", vit_model_name="google/vit-base-patch16-224-in21k"): | |
super(MultiModalModel, self).__init__() | |
self.gpt2 = GPT2Model.from_pretrained(gpt2_model_name) | |
self.vit = ViTModel.from_pretrained(vit_model_name) | |
self.classifier = torch.nn.Linear(self.gpt2.config.hidden_size + self.vit.config.hidden_size, 2) | |
def forward(self, input_ids, attention_mask, pixel_values): | |
gpt2_outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask) | |
text_features = gpt2_outputs.last_hidden_state[:, -1, :] | |
vit_outputs = self.vit(pixel_values=pixel_values) | |
image_features = vit_outputs.last_hidden_state[:, 0, :] | |
fused_features = torch.cat((text_features, image_features), dim=1) | |
logits = self.classifier(fused_features) | |
return logits | |
# 加载模型 | |
def load_model(): | |
model_name = "Muhusjf/ViT-GPT2-multimodal-model" | |
model = MultiModalModel() | |
model.load_state_dict(torch.hub.load_state_dict_from_url( | |
f"https://huggingface.co/{model_name}/resolve/main/pytorch_model.bin", | |
map_location=torch.device('cpu') | |
)) | |
model.eval() | |
return model | |
# 转换张量为 PIL 图像 | |
def convert_tensor_to_pil(tensor_image): | |
if isinstance(tensor_image, torch.Tensor): | |
tensor_image = tensor_image.numpy() | |
image_np = np.transpose(tensor_image, (1, 2, 0)) | |
if image_np.max() <= 1.0: | |
image_np = (image_np * 255).astype(np.uint8) | |
return Image.fromarray(image_np) | |
# 自定义前向函数用于集成梯度 | |
def custom_forward(pixel_values, input_ids, attention_mask): | |
logits = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values) | |
return logits | |
# 初始化集成梯度 | |
integrated_gradients = IntegratedGradients(custom_forward) | |
# 初始化模型和加载器 | |
model = load_model() | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
tokenizer.pad_token = tokenizer.eos_token | |
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") | |
# 定义推理函数 | |
def predict_text(image, text): | |
image_features = feature_extractor(images=image, return_tensors="pt") | |
inputs = tokenizer.encode_plus( | |
f"Question: {text} Answer:", | |
return_tensors="pt", | |
max_length=128, | |
truncation=True, | |
padding="max_length" | |
) | |
input_ids = inputs["input_ids"].long() | |
attention_mask = inputs["attention_mask"].long() | |
pixel_values = image_features["pixel_values"] | |
with torch.no_grad(): | |
logits = model(input_ids, attention_mask, pixel_values) | |
prediction = torch.argmax(logits, dim=1).item() | |
label = "yes" if prediction == 1 else "no" | |
return label | |
# 定义归因分析函数 | |
def generate_attribution(image, text): | |
image_features = feature_extractor(images=image, return_tensors="pt") | |
inputs = tokenizer.encode_plus( | |
f"Question: {text} Answer:", | |
return_tensors="pt", | |
max_length=128, | |
truncation=True, | |
padding="max_length" | |
) | |
input_ids = inputs["input_ids"].long() | |
attention_mask = inputs["attention_mask"].long() | |
pixel_values = image_features["pixel_values"] | |
with torch.no_grad(): | |
logits = model(input_ids, attention_mask, pixel_values) | |
prediction = torch.argmax(logits, dim=1).item() | |
attributions, _ = integrated_gradients.attribute( | |
inputs=pixel_values, | |
target=prediction, | |
additional_forward_args=(input_ids, attention_mask), | |
n_steps=1, | |
return_convergence_delta=True | |
) | |
attribution_image = attributions.squeeze().cpu().numpy() | |
attribution_image = (attribution_image - attribution_image.min()) / (attribution_image.max() - attribution_image.min()) | |
attribution_image = np.uint8(255 * attribution_image) | |
attribution_image_real = convert_tensor_to_pil(attribution_image) | |
attribution_gray = cv2.cvtColor(np.array(attribution_image_real), cv2.COLOR_RGB2GRAY) | |
_, binary_mask = cv2.threshold(attribution_gray, 128, 255, cv2.THRESH_BINARY) | |
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
original_image = convert_tensor_to_pil(pixel_values.squeeze(0).numpy()) | |
original_image_np = np.array(original_image) | |
cv2.drawContours(original_image_np, contours, -1, (255, 0, 0), 2) | |
return attribution_image_real, Image.fromarray(original_image_np) | |
# 创建 Gradio 界面 | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", interactive=True, height=400) | |
question_input = gr.Textbox(label="Question", lines=3, max_lines=3) | |
clear_button = gr.Button("Clear") | |
with gr.Column(): | |
predict_button = gr.Button("Answer") | |
prediction_output = gr.Textbox(label="Answer", lines=2, interactive=False) | |
attribution_button = gr.Button("Generate Attribution") | |
with gr.Row(): | |
attribution_image_1 = gr.Image(label="Attribution Image", interactive=False, height=400) | |
attribution_image_2 = gr.Image(label="Attribution with Contours", interactive=False, height=400) | |
# 按钮事件绑定 | |
predict_button.click(predict_text, inputs=[input_image, question_input], outputs=prediction_output) | |
attribution_button.click(generate_attribution, inputs=[input_image, question_input], outputs=[attribution_image_1, attribution_image_2]) | |
clear_button.click(lambda: (None, "", ""), outputs=[input_image, question_input, prediction_output]) | |
# 启动 Gradio 界面 | |
demo.launch() | |