File size: 6,070 Bytes
15f7c75
 
ade3964
9261673
15f7c75
9261673
 
15f7c75
b64070e
 
 
 
 
 
 
15f7c75
b64070e
 
 
 
 
 
 
 
 
 
15f7c75
 
b64070e
 
 
ade3964
 
 
 
b64070e
 
 
d044829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b64070e
 
 
 
ade3964
15f7c75
559bd94
a139433
15f7c75
 
b64070e
 
 
 
 
 
 
 
d044829
 
b64070e
15f7c75
b64070e
 
 
 
a139433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15f7c75
9261673
 
 
 
 
d044829
9261673
 
d044829
 
 
 
 
 
 
 
 
a5ee181
 
 
 
 
15f7c75
 
a5ee181
 
 
9e5f3db
 
a5ee181
 
 
 
 
 
9e5f3db
 
a5ee181
 
 
 
 
 
9e5f3db
7724540
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gradio as gr
import torch
from transformers import GPT2Model, ViTModel, GPT2Tokenizer, ViTImageProcessor
from captum.attr import IntegratedGradients
from PIL import Image
import numpy as np
import cv2

# 定义多模态模型
class MultiModalModel(torch.nn.Module):
    def __init__(self, gpt2_model_name="gpt2", vit_model_name="google/vit-base-patch16-224-in21k"):
        super(MultiModalModel, self).__init__()
        self.gpt2 = GPT2Model.from_pretrained(gpt2_model_name)
        self.vit = ViTModel.from_pretrained(vit_model_name)
        self.classifier = torch.nn.Linear(self.gpt2.config.hidden_size + self.vit.config.hidden_size, 2)

    def forward(self, input_ids, attention_mask, pixel_values):
        gpt2_outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask)
        text_features = gpt2_outputs.last_hidden_state[:, -1, :]

        vit_outputs = self.vit(pixel_values=pixel_values)
        image_features = vit_outputs.last_hidden_state[:, 0, :]

        fused_features = torch.cat((text_features, image_features), dim=1)
        logits = self.classifier(fused_features)
        return logits

# 加载模型
def load_model():
    model_name = "Muhusjf/ViT-GPT2-multimodal-model"
    model = MultiModalModel()
    model.load_state_dict(torch.hub.load_state_dict_from_url(
        f"https://huggingface.co/{model_name}/resolve/main/pytorch_model.bin",
        map_location=torch.device('cpu')
    ))
    model.eval()
    return model

# 转换张量为 PIL 图像
def convert_tensor_to_pil(tensor_image):
    if isinstance(tensor_image, torch.Tensor):
        tensor_image = tensor_image.numpy()
    image_np = np.transpose(tensor_image, (1, 2, 0))
    if image_np.max() <= 1.0:
        image_np = (image_np * 255).astype(np.uint8)
    return Image.fromarray(image_np)

# 自定义前向函数用于集成梯度
def custom_forward(pixel_values, input_ids, attention_mask):
    logits = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
    return logits

# 初始化集成梯度
integrated_gradients = IntegratedGradients(custom_forward)

# 初始化模型和加载器
model = load_model()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

# 定义推理函数
def predict_text(image, text):
    image_features = feature_extractor(images=image, return_tensors="pt")

    inputs = tokenizer.encode_plus(
        f"Question: {text} Answer:",
        return_tensors="pt",
        max_length=128,
        truncation=True,
        padding="max_length"
    )

    input_ids = inputs["input_ids"].long()
    attention_mask = inputs["attention_mask"].long()
    pixel_values = image_features["pixel_values"]

    with torch.no_grad():
        logits = model(input_ids, attention_mask, pixel_values)
        prediction = torch.argmax(logits, dim=1).item()
        label = "yes" if prediction == 1 else "no"
    return label

# 定义归因分析函数
def generate_attribution(image, text):
    image_features = feature_extractor(images=image, return_tensors="pt")

    inputs = tokenizer.encode_plus(
        f"Question: {text} Answer:",
        return_tensors="pt",
        max_length=128,
        truncation=True,
        padding="max_length"
    )

    input_ids = inputs["input_ids"].long()
    attention_mask = inputs["attention_mask"].long()
    pixel_values = image_features["pixel_values"]

    with torch.no_grad():
        logits = model(input_ids, attention_mask, pixel_values)
        prediction = torch.argmax(logits, dim=1).item()

    attributions, _ = integrated_gradients.attribute(
        inputs=pixel_values,
        target=prediction,
        additional_forward_args=(input_ids, attention_mask),
        n_steps=1,
        return_convergence_delta=True
    )

    attribution_image = attributions.squeeze().cpu().numpy()
    attribution_image = (attribution_image - attribution_image.min()) / (attribution_image.max() - attribution_image.min())
    attribution_image = np.uint8(255 * attribution_image)
    attribution_image_real = convert_tensor_to_pil(attribution_image)

    attribution_gray = cv2.cvtColor(np.array(attribution_image_real), cv2.COLOR_RGB2GRAY)
    _, binary_mask = cv2.threshold(attribution_gray, 128, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    original_image = convert_tensor_to_pil(pixel_values.squeeze(0).numpy())
    original_image_np = np.array(original_image)
    cv2.drawContours(original_image_np, contours, -1, (255, 0, 0), 2)

    return attribution_image_real, Image.fromarray(original_image_np)

# 创建 Gradio 界面
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input Image", interactive=True, height=400)
            question_input = gr.Textbox(label="Question", lines=3, max_lines=3)
            clear_button = gr.Button("Clear")
        with gr.Column():
            predict_button = gr.Button("Answer")
            prediction_output = gr.Textbox(label="Answer", lines=2, interactive=False)
            attribution_button = gr.Button("Generate Attribution")
            with gr.Row():
                attribution_image_1 = gr.Image(label="Attribution Image", interactive=False, height=400)
                attribution_image_2 = gr.Image(label="Attribution with Contours", interactive=False, height=400)

    # 按钮事件绑定
    predict_button.click(predict_text, inputs=[input_image, question_input], outputs=prediction_output)
    attribution_button.click(generate_attribution, inputs=[input_image, question_input], outputs=[attribution_image_1, attribution_image_2])
    clear_button.click(lambda: (None, "", ""), outputs=[input_image, question_input, prediction_output])

# 启动 Gradio 界面
demo.launch()