Spaces:
Runtime error
Runtime error
Initial commit
Browse files- .DS_Store +0 -0
- app.py +164 -0
- data/.DS_Store +0 -0
- data/sample/.DS_Store +0 -0
- data/sample/sample_0.npy +3 -0
- data/sample/sample_1.npy +3 -0
- data/sample/sample_2.npy +3 -0
- data_preparation.py +210 -0
- model.py +313 -0
- training.py +104 -0
- weights.ckpt +3 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Gradio app that takes seismic waveform as input and marks 2 phases on the waveform as output.
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
from phasehunter.model import Onset_picker, Updated_onset_picker
|
6 |
+
from phasehunter.data_preparation import prepare_waveform
|
7 |
+
import torch
|
8 |
+
from scipy.stats.kde import gaussian_kde
|
9 |
+
|
10 |
+
import obspy
|
11 |
+
from obspy.clients.fdsn import Client
|
12 |
+
from obspy.clients.fdsn.header import FDSNNoDataException, FDSNTimeoutException, FDSNInternalServerException
|
13 |
+
from obspy.geodetics.base import locations2degrees
|
14 |
+
from obspy.taup import TauPyModel
|
15 |
+
from obspy.taup.helper_classes import SlownessModelError
|
16 |
+
|
17 |
+
from obspy.clients.fdsn.header import URL_MAPPINGS
|
18 |
+
|
19 |
+
def make_prediction(waveform):
|
20 |
+
waveform = np.load(waveform)
|
21 |
+
processed_input = prepare_waveform(waveform)
|
22 |
+
|
23 |
+
# Make prediction
|
24 |
+
with torch.no_grad():
|
25 |
+
output = model(processed_input)
|
26 |
+
|
27 |
+
p_phase = output[:, 0]
|
28 |
+
s_phase = output[:, 1]
|
29 |
+
|
30 |
+
return processed_input, p_phase, s_phase
|
31 |
+
|
32 |
+
def mark_phases(waveform):
|
33 |
+
processed_input, p_phase, s_phase = make_prediction(waveform)
|
34 |
+
|
35 |
+
# Create a plot of the waveform with the phases marked
|
36 |
+
if sum(processed_input[0][2] == 0): #if input is 1C
|
37 |
+
fig, ax = plt.subplots(nrows=2, figsize=(10, 2), sharex=True)
|
38 |
+
|
39 |
+
ax[0].plot(processed_input[0][0])
|
40 |
+
ax[0].set_ylabel('Norm. Ampl.')
|
41 |
+
|
42 |
+
else: #if input is 3C
|
43 |
+
fig, ax = plt.subplots(nrows=4, figsize=(10, 6), sharex=True)
|
44 |
+
ax[0].plot(processed_input[0][0])
|
45 |
+
ax[1].plot(processed_input[0][1])
|
46 |
+
ax[2].plot(processed_input[0][2])
|
47 |
+
|
48 |
+
ax[0].set_ylabel('Z')
|
49 |
+
ax[1].set_ylabel('N')
|
50 |
+
ax[2].set_ylabel('E')
|
51 |
+
|
52 |
+
p_phase_plot = p_phase*processed_input.shape[-1]
|
53 |
+
p_kde = gaussian_kde(p_phase_plot)
|
54 |
+
p_dist_space = np.linspace( min(p_phase_plot)-10, max(p_phase_plot)+10, 500 )
|
55 |
+
ax[-1].plot( p_dist_space, p_kde(p_dist_space), color='r')
|
56 |
+
|
57 |
+
s_phase_plot = s_phase*processed_input.shape[-1]
|
58 |
+
s_kde = gaussian_kde(s_phase_plot)
|
59 |
+
s_dist_space = np.linspace( min(s_phase_plot)-10, max(s_phase_plot)+10, 500 )
|
60 |
+
ax[-1].plot( s_dist_space, s_kde(s_dist_space), color='b')
|
61 |
+
|
62 |
+
for a in ax:
|
63 |
+
a.axvline(p_phase.mean()*processed_input.shape[-1], color='r', linestyle='--', label='P')
|
64 |
+
a.axvline(s_phase.mean()*processed_input.shape[-1], color='b', linestyle='--', label='S')
|
65 |
+
|
66 |
+
ax[-1].set_xlabel('Time, samples')
|
67 |
+
ax[-1].set_ylabel('Uncert.')
|
68 |
+
ax[-1].legend()
|
69 |
+
|
70 |
+
plt.subplots_adjust(hspace=0., wspace=0.)
|
71 |
+
|
72 |
+
# Convert the plot to an image and return it
|
73 |
+
fig.canvas.draw()
|
74 |
+
image = np.array(fig.canvas.renderer.buffer_rgba())
|
75 |
+
plt.close(fig)
|
76 |
+
return image
|
77 |
+
|
78 |
+
def download_data(timestamp, eq_lat, eq_lon, client_name, radius_km):
|
79 |
+
client = Client(client_name)
|
80 |
+
window = radius_km / 111.2
|
81 |
+
|
82 |
+
assert eq_lat - window > -90 and eq_lat + window < 90, "Latitude out of bounds"
|
83 |
+
assert eq_lon - window > -180 and eq_lon + window < 180, "Longitude out of bounds"
|
84 |
+
|
85 |
+
# starttime = catalog['DateTime'].apply(lambda x: pd.Timestamp(x)).min()
|
86 |
+
# endtime = catalog['DateTime'].apply(lambda x: pd.Timestamp(x)).max()
|
87 |
+
|
88 |
+
return 0
|
89 |
+
|
90 |
+
model = Onset_picker.load_from_checkpoint("./weights.ckpt",
|
91 |
+
picker=Updated_onset_picker(),
|
92 |
+
learning_rate=3e-4)
|
93 |
+
model.eval()
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
# # Create the Gradio interface
|
98 |
+
# gr.Interface(mark_phases, inputs, outputs, title='PhaseHunter').launch()
|
99 |
+
|
100 |
+
|
101 |
+
with gr.Blocks() as demo:
|
102 |
+
gr.Markdown("# PhaseHunter")
|
103 |
+
with gr.Tab("Default example"):
|
104 |
+
# Define the input and output types for Gradio
|
105 |
+
inputs = gr.Dropdown(
|
106 |
+
["data/sample/sample_0.npy",
|
107 |
+
"data/sample/sample_1.npy",
|
108 |
+
"data/sample/sample_2.npy"],
|
109 |
+
label="Sample waveform",
|
110 |
+
info="Select one of the samples",
|
111 |
+
value = "data/sample/sample_0.npy"
|
112 |
+
)
|
113 |
+
|
114 |
+
button = gr.Button("Predict phases")
|
115 |
+
outputs = gr.outputs.Image(label='Waveform with Phases Marked', type='numpy')
|
116 |
+
|
117 |
+
button.click(mark_phases, inputs=inputs, outputs=outputs)
|
118 |
+
|
119 |
+
with gr.Tab("Select earthquake from catalogue"):
|
120 |
+
gr.Markdown('TEST')
|
121 |
+
|
122 |
+
client_inputs = gr.Dropdown(
|
123 |
+
choices = list(URL_MAPPINGS.keys()),
|
124 |
+
label="FDSN Client",
|
125 |
+
info="Select one of the available FDSN clients",
|
126 |
+
value = "IRIS",
|
127 |
+
interactive=True
|
128 |
+
)
|
129 |
+
with gr.Row():
|
130 |
+
|
131 |
+
timestamp_inputs = gr.Textbox(value='2019-07-04 17:33:49',
|
132 |
+
placeholder='YYYY-MM-DD HH:MM:SS',
|
133 |
+
label="Timestamp",
|
134 |
+
info="Timestamp of the earthquake",
|
135 |
+
max_lines=1,
|
136 |
+
interactive=True)
|
137 |
+
|
138 |
+
eq_lat_inputs = gr.Number(value=35.766,
|
139 |
+
label="Latitude",
|
140 |
+
info="Latitude of the earthquake",
|
141 |
+
interactive=True)
|
142 |
+
|
143 |
+
eq_lo_inputs = gr.Number(value=117.605,
|
144 |
+
label="Longitude",
|
145 |
+
info="Longitude of the earthquake",
|
146 |
+
interactive=True)
|
147 |
+
|
148 |
+
radius_inputs = gr.Slider(minimum=1,
|
149 |
+
maximum=150,
|
150 |
+
value=50, label="Radius (km)",
|
151 |
+
info="Select the radius around the earthquake to download data from",
|
152 |
+
interactive=True)
|
153 |
+
|
154 |
+
button = gr.Button("Predict phases")
|
155 |
+
|
156 |
+
with gr.Tab("Predict on your own waveform"):
|
157 |
+
gr.Markdown("""
|
158 |
+
Please upload your waveform in .npy (numpy) format.
|
159 |
+
Your waveform should be sampled at 100 sps and have 3 (Z, N, E) or 1 (Z) channels.
|
160 |
+
""")
|
161 |
+
|
162 |
+
button.click(mark_phases, inputs=inputs, outputs=outputs)
|
163 |
+
|
164 |
+
demo.launch()
|
data/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
data/sample/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data/sample/sample_0.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bcacc216bea4273debc0244dce60def2413e4642100029fa0c0ce83416ba71c8
|
3 |
+
size 144128
|
data/sample/sample_1.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:245f7bd592643f8573e8051aa593bdcdffa7e497de3e56aa76423cd96e44ae03
|
3 |
+
size 113776
|
data/sample/sample_2.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c3eb64a6b01caba32464d950aafcbc4ed32f724f062c9c91b1c7f36671ce2b8e
|
3 |
+
size 136344
|
data_preparation.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from scipy import signal
|
5 |
+
from scipy.signal import butter, lfilter, detrend
|
6 |
+
|
7 |
+
# Make bandpass filter
|
8 |
+
def butter_bandpass(lowcut, highcut, fs, order=5):
|
9 |
+
nyq = 0.5 * fs # Nyquist frequency
|
10 |
+
low = lowcut / nyq # Normalized frequency
|
11 |
+
high = highcut / nyq
|
12 |
+
b, a = butter(order, [low, high], btype="band") # Bandpass filter
|
13 |
+
return b, a
|
14 |
+
|
15 |
+
|
16 |
+
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
|
17 |
+
b, a = butter_bandpass(lowcut, highcut, fs, order=order)
|
18 |
+
y = lfilter(b, a, data)
|
19 |
+
return y
|
20 |
+
|
21 |
+
|
22 |
+
def rotate_waveform(waveform, angle):
|
23 |
+
fft_waveform = np.fft.fft(waveform) # Compute the Fourier transform of the waveform
|
24 |
+
rotate_factor = np.exp(
|
25 |
+
1j * angle
|
26 |
+
) # Create a complex exponential with the specified rotation angle
|
27 |
+
rotated_fft_waveform = (
|
28 |
+
fft_waveform * rotate_factor
|
29 |
+
) # Multiply the Fourier transform by the rotation factor
|
30 |
+
rotated_waveform = np.fft.ifft(
|
31 |
+
rotated_fft_waveform
|
32 |
+
) # Compute the inverse Fourier transform to get the rotated waveform in the time domain
|
33 |
+
|
34 |
+
return rotated_waveform
|
35 |
+
|
36 |
+
|
37 |
+
def augment(sample):
|
38 |
+
# SET PARAMETERS:
|
39 |
+
crop_length = 6000
|
40 |
+
padding = 120
|
41 |
+
test = False
|
42 |
+
|
43 |
+
waveform = sample["waveform.npy"]
|
44 |
+
meta = sample["meta.json"]
|
45 |
+
|
46 |
+
if meta["split"] != "train":
|
47 |
+
test = True
|
48 |
+
|
49 |
+
target_sample_P = meta["trace_p_arrival_sample"]
|
50 |
+
target_sample_S = meta["trace_s_arrival_sample"]
|
51 |
+
|
52 |
+
if target_sample_P is None:
|
53 |
+
target_sample_P = 0
|
54 |
+
if target_sample_S is None:
|
55 |
+
target_sample_S = 0
|
56 |
+
|
57 |
+
# Randomly select a phase to start the crop
|
58 |
+
current_phases = [x for x in (target_sample_P, target_sample_S) if x > 0]
|
59 |
+
phase_selector = np.random.randint(0, len(current_phases))
|
60 |
+
first_phase = current_phases[phase_selector]
|
61 |
+
|
62 |
+
# Shuffle
|
63 |
+
if first_phase - (crop_length - padding) > padding:
|
64 |
+
start_indx = int(
|
65 |
+
first_phase
|
66 |
+
- torch.randint(low=padding, high=(crop_length - padding), size=(1,))
|
67 |
+
)
|
68 |
+
if test == True:
|
69 |
+
start_indx = int(first_phase - 2 * padding)
|
70 |
+
|
71 |
+
elif int(first_phase - padding) > 0:
|
72 |
+
start_indx = int(
|
73 |
+
first_phase
|
74 |
+
- torch.randint(low=0, high=(int(first_phase - padding)), size=(1,))
|
75 |
+
)
|
76 |
+
if test == True:
|
77 |
+
start_indx = int(first_phase - padding)
|
78 |
+
|
79 |
+
else:
|
80 |
+
start_indx = padding
|
81 |
+
|
82 |
+
end_indx = start_indx + crop_length
|
83 |
+
|
84 |
+
if (waveform.shape[-1] - end_indx) < 0:
|
85 |
+
start_indx += waveform.shape[-1] - end_indx
|
86 |
+
end_indx = start_indx + crop_length
|
87 |
+
|
88 |
+
# Update target
|
89 |
+
new_target_P = target_sample_P - start_indx
|
90 |
+
new_target_S = target_sample_S - start_indx
|
91 |
+
|
92 |
+
# Cut
|
93 |
+
waveform_cropped = waveform[:, start_indx:end_indx]
|
94 |
+
|
95 |
+
# Preprocess
|
96 |
+
waveform_cropped = detrend(waveform_cropped)
|
97 |
+
waveform_cropped = butter_bandpass_filter(
|
98 |
+
waveform_cropped, lowcut=0.2, highcut=40, fs=100, order=5
|
99 |
+
)
|
100 |
+
window = signal.windows.tukey(waveform_cropped[-1].shape[0], alpha=0.1)
|
101 |
+
waveform_cropped = waveform_cropped * window
|
102 |
+
waveform_cropped = detrend(waveform_cropped)
|
103 |
+
|
104 |
+
if np.isnan(waveform_cropped).any() == True:
|
105 |
+
waveform_cropped = np.zeros(shape=waveform_cropped.shape)
|
106 |
+
|
107 |
+
new_target_P = 0
|
108 |
+
new_target_S = 0
|
109 |
+
|
110 |
+
if np.sum(waveform_cropped) == 0:
|
111 |
+
|
112 |
+
new_target_P = 0
|
113 |
+
new_target_S = 0
|
114 |
+
|
115 |
+
# Normalize data
|
116 |
+
max_val = np.max(np.abs(waveform_cropped))
|
117 |
+
waveform_cropped_norm = waveform_cropped / max_val
|
118 |
+
|
119 |
+
# Added Z component only
|
120 |
+
if len(waveform_cropped_norm) < 3:
|
121 |
+
zeros = np.zeros((3, waveform_cropped_norm.shape[-1]))
|
122 |
+
zeros[0] = waveform_cropped_norm
|
123 |
+
|
124 |
+
waveform_cropped_norm = zeros
|
125 |
+
|
126 |
+
if test == False:
|
127 |
+
##### Rotate waveform #####
|
128 |
+
probability = torch.randint(0, 2, size=(1,)).item()
|
129 |
+
angle = torch.FloatTensor(size=(1,)).uniform_(0.01, 359.9).item()
|
130 |
+
if probability == 1:
|
131 |
+
waveform_cropped_norm = rotate_waveform(waveform_cropped_norm, angle).real
|
132 |
+
|
133 |
+
#### Channel DropOUT #####
|
134 |
+
probability = torch.randint(0, 2, size=(1,)).item()
|
135 |
+
channel = torch.randint(1, 3, size=(1,)).item()
|
136 |
+
if probability == 1:
|
137 |
+
waveform_cropped_norm[channel, :] = 1e-6
|
138 |
+
|
139 |
+
# Normalize target
|
140 |
+
new_target_P = new_target_P / crop_length
|
141 |
+
new_target_S = new_target_S / crop_length
|
142 |
+
|
143 |
+
if (new_target_P <= 0) or (new_target_P >= 1) or (np.isnan(new_target_P)):
|
144 |
+
new_target_P = 0
|
145 |
+
if (new_target_S <= 0) or (new_target_S >= 1) or (np.isnan(new_target_S)):
|
146 |
+
new_target_S = 0
|
147 |
+
|
148 |
+
return waveform_cropped_norm, new_target_P, new_target_S
|
149 |
+
|
150 |
+
|
151 |
+
def collation_fn(sample):
|
152 |
+
waveforms = np.stack([x[0] for x in sample])
|
153 |
+
targets_P = np.stack([x[1] for x in sample])
|
154 |
+
targets_S = np.stack([x[2] for x in sample])
|
155 |
+
|
156 |
+
return (
|
157 |
+
torch.tensor(waveforms, dtype=torch.float),
|
158 |
+
torch.tensor(targets_P, dtype=torch.float),
|
159 |
+
torch.tensor(targets_S, dtype=torch.float),
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
def my_split_by_node(urls):
|
164 |
+
node_id, node_count = (
|
165 |
+
torch.distributed.get_rank(),
|
166 |
+
torch.distributed.get_world_size(),
|
167 |
+
)
|
168 |
+
return list(urls)[node_id::node_count]
|
169 |
+
|
170 |
+
def prepare_waveform(waveform):
|
171 |
+
# SET PARAMETERS:
|
172 |
+
crop_length = 6000
|
173 |
+
padding = 120
|
174 |
+
|
175 |
+
assert waveform.shape[0] <= 3, "Waveform has more than 3 channels"
|
176 |
+
|
177 |
+
if waveform.shape[-1] < crop_length:
|
178 |
+
waveform = np.pad(
|
179 |
+
waveform,
|
180 |
+
((0, 0), (0, crop_length - waveform.shape[-1])),
|
181 |
+
mode="constant",
|
182 |
+
constant_values=0,
|
183 |
+
)
|
184 |
+
if waveform.shape[-1] > crop_length:
|
185 |
+
waveform = waveform[:, :crop_length]
|
186 |
+
|
187 |
+
# Preprocess
|
188 |
+
waveform = detrend(waveform)
|
189 |
+
waveform = butter_bandpass_filter(
|
190 |
+
waveform, lowcut=0.2, highcut=40, fs=100, order=5
|
191 |
+
)
|
192 |
+
window = signal.windows.tukey(waveform[-1].shape[0], alpha=0.1)
|
193 |
+
waveform = waveform * window
|
194 |
+
waveform = detrend(waveform)
|
195 |
+
|
196 |
+
assert np.isnan(waveform).any() != True, "Nan in waveform"
|
197 |
+
assert np.sum(waveform) != 0, "Sum of waveform sample is zero"
|
198 |
+
|
199 |
+
# Normalize data
|
200 |
+
max_val = np.max(np.abs(waveform))
|
201 |
+
waveform = waveform / max_val
|
202 |
+
|
203 |
+
# Added Z component only
|
204 |
+
if len(waveform) < 3:
|
205 |
+
zeros = np.zeros((3, waveform.shape[-1]))
|
206 |
+
zeros[0] = waveform
|
207 |
+
|
208 |
+
waveform = zeros
|
209 |
+
|
210 |
+
return torch.tensor([waveform]*128, dtype=torch.float)
|
model.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
from torchmetrics import MeanAbsoluteError
|
6 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
7 |
+
|
8 |
+
import lightning as pl
|
9 |
+
|
10 |
+
class BlurPool1D(nn.Module):
|
11 |
+
def __init__(self, channels, pad_type="reflect", filt_size=3, stride=2, pad_off=0):
|
12 |
+
super(BlurPool1D, self).__init__()
|
13 |
+
self.filt_size = filt_size
|
14 |
+
self.pad_off = pad_off
|
15 |
+
self.pad_sizes = [
|
16 |
+
int(1.0 * (filt_size - 1) / 2),
|
17 |
+
int(np.ceil(1.0 * (filt_size - 1) / 2)),
|
18 |
+
]
|
19 |
+
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
|
20 |
+
self.stride = stride
|
21 |
+
self.off = int((self.stride - 1) / 2.0)
|
22 |
+
self.channels = channels
|
23 |
+
|
24 |
+
# print('Filter size [%i]' % filt_size)
|
25 |
+
if self.filt_size == 1:
|
26 |
+
a = np.array(
|
27 |
+
[
|
28 |
+
1.0,
|
29 |
+
]
|
30 |
+
)
|
31 |
+
elif self.filt_size == 2:
|
32 |
+
a = np.array([1.0, 1.0])
|
33 |
+
elif self.filt_size == 3:
|
34 |
+
a = np.array([1.0, 2.0, 1.0])
|
35 |
+
elif self.filt_size == 4:
|
36 |
+
a = np.array([1.0, 3.0, 3.0, 1.0])
|
37 |
+
elif self.filt_size == 5:
|
38 |
+
a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
|
39 |
+
elif self.filt_size == 6:
|
40 |
+
a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
|
41 |
+
elif self.filt_size == 7:
|
42 |
+
a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
|
43 |
+
|
44 |
+
filt = torch.Tensor(a)
|
45 |
+
filt = filt / torch.sum(filt)
|
46 |
+
self.register_buffer("filt", filt[None, None, :].repeat((self.channels, 1, 1)))
|
47 |
+
|
48 |
+
self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
|
49 |
+
|
50 |
+
def forward(self, inp):
|
51 |
+
if self.filt_size == 1:
|
52 |
+
if self.pad_off == 0:
|
53 |
+
return inp[:, :, :: self.stride]
|
54 |
+
else:
|
55 |
+
return self.pad(inp)[:, :, :: self.stride]
|
56 |
+
else:
|
57 |
+
return F.conv1d(
|
58 |
+
self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def get_pad_layer_1d(pad_type):
|
63 |
+
if pad_type in ["refl", "reflect"]:
|
64 |
+
PadLayer = nn.ReflectionPad1d
|
65 |
+
elif pad_type in ["repl", "replicate"]:
|
66 |
+
PadLayer = nn.ReplicationPad1d
|
67 |
+
elif pad_type == "zero":
|
68 |
+
PadLayer = nn.ZeroPad1d
|
69 |
+
else:
|
70 |
+
print("Pad type [%s] not recognized" % pad_type)
|
71 |
+
return PadLayer
|
72 |
+
|
73 |
+
|
74 |
+
from masksembles import common
|
75 |
+
|
76 |
+
|
77 |
+
class Masksembles1D(nn.Module):
|
78 |
+
def __init__(self, channels: int, n: int, scale: float):
|
79 |
+
super().__init__()
|
80 |
+
|
81 |
+
self.channels = channels
|
82 |
+
self.n = n
|
83 |
+
self.scale = scale
|
84 |
+
|
85 |
+
masks = common.generation_wrapper(channels, n, scale)
|
86 |
+
masks = torch.from_numpy(masks)
|
87 |
+
|
88 |
+
self.masks = torch.nn.Parameter(masks, requires_grad=False)
|
89 |
+
|
90 |
+
def forward(self, inputs):
|
91 |
+
batch = inputs.shape[0]
|
92 |
+
x = torch.split(inputs.unsqueeze(1), batch // self.n, dim=0)
|
93 |
+
x = torch.cat(x, dim=1).permute([1, 0, 2, 3])
|
94 |
+
x = x * self.masks.unsqueeze(1).unsqueeze(-1)
|
95 |
+
x = torch.cat(torch.split(x, 1, dim=0), dim=1)
|
96 |
+
|
97 |
+
return x.squeeze(0).type(inputs.dtype)
|
98 |
+
|
99 |
+
|
100 |
+
class BasicBlock(nn.Module):
|
101 |
+
expansion = 1
|
102 |
+
|
103 |
+
def __init__(self, in_planes, planes, stride=1, kernel_size=7, groups=1):
|
104 |
+
super(BasicBlock, self).__init__()
|
105 |
+
self.conv1 = nn.Conv1d(
|
106 |
+
in_planes,
|
107 |
+
planes,
|
108 |
+
kernel_size=kernel_size,
|
109 |
+
stride=stride,
|
110 |
+
padding="same",
|
111 |
+
bias=False,
|
112 |
+
)
|
113 |
+
self.bn1 = nn.BatchNorm1d(planes)
|
114 |
+
self.conv2 = nn.Conv1d(
|
115 |
+
planes,
|
116 |
+
planes,
|
117 |
+
kernel_size=kernel_size,
|
118 |
+
stride=1,
|
119 |
+
padding="same",
|
120 |
+
bias=False,
|
121 |
+
)
|
122 |
+
self.bn2 = nn.BatchNorm1d(planes)
|
123 |
+
|
124 |
+
self.shortcut = nn.Sequential(
|
125 |
+
nn.Conv1d(
|
126 |
+
in_planes,
|
127 |
+
self.expansion * planes,
|
128 |
+
kernel_size=1,
|
129 |
+
stride=stride,
|
130 |
+
padding="same",
|
131 |
+
bias=False,
|
132 |
+
),
|
133 |
+
nn.BatchNorm1d(self.expansion * planes),
|
134 |
+
)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
138 |
+
out = self.bn2(self.conv2(out))
|
139 |
+
out += self.shortcut(x)
|
140 |
+
out = F.relu(out)
|
141 |
+
return out
|
142 |
+
|
143 |
+
|
144 |
+
class Updated_onset_picker(nn.Module):
|
145 |
+
def __init__(
|
146 |
+
self,
|
147 |
+
):
|
148 |
+
super().__init__()
|
149 |
+
|
150 |
+
# self.activation = nn.ReLU()
|
151 |
+
# self.maxpool = nn.MaxPool1d(2)
|
152 |
+
|
153 |
+
self.n_masks = 128
|
154 |
+
|
155 |
+
self.block1 = nn.Sequential(
|
156 |
+
BasicBlock(3, 8, kernel_size=7, groups=1),
|
157 |
+
nn.GELU(),
|
158 |
+
BlurPool1D(8, filt_size=3, stride=2),
|
159 |
+
nn.GroupNorm(2, 8),
|
160 |
+
)
|
161 |
+
|
162 |
+
self.block2 = nn.Sequential(
|
163 |
+
BasicBlock(8, 16, kernel_size=7, groups=8),
|
164 |
+
nn.GELU(),
|
165 |
+
BlurPool1D(16, filt_size=3, stride=2),
|
166 |
+
nn.GroupNorm(2, 16),
|
167 |
+
)
|
168 |
+
|
169 |
+
self.block3 = nn.Sequential(
|
170 |
+
BasicBlock(16, 32, kernel_size=7, groups=16),
|
171 |
+
nn.GELU(),
|
172 |
+
BlurPool1D(32, filt_size=3, stride=2),
|
173 |
+
nn.GroupNorm(2, 32),
|
174 |
+
)
|
175 |
+
|
176 |
+
self.block4 = nn.Sequential(
|
177 |
+
BasicBlock(32, 64, kernel_size=7, groups=32),
|
178 |
+
nn.GELU(),
|
179 |
+
BlurPool1D(64, filt_size=3, stride=2),
|
180 |
+
nn.GroupNorm(2, 64),
|
181 |
+
)
|
182 |
+
|
183 |
+
self.block5 = nn.Sequential(
|
184 |
+
BasicBlock(64, 128, kernel_size=7, groups=64),
|
185 |
+
nn.GELU(),
|
186 |
+
BlurPool1D(128, filt_size=3, stride=2),
|
187 |
+
nn.GroupNorm(2, 128),
|
188 |
+
)
|
189 |
+
|
190 |
+
self.block6 = nn.Sequential(
|
191 |
+
Masksembles1D(128, self.n_masks, 2.0),
|
192 |
+
BasicBlock(128, 256, kernel_size=7, groups=128),
|
193 |
+
nn.GELU(),
|
194 |
+
BlurPool1D(256, filt_size=3, stride=2),
|
195 |
+
nn.GroupNorm(2, 256),
|
196 |
+
)
|
197 |
+
|
198 |
+
self.block7 = nn.Sequential(
|
199 |
+
Masksembles1D(256, self.n_masks, 2.0),
|
200 |
+
BasicBlock(256, 512, kernel_size=7, groups=256),
|
201 |
+
BlurPool1D(512, filt_size=3, stride=2),
|
202 |
+
nn.GELU(),
|
203 |
+
nn.GroupNorm(2, 512),
|
204 |
+
)
|
205 |
+
|
206 |
+
self.block8 = nn.Sequential(
|
207 |
+
Masksembles1D(512, self.n_masks, 2.0),
|
208 |
+
BasicBlock(512, 1024, kernel_size=7, groups=512),
|
209 |
+
BlurPool1D(1024, filt_size=3, stride=2),
|
210 |
+
nn.GELU(),
|
211 |
+
nn.GroupNorm(2, 1024),
|
212 |
+
)
|
213 |
+
|
214 |
+
self.block9 = nn.Sequential(
|
215 |
+
Masksembles1D(1024, self.n_masks, 2.0),
|
216 |
+
BasicBlock(1024, 128, kernel_size=7, groups=128),
|
217 |
+
# BlurPool1D(512, filt_size=3, stride=2),
|
218 |
+
# nn.GELU(),
|
219 |
+
# nn.GroupNorm(2,512),
|
220 |
+
)
|
221 |
+
|
222 |
+
self.out = nn.Sequential(nn.Linear(3072, 2), nn.Sigmoid())
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
# Feature extraction
|
226 |
+
|
227 |
+
x = self.block1(x)
|
228 |
+
x = self.block2(x)
|
229 |
+
|
230 |
+
x = self.block3(x)
|
231 |
+
x = self.block4(x)
|
232 |
+
|
233 |
+
x = self.block5(x)
|
234 |
+
x = self.block6(x)
|
235 |
+
|
236 |
+
x = self.block7(x)
|
237 |
+
x = self.block8(x)
|
238 |
+
|
239 |
+
x = self.block9(x)
|
240 |
+
|
241 |
+
# Regressor
|
242 |
+
x = x.flatten(start_dim=1)
|
243 |
+
x = self.out(x)
|
244 |
+
|
245 |
+
return x
|
246 |
+
|
247 |
+
class Onset_picker(pl.LightningModule):
|
248 |
+
def __init__(self, picker, learning_rate):
|
249 |
+
super().__init__()
|
250 |
+
self.picker = picker
|
251 |
+
self.learning_rate = learning_rate
|
252 |
+
self.save_hyperparameters(ignore=['picker'])
|
253 |
+
self.mae = MeanAbsoluteError()
|
254 |
+
|
255 |
+
def compute_loss(self, y, pick, mae_name=False):
|
256 |
+
y_filt = y[y != 0]
|
257 |
+
pick_filt = pick[y != 0]
|
258 |
+
if len(y_filt) > 0:
|
259 |
+
loss = F.l1_loss(y_filt, pick_filt.flatten())
|
260 |
+
if mae_name != False:
|
261 |
+
mae_phase = self.mae(y_filt, pick_filt.flatten())*60
|
262 |
+
self.log(f'MAE/{mae_name}_val', mae_phase, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
263 |
+
else:
|
264 |
+
loss = 0
|
265 |
+
return loss
|
266 |
+
|
267 |
+
def training_step(self, batch, batch_idx):
|
268 |
+
# training_step defines the train loop.
|
269 |
+
x, y_p, y_s = batch
|
270 |
+
# x, y_p, y_s, y_pg, y_sg, y_pn, y_sn = batch
|
271 |
+
|
272 |
+
picks = self.picker(x)
|
273 |
+
|
274 |
+
p_pick = picks[:,0]
|
275 |
+
s_pick = picks[:,1]
|
276 |
+
|
277 |
+
p_loss = self.compute_loss(y_p, p_pick)
|
278 |
+
s_loss = self.compute_loss(y_s, s_pick)
|
279 |
+
|
280 |
+
loss = (p_loss+s_loss)/2
|
281 |
+
|
282 |
+
self.log('Loss/train', loss, on_step=True, on_epoch=False, prog_bar=True, sync_dist=True)
|
283 |
+
|
284 |
+
return loss
|
285 |
+
|
286 |
+
def validation_step(self, batch, batch_idx):
|
287 |
+
|
288 |
+
x, y_p, y_s = batch
|
289 |
+
|
290 |
+
picks = self.picker(x)
|
291 |
+
|
292 |
+
p_pick = picks[:,0]
|
293 |
+
s_pick = picks[:,1]
|
294 |
+
|
295 |
+
p_loss = self.compute_loss(y_p, p_pick, mae_name='P')
|
296 |
+
s_loss = self.compute_loss(y_s, s_pick, mae_name='S')
|
297 |
+
|
298 |
+
loss = (p_loss+s_loss)/2
|
299 |
+
|
300 |
+
self.log('Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
301 |
+
|
302 |
+
return loss
|
303 |
+
|
304 |
+
def configure_optimizers(self):
|
305 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
306 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, cooldown=10, threshold=1e-3)
|
307 |
+
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 3e-4, epochs=300, steps_per_epoch=len(train_loader))
|
308 |
+
monitor = 'Loss/train'
|
309 |
+
return {"optimizer": optimizer, "lr_scheduler": scheduler, 'monitor': monitor}
|
310 |
+
|
311 |
+
def forward(self, x):
|
312 |
+
picks = self.picker(x)
|
313 |
+
return picks
|
training.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from data_preparation import augment, collation_fn, my_split_by_node
|
5 |
+
from model import Onset_picker, Updated_onset_picker
|
6 |
+
|
7 |
+
import webdataset as wds
|
8 |
+
|
9 |
+
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
|
10 |
+
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
|
11 |
+
from lightning.pytorch.strategies import DDPStrategy
|
12 |
+
from lightning import seed_everything
|
13 |
+
import lightning as pl
|
14 |
+
|
15 |
+
seed_everything(42, workers=False)
|
16 |
+
torch.set_float32_matmul_precision('medium')
|
17 |
+
|
18 |
+
batch_size = 256
|
19 |
+
num_workers = 16 #int(os.cpu_count())
|
20 |
+
n_iters_in_epoch = 5000
|
21 |
+
|
22 |
+
train_dataset = (
|
23 |
+
wds.WebDataset("data/sample/shard-00{0000..0001}.tar",
|
24 |
+
# splitter=my_split_by_worker,
|
25 |
+
nodesplitter=my_split_by_node)
|
26 |
+
.decode()
|
27 |
+
.map(augment)
|
28 |
+
.shuffle(5000)
|
29 |
+
.batched(batchsize=batch_size,
|
30 |
+
collation_fn=collation_fn,
|
31 |
+
partial=False
|
32 |
+
)
|
33 |
+
).with_epoch(n_iters_in_epoch//num_workers)
|
34 |
+
|
35 |
+
|
36 |
+
val_dataset = (
|
37 |
+
wds.WebDataset("data/sample/shard-00{0000..0000}.tar",
|
38 |
+
# splitter=my_split_by_worker,
|
39 |
+
nodesplitter=my_split_by_node)
|
40 |
+
.decode()
|
41 |
+
.map(augment)
|
42 |
+
.repeat()
|
43 |
+
.batched(batchsize=batch_size,
|
44 |
+
collation_fn=collation_fn,
|
45 |
+
partial=False
|
46 |
+
)
|
47 |
+
).with_epoch(100)
|
48 |
+
|
49 |
+
|
50 |
+
train_loader = wds.WebLoader(train_dataset,
|
51 |
+
num_workers=num_workers,
|
52 |
+
shuffle=False,
|
53 |
+
pin_memory=True,
|
54 |
+
batch_size=None)
|
55 |
+
|
56 |
+
val_loader = wds.WebLoader(val_dataset,
|
57 |
+
num_workers=0,
|
58 |
+
shuffle=False,
|
59 |
+
pin_memory=True,
|
60 |
+
batch_size=None)
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
# model
|
65 |
+
model = Onset_picker(picker=Updated_onset_picker(),
|
66 |
+
learning_rate=3e-4)
|
67 |
+
# model = torch.compile(model, mode="reduce-overhead")
|
68 |
+
|
69 |
+
logger = TensorBoardLogger("tensorboard_logdir", name="FAST")
|
70 |
+
|
71 |
+
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="Loss/val", filename="chkp-{epoch:02d}")
|
72 |
+
lr_callback = LearningRateMonitor(logging_interval='epoch')
|
73 |
+
# swa_callback = StochasticWeightAveraging(swa_lrs=0.05)
|
74 |
+
|
75 |
+
# # train model
|
76 |
+
trainer = pl.Trainer(
|
77 |
+
precision='16-mixed',
|
78 |
+
|
79 |
+
callbacks=[checkpoint_callback, lr_callback],
|
80 |
+
|
81 |
+
devices='auto',
|
82 |
+
accelerator='auto',
|
83 |
+
|
84 |
+
strategy=DDPStrategy(find_unused_parameters=False,
|
85 |
+
static_graph=True,
|
86 |
+
gradient_as_bucket_view=True),
|
87 |
+
benchmark=True,
|
88 |
+
|
89 |
+
gradient_clip_val=0.5,
|
90 |
+
# ckpt_path='path/to/saved/checkpoints/chkp.ckpt',
|
91 |
+
|
92 |
+
# fast_dev_run=True,
|
93 |
+
|
94 |
+
logger=logger,
|
95 |
+
log_every_n_steps=50,
|
96 |
+
enable_progress_bar=True,
|
97 |
+
|
98 |
+
max_epochs=300,
|
99 |
+
)
|
100 |
+
|
101 |
+
trainer.fit(model=model,
|
102 |
+
train_dataloaders=train_loader,
|
103 |
+
val_dataloaders=val_loader,
|
104 |
+
)
|
weights.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:80255c65b749559f7c5c3f2bb993a25cc666d9a63a0d3050024679dd8064dcec
|
3 |
+
size 200977197
|