TeLVE / imagine.py
outsu's picture
First version of TeLVE!
012f8b5 verified
raw
history blame
4.04 kB
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import ViTModel, BertTokenizerFast, BertConfig, BertLMHeadModel
from PIL import Image
import os
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Define constants
VIT_MODEL_NAME = "google/vit-base-patch16-224"
BERT_MODEL_NAME = "dbmdz/bert-base-turkish-cased"
MAX_LENGTH = 128
class ImageCaptioningModel(nn.Module):
def __init__(self, vit_model, bert_model):
super(ImageCaptioningModel, self).__init__()
self.vit = vit_model
self.bert = bert_model
self.linear = nn.Linear(self.vit.config.hidden_size, self.bert.config.hidden_size)
def forward(self, pixel_values, input_ids, attention_mask, labels=None):
image_features = self.vit(pixel_values).last_hidden_state
image_features = self.linear(image_features)
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_features,
labels=labels,
return_dict=True)
return outputs.loss, outputs.logits
def load_model(model_path):
# Initialize the model components
vit_model = ViTModel.from_pretrained(VIT_MODEL_NAME)
bert_config = BertConfig.from_pretrained(BERT_MODEL_NAME)
bert_config.is_decoder = True
bert_config.add_cross_attention = True
bert_model = BertLMHeadModel.from_pretrained(BERT_MODEL_NAME, config=bert_config)
# Create the combined model
model = ImageCaptioningModel(vit_model, bert_model)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
return model
def generate_caption(model, image_path, tokenizer):
# Prepare the image
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0).to(device)
# Generate the caption
with torch.no_grad():
input_ids = torch.tensor([[tokenizer.cls_token_id]]).to(device)
attention_mask = torch.tensor([[1]]).to(device)
for _ in range(MAX_LENGTH):
_, logits = model(image, input_ids, attention_mask)
next_token = logits[:, -1, :].argmax(dim=-1)
if next_token.item() == tokenizer.sep_token_id:
break
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
attention_mask = torch.cat([attention_mask, torch.tensor([[1]]).to(device)], dim=1)
caption = tokenizer.decode(input_ids[0], skip_special_tokens=True)
return caption
def main():
model_path = "./models/TeLVE_v1.1.pth"
tokenizer_path = "./tokenizer"
# Check if the model and tokenizer exist
if not os.path.exists(model_path) or not os.path.exists(tokenizer_path):
print("Model or tokenizer not found. Please make sure you have trained the model and saved it correctly.")
return
# Load the model and tokenizer
model = load_model(model_path)
tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path)
# Generate captions for images in a specified directory
image_dir = "./images" # Change this to the directory containing your test images
for image_file in os.listdir(image_dir):
if image_file.lower().endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(image_dir, image_file)
caption = generate_caption(model, image_path, tokenizer)
print(f"Image: {image_file}")
print(f"Generated Caption: {caption}")
print("---")
if __name__ == "__main__":
main()