|
import torch |
|
import numpy as np |
|
import shap |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
from data_load import load_soil_data |
|
from data_processing import preprocess_with_downsampling, process_spectra |
|
from resnet1d_multitask import get_model |
|
|
|
|
|
plt.rcParams['font.sans-serif'] = ['SimHei'] |
|
plt.rcParams['axes.unicode_minus'] = False |
|
|
|
def load_model_and_data(): |
|
|
|
target_columns = ['pH.in.CaCl2', 'pH.in.H2O', 'OC', 'CaCO3', 'N', 'P', 'K', 'CEC'] |
|
|
|
|
|
X_train, X_test, y_train, y_test, wavelengths = load_soil_data('LUCAS.2009_abs.csv', target_columns) |
|
|
|
|
|
X_train, X_test = X_train.squeeze(), X_test.squeeze() |
|
|
|
|
|
X_train = process_spectra(X_train, 'Abs-SG1-SNV') |
|
X_test = process_spectra(X_test, 'Abs-SG1-SNV') |
|
|
|
X_train, X_train_nwavelengths = preprocess_with_downsampling(X_train, wavelengths, 15) |
|
X_test, X_test_nwavelengths = preprocess_with_downsampling(X_test, wavelengths, 15) |
|
|
|
|
|
X_train = X_train.reshape(X_train.shape[0], 1, X_train.shape[1]) |
|
X_test = X_test.reshape(X_test.shape[0], 1, X_test.shape[1]) |
|
|
|
|
|
device = torch.device('cpu') |
|
model = get_model('C') |
|
model.load_state_dict(torch.load('code/models/set.pth', map_location=device)) |
|
model.eval() |
|
|
|
return model, X_test, X_test_nwavelengths |
|
|
|
def explain_predictions(model, X_test, wavelengths, n_background=50, n_samples=100): |
|
""" |
|
使用KernelExplainer计算SHAP值,更适合CPU环境 |
|
""" |
|
|
|
def model_predict(x): |
|
with torch.no_grad(): |
|
x = torch.FloatTensor(x) |
|
if len(x.shape) == 2: |
|
x = x.reshape(x.shape[0], 1, -1) |
|
output = model(x) |
|
|
|
if isinstance(output, tuple): |
|
output = output[0] |
|
return output.numpy() |
|
|
|
|
|
background_data = X_test[:n_background].squeeze() |
|
test_data = X_test[:n_samples].squeeze() |
|
|
|
|
|
print("创建SHAP解释器...") |
|
explainer = shap.KernelExplainer(model_predict, background_data) |
|
|
|
|
|
print("计算SHAP值(这可能需要几分钟时间)...") |
|
shap_values = explainer.shap_values(test_data, nsamples=100) |
|
|
|
|
|
if isinstance(shap_values, list): |
|
|
|
avg_shap_values = np.mean([np.abs(sv) for sv in shap_values], axis=0) |
|
else: |
|
avg_shap_values = np.abs(shap_values) |
|
|
|
return avg_shap_values |
|
|
|
def plot_top10_wavelengths(shap_values, wavelengths): |
|
""" |
|
绘制贡献度排名前10的波段柱状图 |
|
""" |
|
|
|
mean_shap = np.mean(np.abs(shap_values), axis=0) |
|
|
|
|
|
top10_idx = np.argsort(mean_shap)[-10:][::-1] |
|
top10_wavelengths = wavelengths[top10_idx] |
|
top10_values = mean_shap[top10_idx] |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
plt.barh(range(10), top10_values, color='skyblue') |
|
plt.yticks(range(10), [f'{w:.1f}' for w in top10_wavelengths]) |
|
plt.xlabel('mean(|SHAP value|) (average impact on model output magnitude)') |
|
plt.title('Top 10 Wavelengths by SHAP Value') |
|
plt.gca().invert_yaxis() |
|
plt.tight_layout() |
|
plt.savefig('shap_top10_wavelengths.png', dpi=300, bbox_inches='tight') |
|
plt.close() |
|
|
|
|
|
def plot_shap_summary(shap_values, wavelengths): |
|
""" |
|
绘制SHAP值的蜂窝图 |
|
""" |
|
plt.figure(figsize=(10, 8)) |
|
shap.summary_plot(shap_values, features=wavelengths, feature_names=[f'{w:.1f}' for w in wavelengths], plot_type='dot', show=False) |
|
plt.tight_layout() |
|
plt.savefig('shap_summary_plot.png', dpi=300, bbox_inches='tight') |
|
plt.close() |
|
|
|
def main(): |
|
|
|
print("加载模型和数据...") |
|
model, X_test, wavelengths = load_model_and_data() |
|
|
|
|
|
print("正在计算SHAP值...") |
|
shap_values = explain_predictions(model, X_test, wavelengths) |
|
|
|
|
|
print("正在生成图表...") |
|
plot_shap_summary(shap_values, wavelengths) |
|
plot_top10_wavelengths(shap_values, wavelengths) |
|
print("分析完成!图表已保存为 'shap_summary_plot.png' 和 'shap_top10_wavelengths.png'") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|