Sadjad Alikhani commited on
Commit
e69b52d
1 Parent(s): 23274c4

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +106 -96
inference.py CHANGED
@@ -1,96 +1,106 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Sun Sep 15 18:27:17 2024
4
-
5
- @author: salikha4
6
- """
7
-
8
- import os
9
- import csv
10
- import json
11
- import shutil
12
- import random
13
- import argparse
14
- from datetime import datetime
15
- import pandas as pd
16
- import time
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
- from torch.utils.data import Dataset, DataLoader, TensorDataset
21
- from torch.optim import Adam
22
- import numpy as np
23
- #from lwm_model import LWM, load_model
24
- import warnings
25
- warnings.filterwarnings('ignore')
26
-
27
- # Device configuration
28
- device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
29
- if torch.cuda.is_available():
30
- torch.cuda.empty_cache()
31
-
32
-
33
- def lwm_inference(preprocessed_chs, input_type, lwm_model):
34
-
35
- dataset = prepare_for_LWM(preprocessed_chs, device)
36
-
37
- # Process data through LWM
38
- lwm_loss, embedding_data = evaluate(lwm_model, dataset)
39
- print(f'LWM loss: {lwm_loss:.4f}')
40
-
41
- if input_type == 'cls_emb':
42
- embedding_data = embedding_data[:, 0]
43
- elif input_type == 'channel_emb':
44
- embedding_data = embedding_data[:, 1:]
45
-
46
- dataset = embedding_data.float()
47
-
48
- return dataset
49
-
50
-
51
- def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
52
-
53
- input_ids, masked_tokens, masked_pos = zip(*data)
54
-
55
- input_ids_tensor = torch.tensor(input_ids, device=device).float()
56
- masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
57
- masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
58
-
59
- dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
60
-
61
- return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
62
-
63
-
64
- def evaluate(model, dataloader):
65
-
66
- model.eval()
67
- running_loss = 0.0
68
- outputs = []
69
- criterionMCM = nn.MSELoss()
70
-
71
- with torch.no_grad():
72
- for batch in dataloader:
73
- input_ids = batch[0]
74
- masked_tokens = batch[1]
75
- masked_pos = batch[2]
76
-
77
- logits_lm, output = model(input_ids, masked_pos)
78
-
79
- output_batch_preproc = output
80
- outputs.append(output_batch_preproc)
81
-
82
- loss_lm = criterionMCM(logits_lm, masked_tokens)
83
- loss = loss_lm/torch.var(masked_tokens)
84
- running_loss += loss.item()
85
-
86
- average_loss = running_loss / len(dataloader)
87
- output_total = torch.cat(outputs, dim=0)
88
-
89
- return average_loss, output_total
90
-
91
- def create_raw_dataset(data, device):
92
- """Create a dataset for raw channel data."""
93
- input_ids, _, _ = zip(*data)
94
- input_data = torch.tensor(input_ids, device=device)[:, 1:]
95
- return input_data.float()
96
-
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Sep 15 18:27:17 2024
4
+
5
+ @author: salikha4
6
+ """
7
+
8
+ import os
9
+ import csv
10
+ import json
11
+ import shutil
12
+ import random
13
+ import argparse
14
+ from datetime import datetime
15
+ import pandas as pd
16
+ import time
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
21
+ from torch.optim import Adam
22
+ import numpy as np
23
+ import warnings
24
+ warnings.filterwarnings('ignore')
25
+
26
+ # Set random seeds for reproducibility across CPU and GPU
27
+ def set_random_seed(seed=42):
28
+ torch.manual_seed(seed)
29
+ np.random.seed(seed)
30
+ random.seed(seed)
31
+ if torch.cuda.is_available():
32
+ torch.cuda.manual_seed_all(seed)
33
+ # Ensures deterministic behavior
34
+ torch.backends.cudnn.deterministic = True
35
+ torch.backends.cudnn.benchmark = False
36
+
37
+ # Apply random seed
38
+ set_random_seed()
39
+
40
+ # Device configuration
41
+ device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
42
+ if torch.cuda.is_available():
43
+ torch.cuda.empty_cache()
44
+
45
+ def lwm_inference(preprocessed_chs, input_type, lwm_model):
46
+
47
+ dataset = prepare_for_LWM(preprocessed_chs, device)
48
+
49
+ # Process data through LWM
50
+ lwm_loss, embedding_data = evaluate(lwm_model, dataset)
51
+ print(f'LWM loss: {lwm_loss:.4f}')
52
+
53
+ if input_type == 'cls_emb':
54
+ embedding_data = embedding_data[:, 0]
55
+ elif input_type == 'channel_emb':
56
+ embedding_data = embedding_data[:, 1:]
57
+
58
+ dataset = embedding_data.float()
59
+
60
+ return dataset
61
+
62
+ def prepare_for_LWM(data, device, batch_size=64, shuffle=False):
63
+
64
+ input_ids, masked_tokens, masked_pos = zip(*data)
65
+
66
+ input_ids_tensor = torch.tensor(input_ids, device=device).float() # Explicitly cast to float32
67
+ masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float() # Explicitly cast to float32
68
+ masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
69
+
70
+ dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
71
+
72
+ return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
73
+
74
+ def evaluate(model, dataloader):
75
+
76
+ model.eval()
77
+ running_loss = 0.0
78
+ outputs = []
79
+ criterionMCM = nn.MSELoss()
80
+
81
+ with torch.no_grad():
82
+ for batch in dataloader:
83
+ input_ids = batch[0]
84
+ masked_tokens = batch[1]
85
+ masked_pos = batch[2]
86
+
87
+ logits_lm, output = model(input_ids, masked_pos)
88
+
89
+ output_batch_preproc = output
90
+ outputs.append(output_batch_preproc)
91
+
92
+ loss_lm = criterionMCM(logits_lm, masked_tokens)
93
+ loss = loss_lm / torch.var(masked_tokens) # Use variance for normalization
94
+ running_loss += loss.item()
95
+
96
+ average_loss = running_loss / len(dataloader)
97
+ output_total = torch.cat(outputs, dim=0)
98
+
99
+ return average_loss, output_total
100
+
101
+ def create_raw_dataset(data, device):
102
+ """Create a dataset for raw channel data."""
103
+ input_ids, _, _ = zip(*data)
104
+ input_data = torch.tensor(input_ids, device=device).float()[:, 1:] # Explicitly cast to float32
105
+ return input_data.float()
106
+