XAI-Medical / app.py
Muhusystem
Split text prediction and attribution analysis into separate buttons
9e5f3db
import gradio as gr
import torch
from transformers import GPT2Model, ViTModel, GPT2Tokenizer, ViTImageProcessor
from captum.attr import IntegratedGradients
from PIL import Image
import numpy as np
import cv2
# 定义多模态模型
class MultiModalModel(torch.nn.Module):
def __init__(self, gpt2_model_name="gpt2", vit_model_name="google/vit-base-patch16-224-in21k"):
super(MultiModalModel, self).__init__()
self.gpt2 = GPT2Model.from_pretrained(gpt2_model_name)
self.vit = ViTModel.from_pretrained(vit_model_name)
self.classifier = torch.nn.Linear(self.gpt2.config.hidden_size + self.vit.config.hidden_size, 2)
def forward(self, input_ids, attention_mask, pixel_values):
gpt2_outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask)
text_features = gpt2_outputs.last_hidden_state[:, -1, :]
vit_outputs = self.vit(pixel_values=pixel_values)
image_features = vit_outputs.last_hidden_state[:, 0, :]
fused_features = torch.cat((text_features, image_features), dim=1)
logits = self.classifier(fused_features)
return logits
# 加载模型
def load_model():
model_name = "Muhusjf/ViT-GPT2-multimodal-model"
model = MultiModalModel()
model.load_state_dict(torch.hub.load_state_dict_from_url(
f"https://huggingface.co/{model_name}/resolve/main/pytorch_model.bin",
map_location=torch.device('cpu')
))
model.eval()
return model
# 转换张量为 PIL 图像
def convert_tensor_to_pil(tensor_image):
if isinstance(tensor_image, torch.Tensor):
tensor_image = tensor_image.numpy()
image_np = np.transpose(tensor_image, (1, 2, 0))
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
return Image.fromarray(image_np)
# 自定义前向函数用于集成梯度
def custom_forward(pixel_values, input_ids, attention_mask):
logits = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
return logits
# 初始化集成梯度
integrated_gradients = IntegratedGradients(custom_forward)
# 初始化模型和加载器
model = load_model()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
# 定义推理函数
def predict_text(image, text):
image_features = feature_extractor(images=image, return_tensors="pt")
inputs = tokenizer.encode_plus(
f"Question: {text} Answer:",
return_tensors="pt",
max_length=128,
truncation=True,
padding="max_length"
)
input_ids = inputs["input_ids"].long()
attention_mask = inputs["attention_mask"].long()
pixel_values = image_features["pixel_values"]
with torch.no_grad():
logits = model(input_ids, attention_mask, pixel_values)
prediction = torch.argmax(logits, dim=1).item()
label = "yes" if prediction == 1 else "no"
return label
# 定义归因分析函数
def generate_attribution(image, text):
image_features = feature_extractor(images=image, return_tensors="pt")
inputs = tokenizer.encode_plus(
f"Question: {text} Answer:",
return_tensors="pt",
max_length=128,
truncation=True,
padding="max_length"
)
input_ids = inputs["input_ids"].long()
attention_mask = inputs["attention_mask"].long()
pixel_values = image_features["pixel_values"]
with torch.no_grad():
logits = model(input_ids, attention_mask, pixel_values)
prediction = torch.argmax(logits, dim=1).item()
attributions, _ = integrated_gradients.attribute(
inputs=pixel_values,
target=prediction,
additional_forward_args=(input_ids, attention_mask),
n_steps=1,
return_convergence_delta=True
)
attribution_image = attributions.squeeze().cpu().numpy()
attribution_image = (attribution_image - attribution_image.min()) / (attribution_image.max() - attribution_image.min())
attribution_image = np.uint8(255 * attribution_image)
attribution_image_real = convert_tensor_to_pil(attribution_image)
attribution_gray = cv2.cvtColor(np.array(attribution_image_real), cv2.COLOR_RGB2GRAY)
_, binary_mask = cv2.threshold(attribution_gray, 128, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
original_image = convert_tensor_to_pil(pixel_values.squeeze(0).numpy())
original_image_np = np.array(original_image)
cv2.drawContours(original_image_np, contours, -1, (255, 0, 0), 2)
return attribution_image_real, Image.fromarray(original_image_np)
# 创建 Gradio 界面
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", interactive=True, height=400)
question_input = gr.Textbox(label="Question", lines=3, max_lines=3)
clear_button = gr.Button("Clear")
with gr.Column():
predict_button = gr.Button("Answer")
prediction_output = gr.Textbox(label="Answer", lines=2, interactive=False)
attribution_button = gr.Button("Generate Attribution")
with gr.Row():
attribution_image_1 = gr.Image(label="Attribution Image", interactive=False, height=400)
attribution_image_2 = gr.Image(label="Attribution with Contours", interactive=False, height=400)
# 按钮事件绑定
predict_button.click(predict_text, inputs=[input_image, question_input], outputs=prediction_output)
attribution_button.click(generate_attribution, inputs=[input_image, question_input], outputs=[attribution_image_1, attribution_image_2])
clear_button.click(lambda: (None, "", ""), outputs=[input_image, question_input, prediction_output])
# 启动 Gradio 界面
demo.launch()