Spaces:
Sleeping
Sleeping
File size: 6,070 Bytes
15f7c75 ade3964 9261673 15f7c75 9261673 15f7c75 b64070e 15f7c75 b64070e 15f7c75 b64070e ade3964 b64070e d044829 b64070e ade3964 15f7c75 559bd94 a139433 15f7c75 b64070e d044829 b64070e 15f7c75 b64070e a139433 15f7c75 9261673 d044829 9261673 d044829 a5ee181 15f7c75 a5ee181 9e5f3db a5ee181 9e5f3db a5ee181 9e5f3db 7724540 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import gradio as gr
import torch
from transformers import GPT2Model, ViTModel, GPT2Tokenizer, ViTImageProcessor
from captum.attr import IntegratedGradients
from PIL import Image
import numpy as np
import cv2
# 定义多模态模型
class MultiModalModel(torch.nn.Module):
def __init__(self, gpt2_model_name="gpt2", vit_model_name="google/vit-base-patch16-224-in21k"):
super(MultiModalModel, self).__init__()
self.gpt2 = GPT2Model.from_pretrained(gpt2_model_name)
self.vit = ViTModel.from_pretrained(vit_model_name)
self.classifier = torch.nn.Linear(self.gpt2.config.hidden_size + self.vit.config.hidden_size, 2)
def forward(self, input_ids, attention_mask, pixel_values):
gpt2_outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask)
text_features = gpt2_outputs.last_hidden_state[:, -1, :]
vit_outputs = self.vit(pixel_values=pixel_values)
image_features = vit_outputs.last_hidden_state[:, 0, :]
fused_features = torch.cat((text_features, image_features), dim=1)
logits = self.classifier(fused_features)
return logits
# 加载模型
def load_model():
model_name = "Muhusjf/ViT-GPT2-multimodal-model"
model = MultiModalModel()
model.load_state_dict(torch.hub.load_state_dict_from_url(
f"https://huggingface.co/{model_name}/resolve/main/pytorch_model.bin",
map_location=torch.device('cpu')
))
model.eval()
return model
# 转换张量为 PIL 图像
def convert_tensor_to_pil(tensor_image):
if isinstance(tensor_image, torch.Tensor):
tensor_image = tensor_image.numpy()
image_np = np.transpose(tensor_image, (1, 2, 0))
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
return Image.fromarray(image_np)
# 自定义前向函数用于集成梯度
def custom_forward(pixel_values, input_ids, attention_mask):
logits = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
return logits
# 初始化集成梯度
integrated_gradients = IntegratedGradients(custom_forward)
# 初始化模型和加载器
model = load_model()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
# 定义推理函数
def predict_text(image, text):
image_features = feature_extractor(images=image, return_tensors="pt")
inputs = tokenizer.encode_plus(
f"Question: {text} Answer:",
return_tensors="pt",
max_length=128,
truncation=True,
padding="max_length"
)
input_ids = inputs["input_ids"].long()
attention_mask = inputs["attention_mask"].long()
pixel_values = image_features["pixel_values"]
with torch.no_grad():
logits = model(input_ids, attention_mask, pixel_values)
prediction = torch.argmax(logits, dim=1).item()
label = "yes" if prediction == 1 else "no"
return label
# 定义归因分析函数
def generate_attribution(image, text):
image_features = feature_extractor(images=image, return_tensors="pt")
inputs = tokenizer.encode_plus(
f"Question: {text} Answer:",
return_tensors="pt",
max_length=128,
truncation=True,
padding="max_length"
)
input_ids = inputs["input_ids"].long()
attention_mask = inputs["attention_mask"].long()
pixel_values = image_features["pixel_values"]
with torch.no_grad():
logits = model(input_ids, attention_mask, pixel_values)
prediction = torch.argmax(logits, dim=1).item()
attributions, _ = integrated_gradients.attribute(
inputs=pixel_values,
target=prediction,
additional_forward_args=(input_ids, attention_mask),
n_steps=1,
return_convergence_delta=True
)
attribution_image = attributions.squeeze().cpu().numpy()
attribution_image = (attribution_image - attribution_image.min()) / (attribution_image.max() - attribution_image.min())
attribution_image = np.uint8(255 * attribution_image)
attribution_image_real = convert_tensor_to_pil(attribution_image)
attribution_gray = cv2.cvtColor(np.array(attribution_image_real), cv2.COLOR_RGB2GRAY)
_, binary_mask = cv2.threshold(attribution_gray, 128, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
original_image = convert_tensor_to_pil(pixel_values.squeeze(0).numpy())
original_image_np = np.array(original_image)
cv2.drawContours(original_image_np, contours, -1, (255, 0, 0), 2)
return attribution_image_real, Image.fromarray(original_image_np)
# 创建 Gradio 界面
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", interactive=True, height=400)
question_input = gr.Textbox(label="Question", lines=3, max_lines=3)
clear_button = gr.Button("Clear")
with gr.Column():
predict_button = gr.Button("Answer")
prediction_output = gr.Textbox(label="Answer", lines=2, interactive=False)
attribution_button = gr.Button("Generate Attribution")
with gr.Row():
attribution_image_1 = gr.Image(label="Attribution Image", interactive=False, height=400)
attribution_image_2 = gr.Image(label="Attribution with Contours", interactive=False, height=400)
# 按钮事件绑定
predict_button.click(predict_text, inputs=[input_image, question_input], outputs=prediction_output)
attribution_button.click(generate_attribution, inputs=[input_image, question_input], outputs=[attribution_image_1, attribution_image_2])
clear_button.click(lambda: (None, "", ""), outputs=[input_image, question_input, prediction_output])
# 启动 Gradio 界面
demo.launch()
|