import gradio as gr import torch from transformers import GPT2Model, ViTModel, GPT2Tokenizer, ViTImageProcessor from captum.attr import IntegratedGradients from PIL import Image import numpy as np import cv2 # 定义多模态模型 class MultiModalModel(torch.nn.Module): def __init__(self, gpt2_model_name="gpt2", vit_model_name="google/vit-base-patch16-224-in21k"): super(MultiModalModel, self).__init__() self.gpt2 = GPT2Model.from_pretrained(gpt2_model_name) self.vit = ViTModel.from_pretrained(vit_model_name) self.classifier = torch.nn.Linear(self.gpt2.config.hidden_size + self.vit.config.hidden_size, 2) def forward(self, input_ids, attention_mask, pixel_values): gpt2_outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask) text_features = gpt2_outputs.last_hidden_state[:, -1, :] vit_outputs = self.vit(pixel_values=pixel_values) image_features = vit_outputs.last_hidden_state[:, 0, :] fused_features = torch.cat((text_features, image_features), dim=1) logits = self.classifier(fused_features) return logits # 加载模型 def load_model(): model_name = "Muhusjf/ViT-GPT2-multimodal-model" model = MultiModalModel() model.load_state_dict(torch.hub.load_state_dict_from_url( f"https://huggingface.co/{model_name}/resolve/main/pytorch_model.bin", map_location=torch.device('cpu') )) model.eval() return model # 转换张量为 PIL 图像 def convert_tensor_to_pil(tensor_image): if isinstance(tensor_image, torch.Tensor): tensor_image = tensor_image.numpy() image_np = np.transpose(tensor_image, (1, 2, 0)) if image_np.max() <= 1.0: image_np = (image_np * 255).astype(np.uint8) return Image.fromarray(image_np) # 自定义前向函数用于集成梯度 def custom_forward(pixel_values, input_ids, attention_mask): logits = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values) return logits # 初始化集成梯度 integrated_gradients = IntegratedGradients(custom_forward) # 初始化模型和加载器 model = load_model() tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") # 定义推理函数 def predict_text(image, text): image_features = feature_extractor(images=image, return_tensors="pt") inputs = tokenizer.encode_plus( f"Question: {text} Answer:", return_tensors="pt", max_length=128, truncation=True, padding="max_length" ) input_ids = inputs["input_ids"].long() attention_mask = inputs["attention_mask"].long() pixel_values = image_features["pixel_values"] with torch.no_grad(): logits = model(input_ids, attention_mask, pixel_values) prediction = torch.argmax(logits, dim=1).item() label = "yes" if prediction == 1 else "no" return label # 定义归因分析函数 def generate_attribution(image, text): image_features = feature_extractor(images=image, return_tensors="pt") inputs = tokenizer.encode_plus( f"Question: {text} Answer:", return_tensors="pt", max_length=128, truncation=True, padding="max_length" ) input_ids = inputs["input_ids"].long() attention_mask = inputs["attention_mask"].long() pixel_values = image_features["pixel_values"] with torch.no_grad(): logits = model(input_ids, attention_mask, pixel_values) prediction = torch.argmax(logits, dim=1).item() attributions, _ = integrated_gradients.attribute( inputs=pixel_values, target=prediction, additional_forward_args=(input_ids, attention_mask), n_steps=1, return_convergence_delta=True ) attribution_image = attributions.squeeze().cpu().numpy() attribution_image = (attribution_image - attribution_image.min()) / (attribution_image.max() - attribution_image.min()) attribution_image = np.uint8(255 * attribution_image) attribution_image_real = convert_tensor_to_pil(attribution_image) attribution_gray = cv2.cvtColor(np.array(attribution_image_real), cv2.COLOR_RGB2GRAY) _, binary_mask = cv2.threshold(attribution_gray, 128, 255, cv2.THRESH_BINARY) contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) original_image = convert_tensor_to_pil(pixel_values.squeeze(0).numpy()) original_image_np = np.array(original_image) cv2.drawContours(original_image_np, contours, -1, (255, 0, 0), 2) return attribution_image_real, Image.fromarray(original_image_np) # 创建 Gradio 界面 with gr.Blocks() as demo: with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", interactive=True, height=400) question_input = gr.Textbox(label="Question", lines=3, max_lines=3) clear_button = gr.Button("Clear") with gr.Column(): predict_button = gr.Button("Answer") prediction_output = gr.Textbox(label="Answer", lines=2, interactive=False) attribution_button = gr.Button("Generate Attribution") with gr.Row(): attribution_image_1 = gr.Image(label="Attribution Image", interactive=False, height=400) attribution_image_2 = gr.Image(label="Attribution with Contours", interactive=False, height=400) # 按钮事件绑定 predict_button.click(predict_text, inputs=[input_image, question_input], outputs=prediction_output) attribution_button.click(generate_attribution, inputs=[input_image, question_input], outputs=[attribution_image_1, attribution_image_2]) clear_button.click(lambda: (None, "", ""), outputs=[input_image, question_input, prediction_output]) # 启动 Gradio 界面 demo.launch()