crimeacs commited on
Commit
8646273
·
1 Parent(s): e5e1367

Initial commit

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio app that takes seismic waveform as input and marks 2 phases on the waveform as output.
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ from phasehunter.model import Onset_picker, Updated_onset_picker
6
+ from phasehunter.data_preparation import prepare_waveform
7
+ import torch
8
+ from scipy.stats.kde import gaussian_kde
9
+
10
+ import obspy
11
+ from obspy.clients.fdsn import Client
12
+ from obspy.clients.fdsn.header import FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException
13
+ from obspy.geodetics.base import locations2degrees
14
+ from obspy.taup import TauPyModel
15
+ from obspy.taup.helper_classes import SlownessModelError
16
+
17
+ from obspy.clients.fdsn.header import URL_MAPPINGS
18
+
19
+ def make_prediction(waveform):
20
+ waveform = np.load(waveform)
21
+ processed_input = prepare_waveform(waveform)
22
+
23
+ # Make prediction
24
+ with torch.no_grad():
25
+ output = model(processed_input)
26
+
27
+ p_phase = output[:, 0]
28
+ s_phase = output[:, 1]
29
+
30
+ return processed_input, p_phase, s_phase
31
+
32
+ def mark_phases(waveform):
33
+ processed_input, p_phase, s_phase = make_prediction(waveform)
34
+
35
+ # Create a plot of the waveform with the phases marked
36
+ if sum(processed_input[0][2] == 0): #if input is 1C
37
+ fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)
38
+
39
+ ax[0].plot(processed_input[0][0])
40
+ ax[0].set_ylabel('Norm. Ampl.')
41
+
42
+ else: #if input is 3C
43
+ fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)
44
+ ax[0].plot(processed_input[0][0])
45
+ ax[1].plot(processed_input[0][1])
46
+ ax[2].plot(processed_input[0][2])
47
+
48
+ ax[0].set_ylabel('Z')
49
+ ax[1].set_ylabel('N')
50
+ ax[2].set_ylabel('E')
51
+
52
+ p_phase_plot = p_phase*processed_input.shape[-1]
53
+ p_kde = gaussian_kde(p_phase_plot)
54
+ p_dist_space = np.linspace( min(p_phase_plot)-10, max(p_phase_plot)+10, 500 )
55
+ ax[-1].plot( p_dist_space, p_kde(p_dist_space), color='r')
56
+
57
+ s_phase_plot = s_phase*processed_input.shape[-1]
58
+ s_kde = gaussian_kde(s_phase_plot)
59
+ s_dist_space = np.linspace( min(s_phase_plot)-10, max(s_phase_plot)+10, 500 )
60
+ ax[-1].plot( s_dist_space, s_kde(s_dist_space), color='b')
61
+
62
+ for a in ax:
63
+ a.axvline(p_phase.mean()*processed_input.shape[-1], color='r', linestyle='--', label='P')
64
+ a.axvline(s_phase.mean()*processed_input.shape[-1], color='b', linestyle='--', label='S')
65
+
66
+ ax[-1].set_xlabel('Time, samples')
67
+ ax[-1].set_ylabel('Uncert.')
68
+ ax[-1].legend()
69
+
70
+ plt.subplots_adjust(hspace=0., wspace=0.)
71
+
72
+ # Convert the plot to an image and return it
73
+ fig.canvas.draw()
74
+ image = np.array(fig.canvas.renderer.buffer_rgba())
75
+ plt.close(fig)
76
+ return image
77
+
78
+ def download_data(timestamp, eq_lat, eq_lon, client_name, radius_km):
79
+ client = Client(client_name)
80
+ window = radius_km / 111.2
81
+
82
+ assert eq_lat - window > -90 and eq_lat + window < 90, "Latitude out of bounds"
83
+ assert eq_lon - window > -180 and eq_lon + window < 180, "Longitude out of bounds"
84
+
85
+ # starttime = catalog['DateTime'].apply(lambda x: pd.Timestamp(x)).min()
86
+ # endtime = catalog['DateTime'].apply(lambda x: pd.Timestamp(x)).max()
87
+
88
+ return 0
89
+
90
+ model = Onset_picker.load_from_checkpoint("./weights.ckpt",
91
+ picker=Updated_onset_picker(),
92
+ learning_rate=3e-4)
93
+ model.eval()
94
+
95
+
96
+
97
+ # # Create the Gradio interface
98
+ # gr.Interface(mark_phases, inputs, outputs, title='PhaseHunter').launch()
99
+
100
+
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("# PhaseHunter")
103
+ with gr.Tab("Default example"):
104
+ # Define the input and output types for Gradio
105
+ inputs = gr.Dropdown(
106
+ ["data/sample/sample_0.npy",
107
+ "data/sample/sample_1.npy",
108
+ "data/sample/sample_2.npy"],
109
+ label="Sample waveform",
110
+ info="Select one of the samples",
111
+ value = "data/sample/sample_0.npy"
112
+ )
113
+
114
+ button = gr.Button("Predict phases")
115
+ outputs = gr.outputs.Image(label='Waveform with Phases Marked', type='numpy')
116
+
117
+ button.click(mark_phases, inputs=inputs, outputs=outputs)
118
+
119
+ with gr.Tab("Select earthquake from catalogue"):
120
+ gr.Markdown('TEST')
121
+
122
+ client_inputs = gr.Dropdown(
123
+ choices = list(URL_MAPPINGS.keys()),
124
+ label="FDSN Client",
125
+ info="Select one of the available FDSN clients",
126
+ value = "IRIS",
127
+ interactive=True
128
+ )
129
+ with gr.Row():
130
+
131
+ timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49',
132
+ placeholder='YYYY-MM-DD HH:MM:SS',
133
+ label="Timestamp",
134
+ info="Timestamp of the earthquake",
135
+ max_lines=1,
136
+ interactive=True)
137
+
138
+ eq_lat_inputs = gr.Number(value=35.766,
139
+ label="Latitude",
140
+ info="Latitude of the earthquake",
141
+ interactive=True)
142
+
143
+ eq_lo_inputs = gr.Number(value=117.605,
144
+ label="Longitude",
145
+ info="Longitude of the earthquake",
146
+ interactive=True)
147
+
148
+ radius_inputs = gr.Slider(minimum=1,
149
+ maximum=150,
150
+ value=50, label="Radius (km)",
151
+ info="Select the radius around the earthquake to download data from",
152
+ interactive=True)
153
+
154
+ button = gr.Button("Predict phases")
155
+
156
+ with gr.Tab("Predict on your own waveform"):
157
+ gr.Markdown("""
158
+ Please upload your waveform in .npy (numpy) format.
159
+ Your waveform should be sampled at 100 sps and have 3 (Z, N, E) or 1 (Z) channels.
160
+ """)
161
+
162
+ button.click(mark_phases, inputs=inputs, outputs=outputs)
163
+
164
+ demo.launch()
data/.DS_Store ADDED
Binary file (8.2 kB). View file
 
