lwm / inference.py
Sadjad Alikhani
Update inference.py
21ca5ef verified
raw
history blame
3.17 kB
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 15 18:27:17 2024
@author: salikha4
"""
import os
import csv
import json
import shutil
import random
import argparse
from datetime import datetime
import pandas as pd
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.optim import Adam
import numpy as np
import warnings
warnings.filterwarnings('ignore')
# Set random seeds for reproducibility across CPU and GPU
def set_random_seed(seed=42):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Ensures deterministic behavior
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Apply random seed
set_random_seed()
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
def lwm_inference(preprocessed_chs, input_type, lwm_model):
dataset = prepare_for_LWM(preprocessed_chs, device)
# Process data through LWM
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
print(f'LWM loss: {lwm_loss:.4f}')
if input_type == 'cls_emb':
embedding_data = embedding_data[:, 0]
elif input_type == 'channel_emb':
embedding_data = embedding_data[:, 1:]
dataset = embedding_data.float()
print(dataset[0][:4])
return dataset
def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
input_ids, masked_tokens, masked_pos = zip(*data)
input_ids_tensor = torch.tensor(input_ids, device=device).float() # Explicitly cast to float32
masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float() # Explicitly cast to float32
masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
def evaluate(model, dataloader):
model.eval()
running_loss = 0.0
outputs = []
criterionMCM = nn.MSELoss()
with torch.no_grad():
for batch in dataloader:
input_ids = batch[0]
masked_tokens = batch[1]
masked_pos = batch[2]
logits_lm, output = model(input_ids, masked_pos)
output_batch_preproc = output
outputs.append(output_batch_preproc)
loss_lm = criterionMCM(logits_lm, masked_tokens)
loss = loss_lm / torch.var(masked_tokens) # Use variance for normalization
running_loss += loss.item()
average_loss = running_loss / len(dataloader)
output_total = torch.cat(outputs, dim=0)
return average_loss, output_total
def create_raw_dataset(data, device):
"""Create a dataset for raw channel data."""
input_ids, _, _ = zip(*data)
input_data = torch.tensor(input_ids, device=device).float()[:, 1:] # Explicitly cast to float32
return input_data.float()