Sadjad Alikhani
commited on
Update inference.py
Browse files- inference.py +11 -11
inference.py
CHANGED
@@ -23,19 +23,18 @@ import numpy as np
|
|
23 |
import warnings
|
24 |
warnings.filterwarnings('ignore')
|
25 |
|
26 |
-
|
27 |
-
def set_random_seed(seed=42):
|
28 |
torch.manual_seed(seed)
|
29 |
np.random.seed(seed)
|
30 |
-
#random.seed(seed)
|
31 |
-
if torch.cuda.is_available():
|
32 |
-
torch.cuda.manual_seed_all(seed)
|
33 |
-
# Ensures deterministic behavior
|
34 |
-
torch.backends.cudnn.deterministic = True
|
35 |
-
torch.backends.cudnn.benchmark = False
|
36 |
|
37 |
-
#
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
# Device configuration
|
41 |
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
@@ -45,7 +44,8 @@ if torch.cuda.is_available():
|
|
45 |
def lwm_inference(preprocessed_chs, input_type, lwm_model):
|
46 |
|
47 |
dataset = prepare_for_LWM(preprocessed_chs, device)
|
48 |
-
|
|
|
49 |
# Process data through LWM
|
50 |
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|
51 |
print(f'LWM loss: {lwm_loss:.4f}')
|
|
|
23 |
import warnings
|
24 |
warnings.filterwarnings('ignore')
|
25 |
|
26 |
+
def set_seed(seed=42):
|
|
|
27 |
torch.manual_seed(seed)
|
28 |
np.random.seed(seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
# Use this function at the start of your code
|
31 |
+
set_seed(42)
|
32 |
+
|
33 |
+
# Force model weights and data to float32 precision
|
34 |
+
def cast_model_weights_to_float32(model):
|
35 |
+
for param in model.parameters():
|
36 |
+
param.data = param.data.float() # Cast all weights to float32
|
37 |
+
return model
|
38 |
|
39 |
# Device configuration
|
40 |
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
|
|
44 |
def lwm_inference(preprocessed_chs, input_type, lwm_model):
|
45 |
|
46 |
dataset = prepare_for_LWM(preprocessed_chs, device)
|
47 |
+
|
48 |
+
lwm_model = cast_model_weights_to_float32(lwm_model)
|
49 |
# Process data through LWM
|
50 |
lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|
51 |
print(f'LWM loss: {lwm_loss:.4f}')
|