data/sample/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/sample/sample_0.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bcacc216bea4273debc0244dce60def2413e4642100029fa0c0ce83416ba71c8
3
+ size 144128
data/sample/sample_1.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:245f7bd592643f8573e8051aa593bdcdffa7e497de3e56aa76423cd96e44ae03
3
+ size 113776
data/sample/sample_2.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3eb64a6b01caba32464d950aafcbc4ed32f724f062c9c91b1c7f36671ce2b8e
3
+ size 136344
data_preparation.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ from scipy import signal
5
+ from scipy.signal import butter, lfilter, detrend
6
+
7
+ # Make bandpass filter
8
+ def butter_bandpass(lowcut, highcut, fs, order=5):
9
+ nyq = 0.5 * fs # Nyquist frequency
10
+ low = lowcut / nyq # Normalized frequency
11
+ high = highcut / nyq
12
+ b, a = butter(order, [low, high], btype="band") # Bandpass filter
13
+ return b, a
14
+
15
+
16
+ def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
17
+ b, a = butter_bandpass(lowcut, highcut, fs, order=order)
18
+ y = lfilter(b, a, data)
19
+ return y
20
+
21
+
22
+ def rotate_waveform(waveform, angle):
23
+ fft_waveform = np.fft.fft(waveform) # Compute the Fourier transform of the waveform
24
+ rotate_factor = np.exp(
25
+ 1j * angle
26
+ ) # Create a complex exponential with the specified rotation angle
27
+ rotated_fft_waveform = (
28
+ fft_waveform * rotate_factor
29
+ ) # Multiply the Fourier transform by the rotation factor
30
+ rotated_waveform = np.fft.ifft(
31
+ rotated_fft_waveform
32
+ ) # Compute the inverse Fourier transform to get the rotated waveform in the time domain
33
+
34
+ return rotated_waveform
35
+
36
+
37
+ def augment(sample):
38
+ # SET PARAMETERS:
39
+ crop_length = 6000
40
+ padding = 120
41
+ test = False
42
+
43
+ waveform = sample["waveform.npy"]
44
+ meta = sample["meta.json"]
45
+
46
+ if meta["split"] != "train":
47
+ test = True
48
+
49
+ target_sample_P = meta["trace_p_arrival_sample"]
50
+ target_sample_S = meta["trace_s_arrival_sample"]
51
+
52
+ if target_sample_P is None:
53
+ target_sample_P = 0
54
+ if target_sample_S is None:
55
+ target_sample_S = 0
56
+
57
+ # Randomly select a phase to start the crop
58
+ current_phases = [x for x in (target_sample_P, target_sample_S) if x > 0]
59
+ phase_selector = np.random.randint(0, len(current_phases))
60
+ first_phase = current_phases[phase_selector]
61
+
62
+ # Shuffle
63
+ if first_phase - (crop_length - padding) > padding:
64
+ start_indx = int(
65
+ first_phase
66
+ - torch.randint(low=padding, high=(crop_length - padding), size=(1,))
67
+ )
68
+ if test == True:
69
+ start_indx = int(first_phase - 2 * padding)
70
+
71
+ elif int(first_phase - padding) > 0:
72
+ start_indx = int(
73
+ first_phase
74
+ - torch.randint(low=0, high=(int(first_phase - padding)), size=(1,))
75
+ )
76
+ if test == True:
77
+ start_indx = int(first_phase - padding)
78
+
79
+ else:
80
+ start_indx = padding
81
+
82
+ end_indx = start_indx + crop_length
83
+
84
+ if (waveform.shape[-1] - end_indx) < 0:
85
+ start_indx += waveform.shape[-1] - end_indx
86
+ end_indx = start_indx + crop_length
87
+
88
+ # Update target
89
+ new_target_P = target_sample_P - start_indx
90
+ new_target_S = target_sample_S - start_indx
91
+
92
+ # Cut
93
+ waveform_cropped = waveform[:, start_indx:end_indx]
94
+
95
+ # Preprocess
96
+ waveform_cropped = detrend(waveform_cropped)
97
+ waveform_cropped = butter_bandpass_filter(
98
+ waveform_cropped, lowcut=0.2, highcut=40, fs=100, order=5
99
+ )
100
+ window = signal.windows.tukey(waveform_cropped[-1].shape[0], alpha=0.1)
101
+ waveform_cropped = waveform_cropped * window
102
+ waveform_cropped = detrend(waveform_cropped)
103
+
104
+ if np.isnan(waveform_cropped).any() == True:
105
+ waveform_cropped = np.zeros(shape=waveform_cropped.shape)
106
+
107
+ new_target_P = 0
108
+ new_target_S = 0
109
+
110
+ if np.sum(waveform_cropped) == 0:
111
+
112
+ new_target_P = 0
113
+ new_target_S = 0
114
+
115
+ # Normalize data
116
+ max_val = np.max(np.abs(waveform_cropped))
117
+ waveform_cropped_norm = waveform_cropped / max_val
118
+
119
+ # Added Z component only
120
+ if len(waveform_cropped_norm) < 3:
121
+ zeros = np.zeros((3, waveform_cropped_norm.shape[-1]))
122
+ zeros[0] = waveform_cropped_norm
123
+
124
+ waveform_cropped_norm = zeros
125
+
126
+ if test == False:
127
+ ##### Rotate waveform #####
128
+ probability = torch.randint(0, 2, size=(1,)).item()
129
+ angle = torch.FloatTensor(size=(1,)).uniform_(0.01, 359.9).item()
130
+ if probability == 1:
131
+ waveform_cropped_norm = rotate_waveform(waveform_cropped_norm, angle).real
132
+
133
+ #### Channel DropOUT #####
134
+ probability = torch.randint(0, 2, size=(1,)).item()
135
+ channel = torch.randint(1, 3, size=(1,)).item()
136
+ if probability == 1:
137
+ waveform_cropped_norm[channel, :] = 1e-6
138
+
139
+ # Normalize target
140
+ new_target_P = new_target_P / crop_length
141
+ new_target_S = new_target_S / crop_length
142
+
143
+ if (new_target_P <= 0) or (new_target_P >= 1) or (np.isnan(new_target_P)):
144
+ new_target_P = 0
145
+ if (new_target_S <= 0) or (new_target_S >= 1) or (np.isnan(new_target_S)):
146
+ new_target_S = 0
147
+
148
+ return waveform_cropped_norm, new_target_P, new_target_S
149
+
150
+
151
+ def collation_fn(sample):
152
+ waveforms = np.stack([x[0] for x in sample])
153
+ targets_P = np.stack([x[1] for x in sample])
154
+ targets_S = np.stack([x[2] for x in sample])
155
+
156
+ return (
157
+ torch.tensor(waveforms, dtype=torch.float),
158
+ torch.tensor(targets_P, dtype=torch.float),
159
+ torch.tensor(targets_S, dtype=torch.float),
160
+ )
161
+
162
+
163
+ def my_split_by_node(urls):
164
+ node_id, node_count = (
165
+ torch.distributed.get_rank(),
166
+ torch.distributed.get_world_size(),
167
+ )
168
+ return list(urls)[node_id::node_count]
169
+
170
+ def prepare_waveform(waveform):
171
+ # SET PARAMETERS:
172
+ crop_length = 6000
173
+ padding = 120
174
+
175
+ assert waveform.shape[0] <= 3, "Waveform has more than 3 channels"
176
+
177
+ if waveform.shape[-1] < crop_length:
178
+ waveform = np.pad(
179
+ waveform,
180
+ ((0, 0), (0, crop_length - waveform.shape[-1])),
181
+ mode="constant",
182
+ constant_values=0,
183
+ )
184
+ if waveform.shape[-1] > crop_length:
185
+ waveform = waveform[:, :crop_length]
186
+
187
+ # Preprocess
188
+ waveform = detrend(waveform)
189
+ waveform = butter_bandpass_filter(
190
+ waveform, lowcut=0.2, highcut=40, fs=100, order=5
191
+ )
192
+ window = signal.windows.tukey(waveform[-1].shape[0], alpha=0.1)
193
+ waveform = waveform * window
194
+ waveform = detrend(waveform)
195
+
196
+ assert np.isnan(waveform).any() != True, "Nan in waveform"
197
+ assert np.sum(waveform) != 0, "Sum of waveform sample is zero"
198
+
199
+ # Normalize data
200
+ max_val = np.max(np.abs(waveform))
201
+ waveform = waveform / max_val
202
+
203
+ # Added Z component only
204
+ if len(waveform) < 3:
205
+ zeros = np.zeros((3, waveform.shape[-1]))
206
+ zeros[0] = waveform
207
+
208
+ waveform = zeros
209
+
210
+ return torch.tensor([waveform]*128, dtype=torch.float)
model.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from torchmetrics import MeanAbsoluteError
6
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
7
+
8
+ import lightning as pl
9
+
10
+ class BlurPool1D(nn.Module):
11
+ def __init__(self, channels, pad_type="reflect", filt_size=3, stride=2, pad_off=0):
12
+ super(BlurPool1D, self).__init__()
13
+ self.filt_size = filt_size
14
+ self.pad_off = pad_off
15
+ self.pad_sizes = [
16
+ int(1.0 * (filt_size - 1) / 2),
17
+ int(np.ceil(1.0 * (filt_size - 1) / 2)),
18
+ ]
19
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
20
+ self.stride = stride
21
+ self.off = int((self.stride - 1) / 2.0)
22
+ self.channels = channels
23
+
24
+ # print('Filter size [%i]' % filt_size)
25
+ if self.filt_size == 1:
26
+ a = np.array(
27
+ [
28
+ 1.0,
29
+ ]
30
+ )
31
+ elif self.filt_size == 2:
32
+ a = np.array([1.0, 1.0])
33
+ elif self.filt_size == 3:
34
+ a = np.array([1.0, 2.0, 1.0])
35
+ elif self.filt_size == 4:
36
+ a = np.array([1.0, 3.0, 3.0, 1.0])
37
+ elif self.filt_size == 5:
38
+ a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
39
+ elif self.filt_size == 6:
40
+ a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
41
+ elif self.filt_size == 7:
42
+ a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
43
+
44
+ filt = torch.Tensor(a)
45
+ filt = filt / torch.sum(filt)
46
+ self.register_buffer("filt", filt[None, None, :].repeat((self.channels, 1, 1)))
47
+
48
+ self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
49
+
50
+ def forward(self, inp):
51
+ if self.filt_size == 1:
52
+ if self.pad_off == 0:
53
+ return inp[:, :, :: self.stride]
54
+ else:
55
+ return self.pad(inp)[:, :, :: self.stride]
56
+ else:
57
+ return F.conv1d(
58
+ self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]
59
+ )
60
+
61
+
62
+ def get_pad_layer_1d(pad_type):
63
+ if pad_type in ["refl", "reflect"]:
64
+ PadLayer = nn.ReflectionPad1d
65
+ elif pad_type in ["repl", "replicate"]:
66
+ PadLayer = nn.ReplicationPad1d
67
+ elif pad_type == "zero":
68
+ PadLayer = nn.ZeroPad1d
69
+ else:
70
+ print("Pad type [%s] not recognized" % pad_type)
71
+ return PadLayer
72
+
73
+
74
+ from masksembles import common
75
+
76
+
77
+ class Masksembles1D(nn.Module):
78
+ def __init__(self, channels: int, n: int, scale: float):
79
+ super().__init__()
80
+
81
+ self.channels = channels
82
+ self.n = n
83
+ self.scale = scale
84
+
85
+ masks = common.generation_wrapper(channels, n, scale)
86
+ masks = torch.from_numpy(masks)
87
+
88
+ self.masks = torch.nn.Parameter(masks, requires_grad=False)
89
+
90
+ def forward(self, inputs):
91
+ batch = inputs.shape[0]
92
+ x = torch.split(inputs.unsqueeze(1), batch // self.n, dim=0)
93
+ x = torch.cat(x, dim=1).permute([1, 0, 2, 3])
94
+ x = x * self.masks.unsqueeze(1).unsqueeze(-1)
95
+ x = torch.cat(torch.split(x, 1, dim=0), dim=1)
96
+
97
+ return x.squeeze(0).type(inputs.dtype)
98
+
99
+
100
+ class BasicBlock(nn.Module):
101
+ expansion = 1
102
+
103
+ def __init__(self, in_planes, planes, stride=1, kernel_size=7, groups=1):
104
+ super(BasicBlock, self).__init__()
105
+ self.conv1 = nn.Conv1d(
106
+ in_planes,
107
+ planes,
108
+ kernel_size=kernel_size,
109
+ stride=stride,
110
+ padding="same",
111
+ bias=False,
112
+ )
113
+ self.bn1 = nn.BatchNorm1d(planes)
114
+ self.conv2 = nn.Conv1d(
115
+ planes,
116
+ planes,
117
+ kernel_size=kernel_size,
118
+ stride=1,
119
+ padding="same",
120
+ bias=False,
121
+ )
122
+ self.bn2 = nn.BatchNorm1d(planes)
123
+
124
+ self.shortcut = nn.Sequential(
125
+ nn.Conv1d(
126
+ in_planes,
127
+ self.expansion * planes,
128
+ kernel_size=1,
129
+ stride=stride,
130
+ padding="same",
131
+ bias=False,
132
+ ),
133
+ nn.BatchNorm1d(self.expansion * planes),
134
+ )
135
+
136
+ def forward(self, x):
137
+ out = F.relu(self.bn1(self.conv1(x)))
138
+ out = self.bn2(self.conv2(out))
139
+ out += self.shortcut(x)
140
+ out = F.relu(out)
141
+ return out
142
+
143
+
144
+ class Updated_onset_picker(nn.Module):
145
+ def __init__(
146
+ self,
147
+ ):
148
+ super().__init__()
149
+
150
+ # self.activation = nn.ReLU()
151
+ # self.maxpool = nn.MaxPool1d(2)
152
+
153
+ self.n_masks = 128
154
+
155
+ self.block1 = nn.Sequential(
156
+ BasicBlock(3, 8, kernel_size=7, groups=1),
157
+ nn.GELU(),
158
+ BlurPool1D(8, filt_size=3, stride=2),
159
+ nn.GroupNorm(2, 8),
160
+ )
161
+
162
+ self.block2 = nn.Sequential(
163
+ BasicBlock(8, 16, kernel_size=7, groups=8),
164
+ nn.GELU(),
165
+ BlurPool1D(16, filt_size=3, stride=2),
166
+ nn.GroupNorm(2, 16),
167
+ )
168
+
169
+ self.block3 = nn.Sequential(
170
+ BasicBlock(16, 32, kernel_size=7, groups=16),
171
+ nn.GELU(),
172
+ BlurPool1D(32, filt_size=3, stride=2),
173
+ nn.GroupNorm(2, 32),
174
+ )
175
+
176
+ self.block4 = nn.Sequential(
177
+ BasicBlock(32, 64, kernel_size=7, groups=32),
178
+ nn.GELU(),
179
+ BlurPool1D(64, filt_size=3, stride=2),
180
+ nn.GroupNorm(2, 64),
181
+ )
182
+
183
+ self.block5 = nn.Sequential(
184
+ BasicBlock(64, 128, kernel_size=7, groups=64),
185
+ nn.GELU(),
186
+ BlurPool1D(128, filt_size=3, stride=2),
187
+ nn.GroupNorm(2, 128),
188
+ )
189
+
190
+ self.block6 = nn.Sequential(
191
+ Masksembles1D(128, self.n_masks, 2.0),
192
+ BasicBlock(128, 256, kernel_size=7, groups=128),
193
+ nn.GELU(),
194
+ BlurPool1D(256, filt_size=3, stride=2),
195
+ nn.GroupNorm(2, 256),
196
+ )
197
+
198
+ self.block7 = nn.Sequential(
199
+ Masksembles1D(256, self.n_masks, 2.0),
200
+ BasicBlock(256, 512, kernel_size=7, groups=256),
201
+ BlurPool1D(512, filt_size=3, stride=2),
202
+ nn.GELU(),
203
+ nn.GroupNorm(2, 512),
204
+ )
205
+
206
+ self.block8 = nn.Sequential(
207
+ Masksembles1D(512, self.n_masks, 2.0),
208
+ BasicBlock(512, 1024, kernel_size=7, groups=512),
209
+ BlurPool1D(1024, filt_size=3, stride=2),
210
+ nn.GELU(),
211
+ nn.GroupNorm(2, 1024),
212
+ )
213
+
214
+ self.block9 = nn.Sequential(
215
+ Masksembles1D(1024, self.n_masks, 2.0),
216
+ BasicBlock(1024, 128, kernel_size=7, groups=128),
217
+ # BlurPool1D(512, filt_size=3, stride=2),
218
+ # nn.GELU(),
219
+ # nn.GroupNorm(2,512),
220
+ )
221
+
222
+ self.out = nn.Sequential(nn.Linear(3072, 2), nn.Sigmoid())
223
+
224
+ def forward(self, x):
225
+ # Feature extraction
226
+
227
+ x = self.block1(x)
228
+ x = self.block2(x)
229
+
230
+ x = self.block3(x)
231
+ x = self.block4(x)
232
+
233
+ x = self.block5(x)
234
+ x = self.block6(x)
235
+
236
+ x = self.block7(x)
237
+ x = self.block8(x)
238
+
239
+ x = self.block9(x)
240
+
241
+ # Regressor
242
+ x = x.flatten(start_dim=1)
243
+ x = self.out(x)
244
+
245
+ return x
246
+
247
+ class Onset_picker(pl.LightningModule):
248
+ def __init__(self, picker, learning_rate):
249
+ super().__init__()
250
+ self.picker = picker
251
+ self.learning_rate = learning_rate
252
+ self.save_hyperparameters(ignore=['picker'])
253
+ self.mae = MeanAbsoluteError()
254
+
255
+ def compute_loss(self, y, pick, mae_name=False):
256
+ y_filt = y[y != 0]
257
+ pick_filt = pick[y != 0]
258
+ if len(y_filt) > 0:
259
+ loss = F.l1_loss(y_filt, pick_filt.flatten())
260
+ if mae_name != False:
261
+ mae_phase = self.mae(y_filt, pick_filt.flatten())*60
262
+ self.log(f'MAE/{mae_name}_val', mae_phase, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
263
+ else:
264
+ loss = 0
265
+ return loss
266
+
267
+ def training_step(self, batch, batch_idx):
268
+ # training_step defines the train loop.
269
+ x, y_p, y_s = batch
270
+ # x, y_p, y_s, y_pg, y_sg, y_pn, y_sn = batch
271
+
272
+ picks = self.picker(x)
273
+
274
+ p_pick = picks[:,0]
275
+ s_pick = picks[:,1]
276
+
277
+ p_loss = self.compute_loss(y_p, p_pick)
278
+ s_loss = self.compute_loss(y_s, s_pick)
279
+
280
+ loss = (p_loss+s_loss)/2
281
+
282
+ self.log('Loss/train', loss, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True)
283
+
284
+ return loss
285
+
286
+ def validation_step(self, batch, batch_idx):
287
+
288
+ x, y_p, y_s = batch
289
+
290
+ picks = self.picker(x)
291
+
292
+ p_pick = picks[:,0]
293
+ s_pick = picks[:,1]
294
+
295
+ p_loss = self.compute_loss(y_p, p_pick, mae_name='P')
296
+ s_loss = self.compute_loss(y_s, s_pick, mae_name='S')
297
+
298
+ loss = (p_loss+s_loss)/2
299
+
300
+ self.log('Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
301
+
302
+ return loss
303
+
304
+ def configure_optimizers(self):
305
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
306
+ scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, cooldown=10, threshold=1e-3)
307
+ # scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 3e-4, epochs=300, steps_per_epoch=len(train_loader))
308
+ monitor = 'Loss/train'
309
+ return {"optimizer": optimizer, "lr_scheduler": scheduler, 'monitor': monitor}
310
+
311
+ def forward(self, x):
312
+ picks = self.picker(x)
313
+ return picks
training.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ from data_preparation import augment, collation_fn, my_split_by_node
5
+ from model import Onset_picker, Updated_onset_picker
6
+
7
+ import webdataset as wds
8
+
9
+ from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
10
+ from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
11
+ from lightning.pytorch.strategies import DDPStrategy
12
+ from lightning import seed_everything
13
+ import lightning as pl
14
+
15
+ seed_everything(42, workers=False)
16
+ torch.set_float32_matmul_precision('medium')
17
+
18
+ batch_size = 256
19
+ num_workers = 16 #int(os.cpu_count())
20
+ n_iters_in_epoch = 5000
21
+
22
+ train_dataset = (
23
+ wds.WebDataset("data/sample/shard-00{0000..0001}.tar",
24
+ # splitter=my_split_by_worker,
25
+ nodesplitter=my_split_by_node)
26
+ .decode()
27
+ .map(augment)
28
+ .shuffle(5000)
29
+ .batched(batchsize=batch_size,
30
+ collation_fn=collation_fn,
31
+ partial=False
32
+ )
33
+ ).with_epoch(n_iters_in_epoch//num_workers)
34
+
35
+
36
+ val_dataset = (
37
+ wds.WebDataset("data/sample/shard-00{0000..0000}.tar",
38
+ # splitter=my_split_by_worker,
39
+ nodesplitter=my_split_by_node)
40
+ .decode()
41
+ .map(augment)
42
+ .repeat()
43
+ .batched(batchsize=batch_size,
44
+ collation_fn=collation_fn,
45
+ partial=False
46
+ )
47
+ ).with_epoch(100)
48
+
49
+
50
+ train_loader = wds.WebLoader(train_dataset,
51
+ num_workers=num_workers,
52
+ shuffle=False,
53
+ pin_memory=True,
54
+ batch_size=None)
55
+
56
+ val_loader = wds.WebLoader(val_dataset,
57
+ num_workers=0,
58
+ shuffle=False,
59
+ pin_memory=True,
60
+ batch_size=None)
61
+
62
+
63
+
64
+ # model
65
+ model = Onset_picker(picker=Updated_onset_picker(),
66
+ learning_rate=3e-4)
67
+ # model = torch.compile(model, mode="reduce-overhead")
68
+
69
+ logger = TensorBoardLogger("tensorboard_logdir", name="FAST")
70
+
71
+ checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="Loss/val", filename="chkp-{epoch:02d}")
72
+ lr_callback = LearningRateMonitor(logging_interval='epoch')
73
+ # swa_callback = StochasticWeightAveraging(swa_lrs=0.05)
74
+
75
+ # # train model
76
+ trainer = pl.Trainer(
77
+ precision='16-mixed',
78
+
79
+ callbacks=[checkpoint_callback, lr_callback],
80
+
81
+ devices='auto',
82
+ accelerator='auto',
83
+
84
+ strategy=DDPStrategy(find_unused_parameters=False,
85
+ static_graph=True,
86
+ gradient_as_bucket_view=True),
87
+ benchmark=True,
88
+
89
+ gradient_clip_val=0.5,
90
+ # ckpt_path='path/to/saved/checkpoints/chkp.ckpt',
91
+
92
+ # fast_dev_run=True,
93
+
94
+ logger=logger,
95
+ log_every_n_steps=50,
96
+ enable_progress_bar=True,
97
+
98
+ max_epochs=300,
99
+ )
100
+
101
+ trainer.fit(model=model,
102
+ train_dataloaders=train_loader,
103
+ val_dataloaders=val_loader,
104
+ )
weights.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80255c65b749559f7c5c3f2bb993a25cc666d9a63a0d3050024679dd8064dcec
3
+ size 200977197