File size: 3,147 Bytes
e69b52d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 15 18:27:17 2024

@author: salikha4
"""

import os
import csv
import json
import shutil
import random
import argparse
from datetime import datetime
import pandas as pd
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.optim import Adam
import numpy as np
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility across CPU and GPU
def set_random_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # Ensures deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Apply random seed
set_random_seed()

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.empty_cache()

def lwm_inference(preprocessed_chs, input_type, lwm_model):
    
    dataset = prepare_for_LWM(preprocessed_chs, device)
    
    # Process data through LWM
    lwm_loss, embedding_data = evaluate(lwm_model, dataset)
    print(f'LWM loss: {lwm_loss:.4f}')
    
    if input_type == 'cls_emb':
        embedding_data = embedding_data[:, 0]
    elif input_type == 'channel_emb':  
        embedding_data = embedding_data[:, 1:]
    
    dataset = embedding_data.float()
        
    return dataset

def prepare_for_LWM(data, device, batch_size=64, shuffle=False):

    input_ids, masked_tokens, masked_pos = zip(*data)
    
    input_ids_tensor = torch.tensor(input_ids, device=device).float()  # Explicitly cast to float32
    masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()  # Explicitly cast to float32
    masked_pos_tensor = torch.tensor(masked_pos, device=device).long()

    dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
    
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

def evaluate(model, dataloader):

    model.eval()
    running_loss = 0.0
    outputs = []
    criterionMCM = nn.MSELoss()
    
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch[0]
            masked_tokens = batch[1]
            masked_pos = batch[2]
            
            logits_lm, output = model(input_ids, masked_pos)
            
            output_batch_preproc = output 
            outputs.append(output_batch_preproc)

            loss_lm = criterionMCM(logits_lm, masked_tokens)
            loss = loss_lm / torch.var(masked_tokens)  # Use variance for normalization
            running_loss += loss.item()
            
    average_loss = running_loss / len(dataloader)
    output_total = torch.cat(outputs, dim=0)
    
    return average_loss, output_total

def create_raw_dataset(data, device):
    """Create a dataset for raw channel data."""
    input_ids, _, _ = zip(*data)
    input_data = torch.tensor(input_ids, device=device).float()[:, 1:]  # Explicitly cast to float32
    return input_data.float()