File size: 2,462 Bytes
e69b52d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4f2449
e69b52d
fedc36d
e69b52d
 
 
 
 
 
 
 
 
 
 
 
fedc36d
e69b52d
 
 
b4f2449
 
e69b52d
 
 
 
 
 
 
 
 
 
 
 
 
 
5d9edcb
e69b52d
 
 
5d9edcb
e69b52d
 
 
 
 
 
b4f2449
e69b52d
 
 
 
 
 
 
 
 
 
b4f2449
e69b52d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 15 18:27:17 2024

@author: salikha4
"""

import os
import csv
import json
import shutil
import random
import argparse
from datetime import datetime
import pandas as pd
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.optim import Adam
import numpy as np
import warnings
warnings.filterwarnings('ignore')

def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
    
    dataset = prepare_for_lwm(preprocessed_chs, device)
    # Process data through LWM
    lwm_loss, embedding_data = evaluate(lwm_model, dataset)
    print(f'LWM loss: {lwm_loss:.4f}')
    
    if input_type == 'cls_emb':
        embedding_data = embedding_data[:, 0]
    elif input_type == 'channel_emb':  
        embedding_data = embedding_data[:, 1:]
    
    dataset = embedding_data.float()
    return dataset

def prepare_for_lwm(data, device, batch_size=64, shuffle=False):

    input_ids, masked_tokens, masked_pos = zip(*data)
    
    input_ids_tensor = torch.tensor(input_ids, device=device).float() 
    masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float() 
    masked_pos_tensor = torch.tensor(masked_pos, device=device).long()

    dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
    
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

def evaluate(model, dataloader):

    model.eval()
    running_loss = 0.0
    outputs = []
    criterionMCM = nn.MSELoss()
    
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            input_ids = batch[0]
            masked_tokens = batch[1]
            masked_pos = batch[2]

            logits_lm, output = model(input_ids, masked_pos)
            
            output_batch_preproc = output 
            outputs.append(output_batch_preproc)

            loss_lm = criterionMCM(logits_lm, masked_tokens)
            loss = loss_lm / torch.var(masked_tokens)  
            running_loss += loss.item()
            
    average_loss = running_loss / len(dataloader)
    output_total = torch.cat(outputs, dim=0)
    
    return average_loss, output_total

def create_raw_dataset(data, device):
    """Create a dataset for raw channel data."""
    input_ids, _, _ = zip(*data)
    input_data = torch.tensor(input_ids, device=device)[:, 1:]  
    return input_data.float()