import os import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import transforms from transformers import ViTModel, BertTokenizerFast, BertConfig, BertLMHeadModel, AdamW from PIL import Image, ImageFile import pandas as pd from tqdm import tqdm # Increase the maximum image size limit to avoid DecompressionBombWarning Image.MAX_IMAGE_PIXELS = None # Allow loading truncated images ImageFile.LOAD_TRUNCATED_IMAGES = True # Check if CUDA is available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Define constants VIT_MODEL_NAME = "google/vit-base-patch16-224" BERT_MODEL_NAME = "dbmdz/bert-base-turkish-cased" # Using a Turkish BERT model model = "TeLVE_v1.0.pth" MAX_LENGTH = 128 BATCH_SIZE = 8 EPOCHS = 5 LEARNING_RATE = 2e-5 class ImageCaptioningDataset(Dataset): def __init__(self, dataframe, img_dir, tokenizer): self.dataframe = dataframe self.img_dir = img_dir self.tokenizer = tokenizer self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __len__(self): return len(self.dataframe) def __getitem__(self, idx): row = self.dataframe.iloc[idx] img_path = os.path.join(self.img_dir, row['photo_id'] + ".jpg") try: image = Image.open(img_path).convert('RGB') image = self.transform(image) except (FileNotFoundError, IOError): # Return None if the image is not found or cannot be opened return None caption = row['ai_description'] # Check if caption is a valid string if not isinstance(caption, str): return None # Skip the example if caption is not valid encoding = self.tokenizer( caption, add_special_tokens=True, max_length=MAX_LENGTH, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt' ) return { 'pixel_values': image, 'input_ids': encoding['input_ids'].squeeze(), 'attention_mask': encoding['attention_mask'].squeeze(), 'labels': encoding['input_ids'].squeeze() # Use input_ids as labels for calculating loss } class ImageCaptioningModel(nn.Module): def __init__(self, vit_model, bert_model): super(ImageCaptioningModel, self).__init__() self.vit = vit_model self.bert = bert_model self.linear = nn.Linear(self.vit.config.hidden_size, self.bert.config.hidden_size) def forward(self, pixel_values, input_ids, attention_mask, labels=None): image_features = self.vit(pixel_values).last_hidden_state image_features = self.linear(image_features) outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, encoder_hidden_states=image_features, labels=labels, return_dict=True) return outputs.loss, outputs.logits def collate_fn(batch): # Filter out None values (skipped images) batch = list(filter(lambda x: x is not None, batch)) if len(batch) == 0: return None return {key: torch.stack([item[key] for item in batch]) for key in batch[0]} def train_vlm_model(): # Load and preprocess the dataset encodings = ['utf-8', 'iso-8859-9', 'windows-1254'] for encoding in encodings: try: df = pd.read_csv('./datasets/' + model + '.tsv000', sep='\t', encoding=encoding) print(f"Successfully read the file with {encoding} encoding.") break except UnicodeDecodeError: print(f"Failed to read with {encoding} encoding. Trying next...") else: raise ValueError("Could not read the file with any of the specified encodings.") # Initialize the tokenizer tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL_NAME) # Create the dataset and dataloader dataset = ImageCaptioningDataset(df, '../download/images', tokenizer) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn) # Initialize the model components vit_model = ViTModel.from_pretrained(VIT_MODEL_NAME) bert_config = BertConfig.from_pretrained(BERT_MODEL_NAME) bert_config.is_decoder = True bert_config.add_cross_attention = True bert_model = BertLMHeadModel.from_pretrained(BERT_MODEL_NAME, config=bert_config) # Create the combined model model = ImageCaptioningModel(vit_model, bert_model) model.to(device) # Define optimizer optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) # Training loop model.train() for epoch in range(EPOCHS): total_loss = 0 progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}") for batch in progress_bar: if batch is None: continue pixel_values = batch['pixel_values'].to(device) input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) optimizer.zero_grad() loss, _ = model(pixel_values, input_ids, attention_mask, labels) loss.backward() optimizer.step() total_loss += loss.item() progress_bar.set_postfix({'loss': loss.item()}) print(f"Epoch {epoch+1}/{EPOCHS}, Average Loss: {total_loss/len(dataloader)}") # Save the model torch.save(model.state_dict(), "./models/" + model) tokenizer.save_pretrained("./tokenizer") if __name__ == "__main__": train_vlm_model()