Upload 65 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +30 -0
- 11522105.pdf +3 -0
- data_load.py +48 -0
- data_processing.py +265 -0
- resnet1d.py +297 -0
- resnet1d_multitask.py +155 -0
- results/loss_curves_modelA_20241207_145904.png +3 -0
- results/loss_curves_modelB_20241207_151513.png +3 -0
- results/loss_curves_modelC_20241207_153102.png +3 -0
- results/metrics_modelA_20241207_145904.txt +12 -0
- results/metrics_modelB_20241207_151513.txt +12 -0
- results/metrics_modelC_20241207_153102.txt +12 -0
- results/modelC_scatter_CEC_20241207_153102.png +0 -0
- results/modelC_scatter_CaCO3_20241207_153102.png +0 -0
- results/modelC_scatter_K_20241207_153102.png +0 -0
- results/modelC_scatter_N_20241207_153102.png +0 -0
- results/modelC_scatter_OC_20241207_153102.png +0 -0
- results/modelC_scatter_P_20241207_153102.png +0 -0
- results/modelC_scatter_pH.in.CaCl2_20241207_153102.png +3 -0
- results/modelC_scatter_pH.in.H2O_20241207_153102.png +0 -0
- results/training_metrics_modelA_20241207_145904.png +3 -0
- results/training_metrics_modelB_20241207_151513.png +3 -0
- results/training_metrics_modelC_20241207_153102.png +3 -0
- results1/loss_curves_modelC_20241207_155007_Abs-SG0.png +3 -0
- results1/loss_curves_modelC_20241207_160632_Abs-SG0-SNV.png +3 -0
- results1/loss_curves_modelC_20241207_162218_Abs-SG1.png +3 -0
- results1/loss_curves_modelC_20241207_163811_Abs-SG1-SNV.png +3 -0
- results1/loss_curves_modelC_20241207_165344_Abs-SG2.png +3 -0
- results1/loss_curves_modelC_20241207_170911_Abs-SG2-SNV.png +3 -0
- results1/metrics_modelC_20241210_142058_Abs-SG0.txt +12 -0
- results1/metrics_modelC_20241210_143517_Abs-SG0-SNV.txt +12 -0
- results1/metrics_modelC_20241210_144926_Abs-SG1.txt +12 -0
- results1/metrics_modelC_20241210_150332_Abs-SG1-SNV.txt +12 -0
- results1/metrics_modelC_20241210_151845_Abs-SG2.txt +12 -0
- results1/metrics_modelC_20241210_153333_Abs-SG2-SNV.txt +12 -0
- results1/training_metrics_modelC_20241207_155007_Abs-SG0.png +3 -0
- results1/training_metrics_modelC_20241207_160632_Abs-SG0-SNV.png +3 -0
- results1/training_metrics_modelC_20241207_162218_Abs-SG1.png +3 -0
- results1/training_metrics_modelC_20241207_163811_Abs-SG1-SNV.png +3 -0
- results1/training_metrics_modelC_20241207_165344_Abs-SG2.png +3 -0
- results1/training_metrics_modelC_20241207_170911_Abs-SG2-SNV.png +3 -0
- results2/loss_curves_modelC_20241213_202047_5.png +3 -0
- results2/loss_curves_modelC_20241213_203438_10.png +3 -0
- results2/loss_curves_modelC_20241213_204103_15.png +3 -0
- results2/metrics_modelC_20241213_202047_5.txt +4 -0
- results2/metrics_modelC_20241213_203438_10.txt +4 -0
- results2/metrics_modelC_20241213_204103_15.txt +4 -0
- results2/training_metrics_modelC_20241213_202047_5.png +3 -0
- results2/training_metrics_modelC_20241213_203438_10.png +3 -0
- results2/training_metrics_modelC_20241213_204103_15.png +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,33 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
11522105.pdf filter=lfs diff=lfs merge=lfs -text
|
37 |
+
results/loss_curves_modelA_20241207_145904.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
results/loss_curves_modelB_20241207_151513.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
results/loss_curves_modelC_20241207_153102.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
results/modelC_scatter_pH.in.CaCl2_20241207_153102.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
results/training_metrics_modelA_20241207_145904.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
results/training_metrics_modelB_20241207_151513.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
results/training_metrics_modelC_20241207_153102.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
results1/loss_curves_modelC_20241207_155007_Abs-SG0.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
results1/loss_curves_modelC_20241207_160632_Abs-SG0-SNV.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
results1/loss_curves_modelC_20241207_162218_Abs-SG1.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
results1/loss_curves_modelC_20241207_163811_Abs-SG1-SNV.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
results1/loss_curves_modelC_20241207_165344_Abs-SG2.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
results1/loss_curves_modelC_20241207_170911_Abs-SG2-SNV.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
results1/training_metrics_modelC_20241207_155007_Abs-SG0.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
results1/training_metrics_modelC_20241207_160632_Abs-SG0-SNV.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
results1/training_metrics_modelC_20241207_162218_Abs-SG1.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
results1/training_metrics_modelC_20241207_163811_Abs-SG1-SNV.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
results1/training_metrics_modelC_20241207_165344_Abs-SG2.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
results1/training_metrics_modelC_20241207_170911_Abs-SG2-SNV.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
results2/loss_curves_modelC_20241213_202047_5.png filter=lfs diff=lfs merge=lfs -text
|
57 |
+
results2/loss_curves_modelC_20241213_203438_10.png filter=lfs diff=lfs merge=lfs -text
|
58 |
+
results2/loss_curves_modelC_20241213_204103_15.png filter=lfs diff=lfs merge=lfs -text
|
59 |
+
results2/training_metrics_modelC_20241213_202047_5.png filter=lfs diff=lfs merge=lfs -text
|
60 |
+
results2/training_metrics_modelC_20241213_203438_10.png filter=lfs diff=lfs merge=lfs -text
|
61 |
+
results2/training_metrics_modelC_20241213_204103_15.png filter=lfs diff=lfs merge=lfs -text
|
62 |
+
results3/loss_curves_modelC_20241213_211327_15.png filter=lfs diff=lfs merge=lfs -text
|
63 |
+
results3/training_metrics_modelC_20241213_211327_15.png filter=lfs diff=lfs merge=lfs -text
|
64 |
+
shap_summary_plot.png filter=lfs diff=lfs merge=lfs -text
|
65 |
+
shap_top10_wavelengths.png filter=lfs diff=lfs merge=lfs -text
|
11522105.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d350619292928420963d55ea260c286edc0ccad267f2bbefa9b97cee4ee3661
|
3 |
+
size 2284006
|
data_load.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn.model_selection import train_test_split
|
4 |
+
|
5 |
+
def load_soil_data(file_path, target_columns):
|
6 |
+
"""
|
7 |
+
参数:
|
8 |
+
- file_path: 包含土壤数据的文件路径(假设为CSV格式)
|
9 |
+
- target_columns: 列表,包含8个目标土壤指标的列名
|
10 |
+
|
11 |
+
返回:
|
12 |
+
- X_train, X_test, y_train, y_test: 训练和测试集的特征和目标值,划分为8:2
|
13 |
+
- wavelengths: 波长信息数组
|
14 |
+
"""
|
15 |
+
# 读取CSV文件
|
16 |
+
data = pd.read_csv(file_path)
|
17 |
+
|
18 |
+
# 提取波长信息(波长在前4200列的列头中)
|
19 |
+
wavelengths = data.columns[:4200].str.replace('spc.', '').astype(float)
|
20 |
+
|
21 |
+
# 假设每个Record包含4200个数据点
|
22 |
+
X = data.iloc[:, :4200].values # 取前4200列作为特征
|
23 |
+
y = data[target_columns].values # 取目标列作为标签
|
24 |
+
|
25 |
+
# 确保特征数据是浮点数类型
|
26 |
+
X = X.astype('float32')
|
27 |
+
|
28 |
+
# 分割数据集为训练集和测试集,训练集占80%
|
29 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
30 |
+
|
31 |
+
# 将特征数据重塑为ResNet模型的输入形状
|
32 |
+
# 对于1D卷积,我们需要(batch_size, channels, sequence_length)的形状
|
33 |
+
X_train = X_train.reshape(-1, 1, 4200) # 一个通道,序列长度为4200
|
34 |
+
X_test = X_test.reshape(-1, 1, 4200)
|
35 |
+
|
36 |
+
return X_train, X_test, y_train, y_test, wavelengths
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
# 使用示例
|
40 |
+
file_path = 'LUCAS.2009_abs.csv'
|
41 |
+
target_columns = ['pH.in.CaCl2', 'pH.in.H2O', 'OC', 'CaCO3', 'N', 'P', 'K', 'CEC']
|
42 |
+
X_train, X_test, y_train, y_test, wavelengths = load_soil_data(file_path, target_columns)
|
43 |
+
|
44 |
+
print("X_train shape:", X_train.shape)
|
45 |
+
print("X_test shape:", X_test.shape)
|
46 |
+
print("y_train shape:", y_train.shape)
|
47 |
+
print("y_test shape:", y_test.shape)
|
48 |
+
print("wavelengths shape:", wavelengths.shape)
|
data_processing.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.signal import savgol_filter
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from data_load import load_soil_data
|
5 |
+
|
6 |
+
def apply_sg_filter(spectra, window_length=15, polyorder=2, deriv=0):
|
7 |
+
"""
|
8 |
+
应用Savitzky-Golay滤波器进行光谱平滑或求导
|
9 |
+
参数:
|
10 |
+
- spectra: 输入光谱数据,形状为(n_samples, n_wavelengths)
|
11 |
+
- window_length: 窗口长度,必须是奇数
|
12 |
+
- polyorder: 多项式最高阶数
|
13 |
+
- deriv: 求导阶数,0表示平滑,1表示一阶导数,2表示二阶导数
|
14 |
+
返回:
|
15 |
+
- 处理后的光谱数据
|
16 |
+
"""
|
17 |
+
return np.array([savgol_filter(spectrum, window_length, polyorder, deriv=deriv)
|
18 |
+
for spectrum in spectra])
|
19 |
+
|
20 |
+
|
21 |
+
def apply_snv(spectra):
|
22 |
+
"""
|
23 |
+
应用标准正态变量(SNV)转换 (标准正态变量变换)
|
24 |
+
参数:
|
25 |
+
- spectra: 输入光谱数据,形状为(n_samples, n_wavelengths)
|
26 |
+
|
27 |
+
返回:
|
28 |
+
- SNV处理后的光谱数据
|
29 |
+
"""
|
30 |
+
# 对每个样本进行SNV转换
|
31 |
+
spectra_snv = np.zeros_like(spectra)
|
32 |
+
for i in range(spectra.shape[0]):
|
33 |
+
spectrum = spectra[i]
|
34 |
+
# 计算均值和标准差
|
35 |
+
mean = np.mean(spectrum)
|
36 |
+
std = np.std(spectrum)
|
37 |
+
# 应用SNV转换
|
38 |
+
spectra_snv[i] = (spectrum - mean) / std
|
39 |
+
return spectra_snv
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
def process_spectra(spectra, method='Abs-SG0'):
|
45 |
+
"""
|
46 |
+
根据指定方法处理光谱数据
|
47 |
+
|
48 |
+
参数:
|
49 |
+
- spectra: 输入光谱数据,形状为(n_samples, n_wavelengths)
|
50 |
+
- method: 处理方法,可选值包括:
|
51 |
+
'Abs-SG0': SG平滑
|
52 |
+
'Abs-SG0-SNV': SG平滑+SNV
|
53 |
+
'Abs-SG1': SG一阶导
|
54 |
+
'Abs-SG1-SNV': SG一阶导+SNV
|
55 |
+
'Abs-SG2': SG二阶导
|
56 |
+
'Abs-SG2-SNV': SG二阶导+SNV
|
57 |
+
|
58 |
+
返回:
|
59 |
+
- 处理后的光谱数据
|
60 |
+
"""
|
61 |
+
if method == 'Abs-SG0':
|
62 |
+
return apply_sg_filter(spectra, deriv=0)
|
63 |
+
elif method == 'Abs-SG0-SNV':
|
64 |
+
sg_spectra = apply_sg_filter(spectra, deriv=0)
|
65 |
+
return apply_snv(sg_spectra)
|
66 |
+
elif method == 'Abs-SG1':
|
67 |
+
return apply_sg_filter(spectra, deriv=1)
|
68 |
+
elif method == 'Abs-SG1-SNV':
|
69 |
+
sg_spectra = apply_sg_filter(spectra, deriv=1)
|
70 |
+
return apply_snv(sg_spectra)
|
71 |
+
elif method == 'Abs-SG2':
|
72 |
+
return apply_sg_filter(spectra, deriv=2)
|
73 |
+
elif method == 'Abs-SG2-SNV':
|
74 |
+
sg_spectra = apply_sg_filter(spectra, deriv=2)
|
75 |
+
return apply_snv(sg_spectra)
|
76 |
+
else:
|
77 |
+
raise ValueError(f"Unsupported method: {method}")
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
def remove_wavelength_bands(spectra, wavelengths):
|
83 |
+
"""
|
84 |
+
移除400-499.5nm和2450-2499.5nm的波段
|
85 |
+
|
86 |
+
参数:
|
87 |
+
- spectra: 输入光谱数据,形状为(n_samples, n_wavelengths)
|
88 |
+
- wavelengths: 波长值数组
|
89 |
+
|
90 |
+
返回:
|
91 |
+
- 处理后的光谱数据和对应的波长值
|
92 |
+
"""
|
93 |
+
# 创建掩码,保留所需波段
|
94 |
+
mask = ~((wavelengths >= 400) & (wavelengths <= 499.5) |
|
95 |
+
(wavelengths >= 2450) & (wavelengths <= 2499.5))
|
96 |
+
|
97 |
+
# 应用掩码
|
98 |
+
filtered_spectra = spectra[:, mask]
|
99 |
+
filtered_wavelengths = wavelengths[mask]
|
100 |
+
|
101 |
+
return filtered_spectra, filtered_wavelengths
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
def downsample_spectra(spectra, wavelengths, bin_size):
|
107 |
+
"""
|
108 |
+
对光谱数据进行降采样
|
109 |
+
参数:
|
110 |
+
- spectra: 输入光谱数据,形状为(n_samples, n_wavelengths)
|
111 |
+
- wavelengths: 波长值数组
|
112 |
+
- bin_size: 降采样窗口大小(5nm、10nm或15nm)
|
113 |
+
|
114 |
+
返回:
|
115 |
+
- 降采样后的光谱数据和对应的波长值
|
116 |
+
"""
|
117 |
+
# 计算每个bin的边界
|
118 |
+
bins = np.arange(wavelengths[0], wavelengths[-1] + bin_size, bin_size)
|
119 |
+
|
120 |
+
# 初始化结果数组
|
121 |
+
n_bins = len(bins) - 1
|
122 |
+
downsampled_spectra = np.zeros((spectra.shape[0], n_bins))
|
123 |
+
downsampled_wavelengths = np.zeros(n_bins)
|
124 |
+
|
125 |
+
# 对每个bin进行平均
|
126 |
+
for i in range(n_bins):
|
127 |
+
mask = (wavelengths >= bins[i]) & (wavelengths < bins[i+1])
|
128 |
+
if np.any(mask):
|
129 |
+
downsampled_spectra[:, i] = np.mean(spectra[:, mask], axis=1)
|
130 |
+
downsampled_wavelengths[i] = np.mean([bins[i], bins[i+1]])
|
131 |
+
|
132 |
+
return downsampled_spectra, downsampled_wavelengths
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
def preprocess_with_downsampling(spectra, wavelengths, bin_size=5):
|
139 |
+
"""
|
140 |
+
完整的预处理流程:移除特定波段并进行降采样
|
141 |
+
|
142 |
+
参数:
|
143 |
+
- spectra: 输入光谱数据,形状为(n_samples, n_wavelengths)
|
144 |
+
- wavelengths: 波长值数组
|
145 |
+
- bin_size: 降采样窗口大小(5nm、10nm或15nm)
|
146 |
+
|
147 |
+
返回:
|
148 |
+
- 处理后的光谱数据和对应的波长值
|
149 |
+
"""
|
150 |
+
# 首先移除指定波段
|
151 |
+
filtered_spectra, filtered_wavelengths = remove_wavelength_bands(spectra, wavelengths)
|
152 |
+
|
153 |
+
# 然后进行降采样
|
154 |
+
downsampled_spectra, downsampled_wavelengths = downsample_spectra(
|
155 |
+
filtered_spectra, filtered_wavelengths, bin_size)
|
156 |
+
|
157 |
+
return downsampled_spectra, downsampled_wavelengths
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
def plot_processed_spectra_with_range(original_spectra, wavelengths=None):
|
163 |
+
"""
|
164 |
+
绘制处理方法的光谱图,包括平均曲线和范围
|
165 |
+
|
166 |
+
参数:
|
167 |
+
- original_spectra: 原始光谱数据,形状为(n_samples, n_wavelengths)
|
168 |
+
- wavelengths: 波长值,如果为None则使用索引值
|
169 |
+
"""
|
170 |
+
methods = ['Abs-SG0', 'Abs-SG0-SNV', 'Abs-SG1',
|
171 |
+
'Abs-SG1-SNV', 'Abs-SG2', 'Abs-SG2-SNV']
|
172 |
+
|
173 |
+
if wavelengths is None:
|
174 |
+
wavelengths = np.arange(original_spectra.shape[1])
|
175 |
+
|
176 |
+
fig, axes = plt.subplots(2, 3, figsize=(18, 10)) # 布局:2行3列
|
177 |
+
axes = axes.ravel()
|
178 |
+
|
179 |
+
for i, method in enumerate(methods):
|
180 |
+
processed = process_spectra(original_spectra, method) # 获取处理后的数据
|
181 |
+
mean_curve = np.mean(processed, axis=0) # 平均光谱曲线
|
182 |
+
min_curve = np.min(processed, axis=0) # 最小值光谱
|
183 |
+
max_curve = np.max(processed, axis=0) # 最大值光谱
|
184 |
+
|
185 |
+
# 绘制范围
|
186 |
+
axes[i].fill_between(wavelengths, min_curve, max_curve, color='skyblue', alpha=0.3, label='Range')
|
187 |
+
# 绘制平均曲线
|
188 |
+
axes[i].plot(wavelengths, mean_curve, color='steelblue', label='Average Curve')
|
189 |
+
|
190 |
+
# 设置标题和图例
|
191 |
+
axes[i].set_title(f'({chr(97 + i)}) {method}', loc='center', fontsize=12) # a, b, c...
|
192 |
+
axes[i].set_xlabel('Wavelength/nm', fontsize=10)
|
193 |
+
axes[i].set_ylabel('Absorbance', fontsize=10)
|
194 |
+
axes[i].legend()
|
195 |
+
axes[i].grid(True)
|
196 |
+
|
197 |
+
# 调整布局
|
198 |
+
plt.tight_layout(h_pad=2.5, w_pad=3.0)
|
199 |
+
plt.show()
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
# 示例调用
|
208 |
+
if __name__ == '__main__':
|
209 |
+
# 1. 加载数据
|
210 |
+
file_path = 'LUCAS.2009_abs.csv'
|
211 |
+
target_columns = ['pH.in.CaCl2', 'pH.in.H2O', 'OC', 'CaCO3', 'N', 'P', 'K', 'CEC']
|
212 |
+
X_train, X_test, y_train, y_test ,wavelengths= load_soil_data(file_path, target_columns)
|
213 |
+
|
214 |
+
# 2. 将数据重塑为2D
|
215 |
+
X_train_2d = X_train.reshape(X_train.shape[0], -1)
|
216 |
+
|
217 |
+
# 4. 展示原始数据的光谱处理结果
|
218 |
+
print("\n=== 光谱预处理结果 ===")
|
219 |
+
plot_processed_spectra_with_range(X_train_2d, wavelengths)
|
220 |
+
|
221 |
+
# 5. 移除特定波段并进行不同程度的降采样
|
222 |
+
print("\n=== 波段移除和降采样结果 ===")
|
223 |
+
bin_sizes = [5, 10, 15] # 不同的降采样窗口大小
|
224 |
+
|
225 |
+
# 为不同的降采样结果创建一个新的图
|
226 |
+
plt.figure(figsize=(15, 5))
|
227 |
+
|
228 |
+
for i, bin_size in enumerate(bin_sizes):
|
229 |
+
# 处理数据
|
230 |
+
processed_spectra, processed_wavelengths = preprocess_with_downsampling(
|
231 |
+
X_train_2d, wavelengths, bin_size)
|
232 |
+
|
233 |
+
# 打印信息
|
234 |
+
print(f"\n使用 {bin_size}nm 降采样:")
|
235 |
+
print(f"处理后的光谱形状: {processed_spectra.shape}")
|
236 |
+
print(f"波长数量: {len(processed_wavelengths)}")
|
237 |
+
|
238 |
+
# 绘制降采样结果
|
239 |
+
plt.subplot(1, 3, i+1)
|
240 |
+
mean_curve = np.mean(processed_spectra, axis=0)
|
241 |
+
std_curve = np.std(processed_spectra, axis=0)
|
242 |
+
|
243 |
+
plt.plot(processed_wavelengths, mean_curve, 'b-', label=f'Mean ({bin_size}nm)')
|
244 |
+
plt.fill_between(processed_wavelengths,
|
245 |
+
mean_curve - std_curve,
|
246 |
+
mean_curve + std_curve,
|
247 |
+
color='skyblue', alpha=0.2, label='Standard Deviation Range')
|
248 |
+
plt.title(f'Downsampling {bin_size}nm\n(Wavelengths: {len(processed_wavelengths)})')
|
249 |
+
plt.xlabel('Wavelength (nm)')
|
250 |
+
plt.ylabel('Absorbance')
|
251 |
+
plt.legend()
|
252 |
+
plt.grid(True)
|
253 |
+
|
254 |
+
plt.tight_layout()
|
255 |
+
plt.show()
|
256 |
+
|
257 |
+
# 6. 展示完整预处理流程的示例
|
258 |
+
print("\n=== 完整预处理流程示例 ===")
|
259 |
+
# 先进行光谱预处理
|
260 |
+
processed_spectra = process_spectra(X_train_2d, method='Abs-SG0-SNV')
|
261 |
+
# 然后进行波段移除和降采样
|
262 |
+
final_spectra, final_wavelengths = preprocess_with_downsampling(
|
263 |
+
processed_spectra, wavelengths, bin_size=10)
|
264 |
+
print(f"最终处理后的数据形状: {final_spectra.shape}")
|
265 |
+
print(f"最终波长数量: {len(final_wavelengths)}")
|
resnet1d.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
resnet for 1-d signal data, pytorch version
|
3 |
+
|
4 |
+
Shenda Hong, Oct 2019
|
5 |
+
"""
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from collections import Counter
|
9 |
+
from tqdm import tqdm
|
10 |
+
from matplotlib import pyplot as plt
|
11 |
+
from sklearn.metrics import classification_report
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.optim as optim
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torch.utils.data import Dataset, DataLoader
|
18 |
+
|
19 |
+
class MyDataset(Dataset):
|
20 |
+
def __init__(self, data, label):
|
21 |
+
self.data = data
|
22 |
+
self.label = label
|
23 |
+
|
24 |
+
def __getitem__(self, index):
|
25 |
+
return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.long))
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return len(self.data)
|
29 |
+
|
30 |
+
class MyConv1dPadSame(nn.Module):
|
31 |
+
"""
|
32 |
+
extend nn.Conv1d to support SAME padding
|
33 |
+
"""
|
34 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1):
|
35 |
+
super(MyConv1dPadSame, self).__init__()
|
36 |
+
self.in_channels = in_channels
|
37 |
+
self.out_channels = out_channels
|
38 |
+
self.kernel_size = kernel_size
|
39 |
+
self.stride = stride
|
40 |
+
self.groups = groups
|
41 |
+
self.conv = torch.nn.Conv1d(
|
42 |
+
in_channels=self.in_channels,
|
43 |
+
out_channels=self.out_channels,
|
44 |
+
kernel_size=self.kernel_size,
|
45 |
+
stride=self.stride,
|
46 |
+
groups=self.groups)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
|
50 |
+
net = x
|
51 |
+
|
52 |
+
# compute pad shape
|
53 |
+
in_dim = net.shape[-1]
|
54 |
+
out_dim = (in_dim + self.stride - 1) // self.stride
|
55 |
+
p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
|
56 |
+
pad_left = p // 2
|
57 |
+
pad_right = p - pad_left
|
58 |
+
net = F.pad(net, (pad_left, pad_right), "constant", 0)
|
59 |
+
|
60 |
+
net = self.conv(net)
|
61 |
+
|
62 |
+
return net
|
63 |
+
|
64 |
+
class MyMaxPool1dPadSame(nn.Module):
|
65 |
+
"""
|
66 |
+
extend nn.MaxPool1d to support SAME padding
|
67 |
+
"""
|
68 |
+
def __init__(self, kernel_size):
|
69 |
+
super(MyMaxPool1dPadSame, self).__init__()
|
70 |
+
self.kernel_size = kernel_size
|
71 |
+
self.stride = 1
|
72 |
+
self.max_pool = torch.nn.MaxPool1d(kernel_size=self.kernel_size)
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
|
76 |
+
net = x
|
77 |
+
|
78 |
+
# compute pad shape
|
79 |
+
in_dim = net.shape[-1]
|
80 |
+
out_dim = (in_dim + self.stride - 1) // self.stride
|
81 |
+
p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
|
82 |
+
pad_left = p // 2
|
83 |
+
pad_right = p - pad_left
|
84 |
+
net = F.pad(net, (pad_left, pad_right), "constant", 0)
|
85 |
+
|
86 |
+
net = self.max_pool(net)
|
87 |
+
|
88 |
+
return net
|
89 |
+
|
90 |
+
class BasicBlock(nn.Module):
|
91 |
+
"""
|
92 |
+
ResNet Basic Block
|
93 |
+
"""
|
94 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, groups, downsample, use_bn, use_do, is_first_block=False):
|
95 |
+
super(BasicBlock, self).__init__()
|
96 |
+
|
97 |
+
self.in_channels = in_channels
|
98 |
+
self.kernel_size = kernel_size
|
99 |
+
self.out_channels = out_channels
|
100 |
+
self.stride = stride
|
101 |
+
self.groups = groups
|
102 |
+
self.downsample = downsample
|
103 |
+
if self.downsample:
|
104 |
+
self.stride = stride
|
105 |
+
else:
|
106 |
+
self.stride = 1
|
107 |
+
self.is_first_block = is_first_block
|
108 |
+
self.use_bn = use_bn
|
109 |
+
self.use_do = use_do
|
110 |
+
|
111 |
+
# the first conv
|
112 |
+
self.bn1 = nn.BatchNorm1d(in_channels)
|
113 |
+
self.relu1 = nn.ReLU()
|
114 |
+
self.do1 = nn.Dropout(p=0.5)
|
115 |
+
self.conv1 = MyConv1dPadSame(
|
116 |
+
in_channels=in_channels,
|
117 |
+
out_channels=out_channels,
|
118 |
+
kernel_size=kernel_size,
|
119 |
+
stride=self.stride,
|
120 |
+
groups=self.groups)
|
121 |
+
|
122 |
+
# the second conv
|
123 |
+
self.bn2 = nn.BatchNorm1d(out_channels)
|
124 |
+
self.relu2 = nn.ReLU()
|
125 |
+
self.do2 = nn.Dropout(p=0.5)
|
126 |
+
self.conv2 = MyConv1dPadSame(
|
127 |
+
in_channels=out_channels,
|
128 |
+
out_channels=out_channels,
|
129 |
+
kernel_size=kernel_size,
|
130 |
+
stride=1,
|
131 |
+
groups=self.groups)
|
132 |
+
|
133 |
+
self.max_pool = MyMaxPool1dPadSame(kernel_size=self.stride)
|
134 |
+
|
135 |
+
def forward(self, x):
|
136 |
+
|
137 |
+
identity = x
|
138 |
+
|
139 |
+
# the first conv
|
140 |
+
out = x
|
141 |
+
if not self.is_first_block:
|
142 |
+
if self.use_bn:
|
143 |
+
out = self.bn1(out)
|
144 |
+
out = self.relu1(out)
|
145 |
+
if self.use_do:
|
146 |
+
out = self.do1(out)
|
147 |
+
out = self.conv1(out)
|
148 |
+
|
149 |
+
# the second conv
|
150 |
+
if self.use_bn:
|
151 |
+
out = self.bn2(out)
|
152 |
+
out = self.relu2(out)
|
153 |
+
if self.use_do:
|
154 |
+
out = self.do2(out)
|
155 |
+
out = self.conv2(out)
|
156 |
+
|
157 |
+
# if downsample, also downsample identity
|
158 |
+
if self.downsample:
|
159 |
+
identity = self.max_pool(identity)
|
160 |
+
|
161 |
+
# if expand channel, also pad zeros to identity
|
162 |
+
if self.out_channels != self.in_channels:
|
163 |
+
identity = identity.transpose(-1,-2)
|
164 |
+
ch1 = (self.out_channels-self.in_channels)//2
|
165 |
+
ch2 = self.out_channels-self.in_channels-ch1
|
166 |
+
identity = F.pad(identity, (ch1, ch2), "constant", 0)
|
167 |
+
identity = identity.transpose(-1,-2)
|
168 |
+
|
169 |
+
# shortcut
|
170 |
+
out += identity
|
171 |
+
|
172 |
+
return out
|
173 |
+
|
174 |
+
class ResNet1D(nn.Module):
|
175 |
+
"""
|
176 |
+
|
177 |
+
Input:
|
178 |
+
X: (n_samples, n_channel, n_length)
|
179 |
+
Y: (n_samples)
|
180 |
+
|
181 |
+
Output:
|
182 |
+
out: (n_samples)
|
183 |
+
|
184 |
+
Pararmetes:
|
185 |
+
in_channels: dim of input, the same as n_channel
|
186 |
+
base_filters: number of filters in the first several Conv layer, it will double at every 4 layers
|
187 |
+
kernel_size: width of kernel
|
188 |
+
stride: stride of kernel moving
|
189 |
+
groups: set larget to 1 as ResNeXt
|
190 |
+
n_block: number of blocks
|
191 |
+
n_classes: number of classes
|
192 |
+
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(self, in_channels, base_filters, kernel_size, stride, groups, n_block, n_classes, downsample_gap=2, increasefilter_gap=4, use_bn=True, use_do=True, verbose=False):
|
196 |
+
super(ResNet1D, self).__init__()
|
197 |
+
|
198 |
+
self.verbose = verbose
|
199 |
+
self.n_block = n_block
|
200 |
+
self.kernel_size = kernel_size
|
201 |
+
self.stride = stride
|
202 |
+
self.groups = groups
|
203 |
+
self.use_bn = use_bn
|
204 |
+
self.use_do = use_do
|
205 |
+
|
206 |
+
self.downsample_gap = downsample_gap # 2 for base model
|
207 |
+
self.increasefilter_gap = increasefilter_gap # 4 for base model
|
208 |
+
|
209 |
+
# first block
|
210 |
+
self.first_block_conv = MyConv1dPadSame(in_channels=in_channels, out_channels=base_filters, kernel_size=self.kernel_size, stride=1)
|
211 |
+
self.first_block_bn = nn.BatchNorm1d(base_filters)
|
212 |
+
self.first_block_relu = nn.ReLU()
|
213 |
+
out_channels = base_filters
|
214 |
+
|
215 |
+
# residual blocks
|
216 |
+
self.basicblock_list = nn.ModuleList()
|
217 |
+
for i_block in range(self.n_block):
|
218 |
+
# is_first_block
|
219 |
+
if i_block == 0:
|
220 |
+
is_first_block = True
|
221 |
+
else:
|
222 |
+
is_first_block = False
|
223 |
+
# downsample at every self.downsample_gap blocks
|
224 |
+
if i_block % self.downsample_gap == 1:
|
225 |
+
downsample = True
|
226 |
+
else:
|
227 |
+
downsample = False
|
228 |
+
# in_channels and out_channels
|
229 |
+
if is_first_block:
|
230 |
+
in_channels = base_filters
|
231 |
+
out_channels = in_channels
|
232 |
+
else:
|
233 |
+
# increase filters at every self.increasefilter_gap blocks
|
234 |
+
in_channels = int(base_filters*2**((i_block-1)//self.increasefilter_gap))
|
235 |
+
if (i_block % self.increasefilter_gap == 0) and (i_block != 0):
|
236 |
+
out_channels = in_channels * 2
|
237 |
+
else:
|
238 |
+
out_channels = in_channels
|
239 |
+
|
240 |
+
tmp_block = BasicBlock(
|
241 |
+
in_channels=in_channels,
|
242 |
+
out_channels=out_channels,
|
243 |
+
kernel_size=self.kernel_size,
|
244 |
+
stride = self.stride,
|
245 |
+
groups = self.groups,
|
246 |
+
downsample=downsample,
|
247 |
+
use_bn = self.use_bn,
|
248 |
+
use_do = self.use_do,
|
249 |
+
is_first_block=is_first_block)
|
250 |
+
self.basicblock_list.append(tmp_block)
|
251 |
+
|
252 |
+
# final prediction
|
253 |
+
self.final_bn = nn.BatchNorm1d(out_channels)
|
254 |
+
self.final_relu = nn.ReLU(inplace=True)
|
255 |
+
self.do = nn.Dropout(p=0.3)
|
256 |
+
self.dense = nn.Linear(out_channels, n_classes)
|
257 |
+
# self.softmax = nn.Softmax(dim=1)
|
258 |
+
|
259 |
+
def forward(self, x):
|
260 |
+
|
261 |
+
out = x
|
262 |
+
|
263 |
+
# first conv
|
264 |
+
if self.verbose:
|
265 |
+
print('input shape', out.shape)
|
266 |
+
out = self.first_block_conv(out)
|
267 |
+
if self.verbose:
|
268 |
+
print('after first conv', out.shape)
|
269 |
+
if self.use_bn:
|
270 |
+
out = self.first_block_bn(out)
|
271 |
+
out = self.first_block_relu(out)
|
272 |
+
|
273 |
+
# residual blocks, every block has two conv
|
274 |
+
for i_block in range(self.n_block):
|
275 |
+
net = self.basicblock_list[i_block]
|
276 |
+
if self.verbose:
|
277 |
+
print('i_block: {0}, in_channels: {1}, out_channels: {2}, downsample: {3}'.format(i_block, net.in_channels, net.out_channels, net.downsample))
|
278 |
+
out = net(out)
|
279 |
+
if self.verbose:
|
280 |
+
print(out.shape)
|
281 |
+
|
282 |
+
# final prediction
|
283 |
+
if self.use_bn:
|
284 |
+
out = self.final_bn(out)
|
285 |
+
out = self.final_relu(out)
|
286 |
+
out = out.mean(-1)
|
287 |
+
if self.verbose:
|
288 |
+
print('final pooling', out.shape)
|
289 |
+
# out = self.do(out)
|
290 |
+
out = self.dense(out)
|
291 |
+
if self.verbose:
|
292 |
+
print('dense', out.shape)
|
293 |
+
# out = self.softmax(out)
|
294 |
+
if self.verbose:
|
295 |
+
print('softmax', out.shape)
|
296 |
+
|
297 |
+
return out
|
resnet1d_multitask.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.models as models
|
4 |
+
import resnet1d
|
5 |
+
|
6 |
+
__all__ = ['ResNet1D_MultiTask', 'get_model']
|
7 |
+
|
8 |
+
class ResNet1D_MultiTask(resnet1d.ResNet1D):
|
9 |
+
def __init__(self, *args, **kwargs):
|
10 |
+
super().__init__(*args, **kwargs)
|
11 |
+
|
12 |
+
# 获取特征维度
|
13 |
+
in_features = self.dense.in_features
|
14 |
+
|
15 |
+
# 移除原始的预测层
|
16 |
+
delattr(self, 'dense')
|
17 |
+
# 添加多任务预测头
|
18 |
+
self.prediction_head = nn.Sequential(
|
19 |
+
# 第一层:512 -> 256
|
20 |
+
nn.Linear(in_features, in_features//2),
|
21 |
+
nn.BatchNorm1d(in_features//2),
|
22 |
+
nn.ReLU(),
|
23 |
+
nn.Dropout(p=0.3),
|
24 |
+
|
25 |
+
# 第二层:256 -> 128
|
26 |
+
nn.Linear(in_features//2, in_features//4),
|
27 |
+
nn.BatchNorm1d(in_features//4),
|
28 |
+
nn.ReLU(),
|
29 |
+
nn.Dropout(p=0.3),
|
30 |
+
|
31 |
+
# 输出层:128 -> 8
|
32 |
+
nn.Linear(in_features//4, 8)
|
33 |
+
)
|
34 |
+
def forward(self, x):
|
35 |
+
# 获取特征提取器的输出
|
36 |
+
out = x
|
37 |
+
|
38 |
+
# first conv
|
39 |
+
out = self.first_block_conv(out)
|
40 |
+
if self.use_bn:
|
41 |
+
out = self.first_block_bn(out)
|
42 |
+
out = self.first_block_relu(out)
|
43 |
+
|
44 |
+
# residual blocks
|
45 |
+
for i_block in range(self.n_block):
|
46 |
+
net = self.basicblock_list[i_block]
|
47 |
+
out = net(out)
|
48 |
+
|
49 |
+
# 特征聚合
|
50 |
+
if self.use_bn:
|
51 |
+
out = self.final_bn(out)
|
52 |
+
out = self.final_relu(out)
|
53 |
+
out = out.mean(-1) # 全局平均池化
|
54 |
+
|
55 |
+
out=self.prediction_head(out)
|
56 |
+
|
57 |
+
return out # 输出 8 个指标的预测值
|
58 |
+
|
59 |
+
|
60 |
+
def get_model(model_type):
|
61 |
+
if model_type == 'A': # ResNet18
|
62 |
+
return ResNet1D_MultiTask(
|
63 |
+
in_channels=1,
|
64 |
+
base_filters=32, # 减小base_filters,降低显存占用
|
65 |
+
kernel_size=3, # 使用3x3卷积核
|
66 |
+
stride=2,
|
67 |
+
groups=1,
|
68 |
+
n_block=8, # ResNet18的配置
|
69 |
+
n_classes=8
|
70 |
+
)
|
71 |
+
elif model_type == 'B': # ResNet34
|
72 |
+
return ResNet1D_MultiTask(
|
73 |
+
in_channels=1,
|
74 |
+
base_filters=32, # 调整base_filters
|
75 |
+
kernel_size=3, # 使用3x3卷积核
|
76 |
+
stride=2,
|
77 |
+
groups=1,
|
78 |
+
n_block=16, # ResNet34的配置
|
79 |
+
n_classes=8
|
80 |
+
)
|
81 |
+
elif model_type == 'C': # ResNet50
|
82 |
+
return ResNet1D_MultiTask(
|
83 |
+
in_channels=1,
|
84 |
+
base_filters=32, # 调整base_filters
|
85 |
+
kernel_size=3, # 使用3x3卷积核
|
86 |
+
stride=2,
|
87 |
+
groups=1,
|
88 |
+
n_block=24, # ResNet50的配置
|
89 |
+
n_classes=8
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
raise ValueError("Invalid model type. Choose 'A' for ResNet18, 'B' for ResNet34, or 'C' for ResNet50")
|
93 |
+
|
94 |
+
def print_model_info():
|
95 |
+
"""
|
96 |
+
打印模型关键信息(简化版)
|
97 |
+
"""
|
98 |
+
try:
|
99 |
+
from torchsummary import summary
|
100 |
+
except ImportError:
|
101 |
+
print("请先安装torchsummary: pip install torchsummary")
|
102 |
+
return
|
103 |
+
|
104 |
+
import torch
|
105 |
+
device = torch.device("cpu")
|
106 |
+
|
107 |
+
model_types = ['A', 'B', 'C']
|
108 |
+
model_names = {
|
109 |
+
'A': 'ResNet18',
|
110 |
+
'B': 'ResNet34',
|
111 |
+
'C': 'ResNet50'
|
112 |
+
}
|
113 |
+
|
114 |
+
# 模型配置信息
|
115 |
+
model_configs = {
|
116 |
+
'A': {'n_block': 8, 'base_filters': 32, 'kernel_size': 3},
|
117 |
+
'B': {'n_block': 16, 'base_filters': 32, 'kernel_size': 3},
|
118 |
+
'C': {'n_block': 24, 'base_filters': 32, 'kernel_size': 3}
|
119 |
+
}
|
120 |
+
|
121 |
+
print("\n" + "="*50)
|
122 |
+
print(f"{'LUCAS土壤光谱分析模型架构':^48}")
|
123 |
+
print("="*50)
|
124 |
+
print(f"{'输入: (batch_size=15228, channels=1, length=130)':^48}")
|
125 |
+
print(f"{'输出: 8个土壤属性预测值':^48}")
|
126 |
+
print("-"*50)
|
127 |
+
|
128 |
+
for model_type in model_types:
|
129 |
+
model = get_model(model_type).to(device)
|
130 |
+
config = model_configs[model_type]
|
131 |
+
total_params = sum(p.numel() for p in model.parameters())
|
132 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
133 |
+
|
134 |
+
print(f"\n[Model {model_type}: {model_names[model_type]}]")
|
135 |
+
print(f"网络深度: {config['n_block']} blocks")
|
136 |
+
print(f"基础通道数: {config['base_filters']}")
|
137 |
+
print(f"卷积核大小: {config['kernel_size']}")
|
138 |
+
print(f"总参数量: {total_params:,}")
|
139 |
+
print(f"可训练参数: {trainable_params:,}")
|
140 |
+
|
141 |
+
# 只打印主要层的信息
|
142 |
+
main_layers = {}
|
143 |
+
for name, module in model.named_children():
|
144 |
+
params = sum(p.numel() for p in module.parameters())
|
145 |
+
if params > 0 and params/total_params > 0.05: # 只显示占比>5%的层
|
146 |
+
main_layers[name] = params
|
147 |
+
|
148 |
+
if main_layers:
|
149 |
+
print("\n主要层结构:")
|
150 |
+
for name, params in main_layers.items():
|
151 |
+
print(f" {name:15}: {params:,} ({params/total_params*100:.1f}%)")
|
152 |
+
print("-"*50)
|
153 |
+
|
154 |
+
if __name__ == '__main__':
|
155 |
+
print_model_info()
|
results/loss_curves_modelA_20241207_145904.png
ADDED
![]() |
Git LFS Details
|
results/loss_curves_modelB_20241207_151513.png
ADDED
![]() |
Git LFS Details
|
results/loss_curves_modelC_20241207_153102.png
ADDED
![]() |
Git LFS Details
|
results/metrics_modelA_20241207_145904.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Results for Model A generated at 2024-12-07 14:59:04
|
2 |
+
--------------------------------------------------
|
3 |
+
Indicator 1 (pH.in.CaCl2) - RMSE: 1.2717, R2: -0.2851
|
4 |
+
Indicator 2 (pH.in.H2O) - RMSE: 1.2939, R2: -0.5375
|
5 |
+
Indicator 3 (OC) - RMSE: 9.3670, R2: 0.0148
|
6 |
+
Indicator 4 (CaCO3) - RMSE: 11.4724, R2: -0.1613
|
7 |
+
Indicator 5 (N) - RMSE: 1.8129, R2: 0.1423
|
8 |
+
Indicator 6 (P) - RMSE: 6.2296, R2: -0.1063
|
9 |
+
Indicator 7 (K) - RMSE: 15.0763, R2: -0.2626
|
10 |
+
Indicator 8 (CEC) - RMSE: 3.6354, R2: -0.0210
|
11 |
+
|
12 |
+
Average Test Loss: 30.3030
|
results/metrics_modelB_20241207_151513.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Results for Model B generated at 2024-12-07 15:15:13
|
2 |
+
--------------------------------------------------
|
3 |
+
Indicator 1 (pH.in.CaCl2) - RMSE: 1.2917, R2: -0.3678
|
4 |
+
Indicator 2 (pH.in.H2O) - RMSE: 1.2833, R2: -0.4877
|
5 |
+
Indicator 3 (OC) - RMSE: 8.2344, R2: 0.4116
|
6 |
+
Indicator 4 (CaCO3) - RMSE: 10.9131, R2: 0.0491
|
7 |
+
Indicator 5 (N) - RMSE: 1.6586, R2: 0.3991
|
8 |
+
Indicator 6 (P) - RMSE: 6.1583, R2: -0.0565
|
9 |
+
Indicator 7 (K) - RMSE: 14.5700, R2: -0.1014
|
10 |
+
Indicator 8 (CEC) - RMSE: 3.5120, R2: 0.1108
|
11 |
+
|
12 |
+
Average Test Loss: 27.7304
|
results/metrics_modelC_20241207_153102.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Results for Model C generated at 2024-12-07 15:31:02
|
2 |
+
--------------------------------------------------
|
3 |
+
Indicator 1 (pH.in.CaCl2) - RMSE: 1.1903, R2: 0.0136
|
4 |
+
Indicator 2 (pH.in.H2O) - RMSE: 1.1664, R2: -0.0155
|
5 |
+
Indicator 3 (OC) - RMSE: 8.0283, R2: 0.4683
|
6 |
+
Indicator 4 (CaCO3) - RMSE: 10.7531, R2: 0.1037
|
7 |
+
Indicator 5 (N) - RMSE: 1.6682, R2: 0.3850
|
8 |
+
Indicator 6 (P) - RMSE: 6.1934, R2: -0.0808
|
9 |
+
Indicator 7 (K) - RMSE: 14.2200, R2: 0.0007
|
10 |
+
Indicator 8 (CEC) - RMSE: 3.4098, R2: 0.2098
|
11 |
+
|
12 |
+
Average Test Loss: 26.7518
|
results/modelC_scatter_CEC_20241207_153102.png
ADDED
![]() |
results/modelC_scatter_CaCO3_20241207_153102.png
ADDED
![]() |
results/modelC_scatter_K_20241207_153102.png
ADDED
![]() |
results/modelC_scatter_N_20241207_153102.png
ADDED
![]() |
results/modelC_scatter_OC_20241207_153102.png
ADDED
![]() |
results/modelC_scatter_P_20241207_153102.png
ADDED
![]() |
results/modelC_scatter_pH.in.CaCl2_20241207_153102.png
ADDED
![]() |
Git LFS Details
|
results/modelC_scatter_pH.in.H2O_20241207_153102.png
ADDED
![]() |
results/training_metrics_modelA_20241207_145904.png
ADDED
![]() |
Git LFS Details
|
results/training_metrics_modelB_20241207_151513.png
ADDED
![]() |
Git LFS Details
|
results/training_metrics_modelC_20241207_153102.png
ADDED
![]() |
Git LFS Details
|
results1/loss_curves_modelC_20241207_155007_Abs-SG0.png
ADDED
![]() |
Git LFS Details
|
results1/loss_curves_modelC_20241207_160632_Abs-SG0-SNV.png
ADDED
![]() |
Git LFS Details
|
results1/loss_curves_modelC_20241207_162218_Abs-SG1.png
ADDED
![]() |
Git LFS Details
|
results1/loss_curves_modelC_20241207_163811_Abs-SG1-SNV.png
ADDED
![]() |
Git LFS Details
|
results1/loss_curves_modelC_20241207_165344_Abs-SG2.png
ADDED
![]() |
Git LFS Details
|
results1/loss_curves_modelC_20241207_170911_Abs-SG2-SNV.png
ADDED
![]() |
Git LFS Details
|
results1/metrics_modelC_20241210_142058_Abs-SG0.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Results for Model C generated at 2024-12-10 14:20:58
|
2 |
+
--------------------------------------------------
|
3 |
+
Indicator 1 (pH.in.CaCl2) - RMSE: 1.1329, R2: 0.1905
|
4 |
+
Indicator 2 (pH.in.H2O) - RMSE: 1.1135, R2: 0.1565
|
5 |
+
Indicator 3 (OC) - RMSE: 8.1022, R2: 0.4485
|
6 |
+
Indicator 4 (CaCO3) - RMSE: 10.7757, R2: 0.0961
|
7 |
+
Indicator 5 (N) - RMSE: 1.6102, R2: 0.4662
|
8 |
+
Indicator 6 (P) - RMSE: 6.1459, R2: -0.0480
|
9 |
+
Indicator 7 (K) - RMSE: 13.9761, R2: 0.0675
|
10 |
+
Indicator 8 (CEC) - RMSE: 3.2445, R2: 0.3522
|
11 |
+
|
12 |
+
Average Test Loss: 27.7911
|
results1/metrics_modelC_20241210_143517_Abs-SG0-SNV.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Results for Model C generated at 2024-12-10 14:35:17
|
2 |
+
--------------------------------------------------
|
3 |
+
Indicator 1 (pH.in.CaCl2) - RMSE: 1.0763, R2: 0.3406
|
4 |
+
Indicator 2 (pH.in.H2O) - RMSE: 1.0515, R2: 0.3294
|
5 |
+
Indicator 3 (OC) - RMSE: 6.0624, R2: 0.8271
|
6 |
+
Indicator 4 (CaCO3) - RMSE: 8.5423, R2: 0.6430
|
7 |
+
Indicator 5 (N) - RMSE: 1.4266, R2: 0.6711
|
8 |
+
Indicator 6 (P) - RMSE: 6.1328, R2: -0.0392
|
9 |
+
Indicator 7 (K) - RMSE: 13.8842, R2: 0.0918
|
10 |
+
Indicator 8 (CEC) - RMSE: 3.3248, R2: 0.2857
|
11 |
+
|
12 |
+
Average Test Loss: 21.7375
|
results1/metrics_modelC_20241210_144926_Abs-SG1.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Results for Model C generated at 2024-12-10 14:49:26
|
2 |
+
--------------------------------------------------
|
3 |
+
Indicator 1 (pH.in.CaCl2) - RMSE: 1.1943, R2: 0.0002
|
4 |
+
Indicator 2 (pH.in.H2O) - RMSE: 1.1764, R2: -0.0507
|
5 |
+
Indicator 3 (OC) - RMSE: 9.0927, R2: 0.1252
|
6 |
+
Indicator 4 (CaCO3) - RMSE: 11.3249, R2: -0.1028
|
7 |
+
Indicator 5 (N) - RMSE: 1.8572, R2: 0.0553
|
8 |
+
Indicator 6 (P) - RMSE: 6.0900, R2: -0.0105
|
9 |
+
Indicator 7 (K) - RMSE: 14.0982, R2: 0.0345
|
10 |
+
Indicator 8 (CEC) - RMSE: 3.6072, R2: 0.0103
|
11 |
+
|
12 |
+
Average Test Loss: 30.6785
|
results1/metrics_modelC_20241210_150332_Abs-SG1-SNV.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Results for Model C generated at 2024-12-10 15:03:32
|
2 |
+
--------------------------------------------------
|
3 |
+
Indicator 1 (pH.in.CaCl2) - RMSE: 0.9366, R2: 0.6220
|
4 |
+
Indicator 2 (pH.in.H2O) - RMSE: 0.9237, R2: 0.6007
|
5 |
+
Indicator 3 (OC) - RMSE: 6.9911, R2: 0.6943
|
6 |
+
Indicator 4 (CaCO3) - RMSE: 7.3689, R2: 0.8023
|
7 |
+
Indicator 5 (N) - RMSE: 1.6105, R2: 0.4658
|
8 |
+
Indicator 6 (P) - RMSE: 5.9795, R2: 0.0609
|
9 |
+
Indicator 7 (K) - RMSE: 13.8154, R2: 0.1097
|
10 |
+
Indicator 8 (CEC) - RMSE: 3.5810, R2: 0.0387
|
11 |
+
|
12 |
+
Average Test Loss: 21.2949
|
results1/metrics_modelC_20241210_151845_Abs-SG2.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Results for Model C generated at 2024-12-10 15:18:45
|
2 |
+
--------------------------------------------------
|
3 |
+
Indicator 1 (pH.in.CaCl2) - RMSE: 1.5963, R2: -2.1907
|
4 |
+
Indicator 2 (pH.in.H2O) - RMSE: 1.5703, R2: -2.3361
|
5 |
+
Indicator 3 (OC) - RMSE: 9.5270, R2: -0.0543
|
6 |
+
Indicator 4 (CaCO3) - RMSE: 11.4900, R2: -0.1685
|
7 |
+
Indicator 5 (N) - RMSE: 1.9782, R2: -0.2159
|
8 |
+
Indicator 6 (P) - RMSE: 6.6043, R2: -0.3975
|
9 |
+
Indicator 7 (K) - RMSE: 15.9916, R2: -0.5983
|
10 |
+
Indicator 8 (CEC) - RMSE: 4.1587, R2: -0.7483
|
11 |
+
|
12 |
+
Average Test Loss: 35.7320
|
results1/metrics_modelC_20241210_153333_Abs-SG2-SNV.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Results for Model C generated at 2024-12-10 15:33:33
|
2 |
+
--------------------------------------------------
|
3 |
+
Indicator 1 (pH.in.CaCl2) - RMSE: 0.8667, R2: 0.7227
|
4 |
+
Indicator 2 (pH.in.H2O) - RMSE: 0.8609, R2: 0.6986
|
5 |
+
Indicator 3 (OC) - RMSE: 5.8075, R2: 0.8544
|
6 |
+
Indicator 4 (CaCO3) - RMSE: 6.3981, R2: 0.8877
|
7 |
+
Indicator 5 (N) - RMSE: 1.4128, R2: 0.6836
|
8 |
+
Indicator 6 (P) - RMSE: 5.8376, R2: 0.1469
|
9 |
+
Indicator 7 (K) - RMSE: 12.9567, R2: 0.3112
|
10 |
+
Indicator 8 (CEC) - RMSE: 3.2131, R2: 0.3770
|
11 |
+
|
12 |
+
Average Test Loss: 17.8127
|
results1/training_metrics_modelC_20241207_155007_Abs-SG0.png
ADDED
![]() |
Git LFS Details
|
results1/training_metrics_modelC_20241207_160632_Abs-SG0-SNV.png
ADDED
![]() |
Git LFS Details
|
results1/training_metrics_modelC_20241207_162218_Abs-SG1.png
ADDED
![]() |
Git LFS Details
|
results1/training_metrics_modelC_20241207_163811_Abs-SG1-SNV.png
ADDED
![]() |
Git LFS Details
|
results1/training_metrics_modelC_20241207_165344_Abs-SG2.png
ADDED
![]() |
Git LFS Details
|
results1/training_metrics_modelC_20241207_170911_Abs-SG2-SNV.png
ADDED
![]() |
Git LFS Details
|
results2/loss_curves_modelC_20241213_202047_5.png
ADDED
![]() |
Git LFS Details
|
results2/loss_curves_modelC_20241213_203438_10.png
ADDED
![]() |
Git LFS Details
|
results2/loss_curves_modelC_20241213_204103_15.png
ADDED
![]() |
Git LFS Details
|
results2/metrics_modelC_20241213_202047_5.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Results for Model C generated at 2024-12-13 20:20:47
|
2 |
+
--------------------------------------------------
|
3 |
+
|
4 |
+
Average Test Loss: 17.4861
|
results2/metrics_modelC_20241213_203438_10.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Results for Model C generated at 2024-12-13 20:34:38
|
2 |
+
--------------------------------------------------
|
3 |
+
|
4 |
+
Average Test Loss: 17.7109
|
results2/metrics_modelC_20241213_204103_15.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Results for Model C generated at 2024-12-13 20:41:03
|
2 |
+
--------------------------------------------------
|
3 |
+
|
4 |
+
Average Test Loss: 17.6297
|
results2/training_metrics_modelC_20241213_202047_5.png
ADDED
![]() |
Git LFS Details
|
results2/training_metrics_modelC_20241213_203438_10.png
ADDED
![]() |
Git LFS Details
|
results2/training_metrics_modelC_20241213_204103_15.png
ADDED
![]() |
Git LFS Details
|