import torch import torch.nn as nn from torchvision import transforms from transformers import ViTModel, BertTokenizerFast, BertConfig, BertLMHeadModel from PIL import Image import os # Check if CUDA is available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Define constants VIT_MODEL_NAME = "google/vit-base-patch16-224" BERT_MODEL_NAME = "dbmdz/bert-base-turkish-cased" MAX_LENGTH = 128 class ImageCaptioningModel(nn.Module): def __init__(self, vit_model, bert_model): super(ImageCaptioningModel, self).__init__() self.vit = vit_model self.bert = bert_model self.linear = nn.Linear(self.vit.config.hidden_size, self.bert.config.hidden_size) def forward(self, pixel_values, input_ids, attention_mask, labels=None): image_features = self.vit(pixel_values).last_hidden_state image_features = self.linear(image_features) outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, encoder_hidden_states=image_features, labels=labels, return_dict=True) return outputs.loss, outputs.logits def load_model(model_path): # Initialize the model components vit_model = ViTModel.from_pretrained(VIT_MODEL_NAME) bert_config = BertConfig.from_pretrained(BERT_MODEL_NAME) bert_config.is_decoder = True bert_config.add_cross_attention = True bert_model = BertLMHeadModel.from_pretrained(BERT_MODEL_NAME, config=bert_config) # Create the combined model model = ImageCaptioningModel(vit_model, bert_model) model.load_state_dict(torch.load(model_path, map_location=device)) model.to(device) model.eval() return model def generate_caption(model, image_path, tokenizer): # Prepare the image transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image = Image.open(image_path).convert('RGB') image = transform(image).unsqueeze(0).to(device) # Generate the caption with torch.no_grad(): input_ids = torch.tensor([[tokenizer.cls_token_id]]).to(device) attention_mask = torch.tensor([[1]]).to(device) for _ in range(MAX_LENGTH): _, logits = model(image, input_ids, attention_mask) next_token = logits[:, -1, :].argmax(dim=-1) if next_token.item() == tokenizer.sep_token_id: break input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1) attention_mask = torch.cat([attention_mask, torch.tensor([[1]]).to(device)], dim=1) caption = tokenizer.decode(input_ids[0], skip_special_tokens=True) return caption def main(): model_path = "./models/TeLVE_v1.1.pth" tokenizer_path = "./tokenizer" # Check if the model and tokenizer exist if not os.path.exists(model_path) or not os.path.exists(tokenizer_path): print("Model or tokenizer not found. Please make sure you have trained the model and saved it correctly.") return # Load the model and tokenizer model = load_model(model_path) tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path) # Generate captions for images in a specified directory image_dir = "./images" # Change this to the directory containing your test images for image_file in os.listdir(image_dir): if image_file.lower().endswith(('.png', '.jpg', '.jpeg')): image_path = os.path.join(image_dir, image_file) caption = generate_caption(model, image_path, tokenizer) print(f"Image: {image_file}") print(f"Generated Caption: {caption}") print("---") if __name__ == "__main__": main()