Spaces:
Sleeping
Sleeping
Upload 8 files
Browse files- README.md +69 -8
- app.py +402 -0
- best_val_r2_model.pth +3 -0
- data.xlsx +0 -0
- feature_names.json +1 -0
- requirements.txt +8 -0
- scaler_X.pkl +3 -0
- scaler_y.pkl +3 -0
README.md
CHANGED
@@ -1,14 +1,75 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license: mit
|
11 |
-
short_description: SoilResistivity
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Soil Resistivity Prediction
|
3 |
+
emoji: 馃殫
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: green
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: "1.29.0"
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
|
|
10 |
---
|
11 |
|
12 |
+
|
13 |
+
# Soil Resistivity Prediction Tool
|
14 |
+
|
15 |
+
This application predicts soil resistivity based on various soil properties using a deep learning model with a transformer-based architecture.
|
16 |
+
|
17 |
+
## Features
|
18 |
+
|
19 |
+
- Interactive web interface for inputting soil properties
|
20 |
+
- Real-time prediction of soil resistivity
|
21 |
+
- SHAP-based explanation of feature importance
|
22 |
+
- Visualization of prediction results
|
23 |
+
|
24 |
+
## Installation
|
25 |
+
|
26 |
+
1. Make sure you have Python 3.8+ installed
|
27 |
+
2. Install the required packages:
|
28 |
+
|
29 |
+
```bash
|
30 |
+
pip install -r requirements.txt
|
31 |
+
```
|
32 |
+
|
33 |
+
## Usage
|
34 |
+
|
35 |
+
1. Run the Streamlit application:
|
36 |
+
|
37 |
+
```bash
|
38 |
+
streamlit run app.py
|
39 |
+
```
|
40 |
+
|
41 |
+
2. Open your web browser and navigate to the URL shown in the terminal (typically http://localhost:8501)
|
42 |
+
3. Enter the soil properties in the input fields
|
43 |
+
4. Click the "Predict Resistivity" button to get a prediction
|
44 |
+
5. View the prediction result and feature importance explanation
|
45 |
+
|
46 |
+
## Files
|
47 |
+
|
48 |
+
- `app.py`: The main Streamlit application
|
49 |
+
- `data.xlsx`: Sample data used for scaling and background for SHAP explanations
|
50 |
+
- `best_val_r2_model.pth`: The trained PyTorch model
|
51 |
+
- `scaler_X.pkl`: Scaler for input features
|
52 |
+
- `scaler_y.pkl`: Scaler for target variable
|
53 |
+
- `feature_names.json`: Names of the input features
|
54 |
+
- `temp/`: Directory for temporary files (created automatically)
|
55 |
+
|
56 |
+
## Model Architecture
|
57 |
+
|
58 |
+
The model uses a transformer-based architecture with:
|
59 |
+
- Feature embeddings for each input feature
|
60 |
+
- Multi-head self-attention mechanisms
|
61 |
+
- Both feature-wise and sample-wise attention
|
62 |
+
- Residual connections and layer normalization
|
63 |
+
|
64 |
+
## Troubleshooting
|
65 |
+
|
66 |
+
If you encounter any issues:
|
67 |
+
|
68 |
+
1. Make sure all required files are in the same directory
|
69 |
+
2. Check that all dependencies are installed correctly
|
70 |
+
3. Ensure you have sufficient permissions to create the temp directory
|
71 |
+
4. If you see warnings about feature names, these can be safely ignored
|
72 |
+
|
73 |
+
## License
|
74 |
+
|
75 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
app.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from sklearn.preprocessing import PowerTransformer
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import shap
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
import pickle
|
12 |
+
import sys
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
# Suppress OpenMP warnings
|
16 |
+
warnings.filterwarnings("ignore", message=".*OpenMP.*")
|
17 |
+
# Suppress PowerTransformer feature names warning
|
18 |
+
warnings.filterwarnings("ignore", message=".*has feature names.*")
|
19 |
+
|
20 |
+
# Get the absolute path of the current file
|
21 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
22 |
+
|
23 |
+
# Create temp directory for plots if it doesn't exist
|
24 |
+
os.makedirs(os.path.join(current_dir, 'temp'), exist_ok=True)
|
25 |
+
|
26 |
+
# Define the model classes from 2wayembed.py
|
27 |
+
class FeatureEmbedding(nn.Module):
|
28 |
+
def __init__(self, input_dim=1, embedding_dim=32):
|
29 |
+
super().__init__()
|
30 |
+
self.embedding = nn.Sequential(
|
31 |
+
nn.Linear(input_dim, embedding_dim),
|
32 |
+
nn.ReLU(),
|
33 |
+
nn.Linear(embedding_dim, embedding_dim)
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return self.embedding(x)
|
38 |
+
|
39 |
+
class TabularTransformerWithEmbedding(nn.Module):
|
40 |
+
def __init__(self, num_features=6, embedding_dim=32, output_dim=1, num_attention_heads=4):
|
41 |
+
super().__init__()
|
42 |
+
self.num_features = num_features
|
43 |
+
self.embedding_dim = embedding_dim
|
44 |
+
|
45 |
+
# Create separate embedding for each feature
|
46 |
+
self.feature_embeddings = nn.ModuleList([
|
47 |
+
FeatureEmbedding(input_dim=1, embedding_dim=embedding_dim)
|
48 |
+
for _ in range(num_features)
|
49 |
+
])
|
50 |
+
|
51 |
+
# 1D Feature Attention (attention across features)
|
52 |
+
self.feature_attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_attention_heads)
|
53 |
+
self.feature_norm = nn.LayerNorm(embedding_dim)
|
54 |
+
|
55 |
+
# 1D Sample Attention (attention across samples/rows in batch)
|
56 |
+
self.sample_attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_attention_heads)
|
57 |
+
self.sample_norm = nn.LayerNorm(embedding_dim)
|
58 |
+
|
59 |
+
# Combine layer
|
60 |
+
self.combine_layer = nn.Linear(embedding_dim*2, embedding_dim)
|
61 |
+
self.combine_activation = nn.ReLU()
|
62 |
+
|
63 |
+
# Output layers
|
64 |
+
self.output_layers = nn.Sequential(
|
65 |
+
nn.Linear(embedding_dim, embedding_dim),
|
66 |
+
nn.ReLU(),
|
67 |
+
nn.Linear(embedding_dim, output_dim)
|
68 |
+
)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
# x shape: (batch_size, num_features)
|
72 |
+
batch_size = x.shape[0]
|
73 |
+
|
74 |
+
# Project each feature to embedding space
|
75 |
+
embedded_features = []
|
76 |
+
for i in range(self.num_features):
|
77 |
+
# Extract single feature and project to embedding dimension
|
78 |
+
feature = x[:, i:i+1] # (batch_size, 1)
|
79 |
+
projected = self.feature_embeddings[i](feature) # (batch_size, embedding_dim)
|
80 |
+
embedded_features.append(projected)
|
81 |
+
|
82 |
+
# Stack features for attention
|
83 |
+
# Shape: (num_features, batch_size, embedding_dim)
|
84 |
+
embeddings = torch.stack(embedded_features)
|
85 |
+
|
86 |
+
# 1. Feature Attention (attending to features)
|
87 |
+
# Each feature attends to all other features
|
88 |
+
# Apply feature attention in multiple layers
|
89 |
+
feature_attended = embeddings
|
90 |
+
for _ in range(4):
|
91 |
+
# Apply attention
|
92 |
+
attended_layer, _ = self.feature_attention(feature_attended, feature_attended, feature_attended)
|
93 |
+
# Add residual connection
|
94 |
+
feature_attended = attended_layer + feature_attended
|
95 |
+
# Apply layer normalization
|
96 |
+
feature_attended = self.feature_norm(feature_attended)
|
97 |
+
|
98 |
+
# 2. Sample Attention (attending to samples)
|
99 |
+
# Permute to make batch dimension first for sample attention
|
100 |
+
# Shape: (batch_size, num_features, embedding_dim)
|
101 |
+
sample_input = embeddings.permute(1, 0, 2)
|
102 |
+
# Permute back for attention: (num_features, batch_size, embedding_dim)
|
103 |
+
sample_input = sample_input.permute(1, 0, 2)
|
104 |
+
# Apply sample attention in multiple layers
|
105 |
+
sample_attended = sample_input
|
106 |
+
for _ in range(4):
|
107 |
+
# Apply attention
|
108 |
+
attended_layer, _ = self.sample_attention(sample_attended, sample_attended, sample_attended)
|
109 |
+
# Add residual connection
|
110 |
+
sample_attended = attended_layer + sample_attended
|
111 |
+
# Apply layer normalization
|
112 |
+
sample_attended = self.sample_norm(sample_attended)
|
113 |
+
|
114 |
+
# Combine both attention mechanisms
|
115 |
+
# First, make batch dimension first for both
|
116 |
+
# Shape: (batch_size, num_features, embedding_dim)
|
117 |
+
feature_attended = feature_attended.permute(1, 0, 2)
|
118 |
+
sample_attended = sample_attended.permute(1, 0, 2)
|
119 |
+
|
120 |
+
# Mean across features to get a single vector per sample
|
121 |
+
# Shape: (batch_size, embedding_dim)
|
122 |
+
feature_pooled = feature_attended.mean(dim=1)
|
123 |
+
sample_pooled = sample_attended.mean(dim=1)
|
124 |
+
|
125 |
+
# Concatenate the two attention results
|
126 |
+
# Shape: (batch_size, embedding_dim*2)
|
127 |
+
combined = torch.cat([feature_pooled, sample_pooled], dim=1)
|
128 |
+
|
129 |
+
# Project back to embedding_dim
|
130 |
+
combined = self.combine_layer(combined)
|
131 |
+
combined = self.combine_activation(combined)
|
132 |
+
|
133 |
+
# Final output layers
|
134 |
+
output = self.output_layers(combined) # (batch_size, output_dim)
|
135 |
+
|
136 |
+
return output
|
137 |
+
|
138 |
+
class ShapModel:
|
139 |
+
def __init__(self, model):
|
140 |
+
self.model = model
|
141 |
+
|
142 |
+
def __call__(self, X):
|
143 |
+
with torch.no_grad():
|
144 |
+
X_tensor = torch.FloatTensor(X.values if isinstance(X, pd.DataFrame) else X)
|
145 |
+
output = self.model(X_tensor)
|
146 |
+
return output.numpy()
|
147 |
+
|
148 |
+
@st.cache_resource
|
149 |
+
def load_model_and_scalers():
|
150 |
+
"""Load the model, scalers, and data"""
|
151 |
+
# Set paths relative to the current file
|
152 |
+
model_path = os.path.join(current_dir, "best_val_r2_model.pth")
|
153 |
+
data_path = os.path.join(current_dir, "data.xlsx")
|
154 |
+
scaler_x_path = os.path.join(current_dir, "scaler_X.pkl")
|
155 |
+
scaler_y_path = os.path.join(current_dir, "scaler_y.pkl")
|
156 |
+
|
157 |
+
# Load data
|
158 |
+
df = pd.read_excel(data_path)
|
159 |
+
X = df.iloc[:, 0:6] # First 6 columns for features
|
160 |
+
y = df.iloc[:, 6] # 7th column for target (Y)
|
161 |
+
feature_names = X.columns.tolist()
|
162 |
+
|
163 |
+
# Initialize model
|
164 |
+
model = TabularTransformerWithEmbedding(num_features=6, embedding_dim=32, output_dim=1, num_attention_heads=4)
|
165 |
+
|
166 |
+
# Load model state dict
|
167 |
+
state_dict = torch.load(model_path)
|
168 |
+
|
169 |
+
# Remove feature_weights if present in the state dict but not in the model
|
170 |
+
if 'feature_weights' in state_dict and not hasattr(model, 'feature_weights'):
|
171 |
+
del state_dict['feature_weights']
|
172 |
+
|
173 |
+
# Load the state dict with strict=False to allow missing keys
|
174 |
+
model.load_state_dict(state_dict, strict=False)
|
175 |
+
model.eval()
|
176 |
+
|
177 |
+
# Load saved scalers with error handling
|
178 |
+
try:
|
179 |
+
with open(scaler_x_path, 'rb') as f:
|
180 |
+
scaler_X = pickle.load(f)
|
181 |
+
with open(scaler_y_path, 'rb') as f:
|
182 |
+
scaler_y = pickle.load(f)
|
183 |
+
except (FileNotFoundError, pickle.UnpicklingError) as e:
|
184 |
+
# If saved scalers not found or unpickling error, create new ones
|
185 |
+
st.warning(f"Issue with saved scalers: {str(e)}. Creating new scalers.")
|
186 |
+
scaler_X = PowerTransformer(method='yeo-johnson')
|
187 |
+
scaler_y = PowerTransformer(method='yeo-johnson')
|
188 |
+
|
189 |
+
# Fit scalers
|
190 |
+
scaler_X.fit(X)
|
191 |
+
scaler_y.fit(y.values.reshape(-1, 1))
|
192 |
+
|
193 |
+
# Save the new scalers
|
194 |
+
with open(scaler_x_path, 'wb') as f:
|
195 |
+
pickle.dump(scaler_X, f)
|
196 |
+
with open(scaler_y_path, 'wb') as f:
|
197 |
+
pickle.dump(scaler_y, f)
|
198 |
+
|
199 |
+
# Save feature names for later use
|
200 |
+
with open(os.path.join(current_dir, 'feature_names.json'), 'w') as f:
|
201 |
+
json.dump(feature_names, f)
|
202 |
+
|
203 |
+
return model, scaler_X, scaler_y, feature_names, X
|
204 |
+
|
205 |
+
def explain_prediction(model, input_df, X_background, scaler_X, scaler_y, feature_names):
|
206 |
+
"""Generate SHAP explanation for a prediction"""
|
207 |
+
try:
|
208 |
+
# Create a prediction function for SHAP
|
209 |
+
def predict_fn(X):
|
210 |
+
try:
|
211 |
+
# Convert to numpy array if it's a DataFrame to avoid feature names warning
|
212 |
+
X_array = X.values if isinstance(X, pd.DataFrame) else X
|
213 |
+
X_tensor = torch.FloatTensor(scaler_X.transform(X_array))
|
214 |
+
with torch.no_grad():
|
215 |
+
scaled_pred = model(X_tensor).numpy()
|
216 |
+
return scaler_y.inverse_transform(scaled_pred)
|
217 |
+
except Exception as e:
|
218 |
+
st.error(f"Error in prediction function: {str(e)}")
|
219 |
+
# Return zeros as fallback
|
220 |
+
return np.zeros((X_array.shape[0], 1))
|
221 |
+
|
222 |
+
# Create a ShapModel instance
|
223 |
+
shap_model = ShapModel(model)
|
224 |
+
|
225 |
+
# Calculate SHAP values
|
226 |
+
background = shap.kmeans(X_background.values, 10)
|
227 |
+
explainer = shap.KernelExplainer(predict_fn, background)
|
228 |
+
|
229 |
+
# Get SHAP values for the input
|
230 |
+
# Convert to numpy array to avoid feature names warning
|
231 |
+
input_array = input_df.values
|
232 |
+
shap_values = explainer.shap_values(input_array)
|
233 |
+
|
234 |
+
# Handle different SHAP value formats
|
235 |
+
if isinstance(shap_values, list):
|
236 |
+
shap_values = np.array(shap_values[0])
|
237 |
+
|
238 |
+
# Ensure correct shape for waterfall plot
|
239 |
+
if len(shap_values.shape) > 1:
|
240 |
+
if shap_values.shape[0] == len(feature_names):
|
241 |
+
shap_values = shap_values.T
|
242 |
+
shap_values = shap_values.flatten()
|
243 |
+
|
244 |
+
# Create waterfall plot
|
245 |
+
plt.figure(figsize=(10, 6))
|
246 |
+
shap.plots.waterfall(
|
247 |
+
shap.Explanation(
|
248 |
+
values=shap_values,
|
249 |
+
base_values=explainer.expected_value if np.isscalar(explainer.expected_value)
|
250 |
+
else explainer.expected_value[0],
|
251 |
+
data=input_df.iloc[0].values,
|
252 |
+
feature_names=feature_names
|
253 |
+
),
|
254 |
+
show=False
|
255 |
+
)
|
256 |
+
plt.title('Feature Contributions to Prediction')
|
257 |
+
plt.tight_layout()
|
258 |
+
|
259 |
+
# Save the plot to a temporary file
|
260 |
+
temp_dir = os.path.join(current_dir, 'temp')
|
261 |
+
os.makedirs(temp_dir, exist_ok=True)
|
262 |
+
temp_file = os.path.join(temp_dir, 'shap_explanation.png')
|
263 |
+
plt.savefig(temp_file, dpi=300, bbox_inches='tight')
|
264 |
+
plt.close()
|
265 |
+
|
266 |
+
return explainer.expected_value, shap_values, temp_file
|
267 |
+
except Exception as e:
|
268 |
+
st.error(f"Error generating explanation: {str(e)}")
|
269 |
+
return 0, np.zeros(len(feature_names)), None
|
270 |
+
|
271 |
+
def model_predict(model, input_df, scaler_X, scaler_y):
|
272 |
+
"""Make a prediction using the model"""
|
273 |
+
try:
|
274 |
+
# Scale input data
|
275 |
+
# Convert DataFrame to numpy array before transformation to avoid feature names warning
|
276 |
+
X_scaled = scaler_X.transform(input_df.values)
|
277 |
+
X_tensor = torch.FloatTensor(X_scaled)
|
278 |
+
|
279 |
+
# Make prediction
|
280 |
+
with torch.no_grad():
|
281 |
+
scaled_pred = model(X_tensor).numpy()
|
282 |
+
|
283 |
+
# Inverse transform to get original scale prediction
|
284 |
+
prediction = scaler_y.inverse_transform(scaled_pred)
|
285 |
+
return prediction.flatten()
|
286 |
+
except Exception as e:
|
287 |
+
st.error(f"Error making prediction: {str(e)}")
|
288 |
+
# Return a default value in case of error
|
289 |
+
return np.array([0.0])
|
290 |
+
|
291 |
+
# Set page title and description
|
292 |
+
st.set_page_config(
|
293 |
+
page_title="Soil Resistivity Predictor",
|
294 |
+
page_icon="馃И",
|
295 |
+
layout="wide"
|
296 |
+
)
|
297 |
+
|
298 |
+
st.title("Soil Resistivity Prediction Tool")
|
299 |
+
st.markdown("""
|
300 |
+
This application predicts soil resistivity based on various soil properties using a deep learning model.
|
301 |
+
Enter the soil properties below and click the 'Predict Resistivity' button to get a prediction.
|
302 |
+
""")
|
303 |
+
|
304 |
+
# Ensure temp directory exists
|
305 |
+
temp_dir = os.path.join(current_dir, 'temp')
|
306 |
+
os.makedirs(temp_dir, exist_ok=True)
|
307 |
+
|
308 |
+
# Add a session state to track if this is the first run
|
309 |
+
if 'first_run' not in st.session_state:
|
310 |
+
st.session_state.first_run = True
|
311 |
+
# Clear any existing temp files on first run
|
312 |
+
for file in os.listdir(temp_dir):
|
313 |
+
if file.endswith('.png'):
|
314 |
+
try:
|
315 |
+
os.remove(os.path.join(temp_dir, file))
|
316 |
+
except:
|
317 |
+
pass
|
318 |
+
|
319 |
+
# Load model and scalers
|
320 |
+
try:
|
321 |
+
model, scaler_X, scaler_y, feature_names, X = load_model_and_scalers()
|
322 |
+
|
323 |
+
# Create input fields for features
|
324 |
+
st.subheader("Input Features")
|
325 |
+
|
326 |
+
# Create two columns for input fields
|
327 |
+
col1, col2 = st.columns(2)
|
328 |
+
|
329 |
+
# Dictionary to store input values
|
330 |
+
input_values = {}
|
331 |
+
|
332 |
+
# Create input fields split between two columns
|
333 |
+
for i, feature in enumerate(feature_names):
|
334 |
+
# Get min and max values for each feature
|
335 |
+
min_val = float(X[feature].min())
|
336 |
+
max_val = float(X[feature].max())
|
337 |
+
|
338 |
+
# Add input field to alternating columns
|
339 |
+
with col1 if i < len(feature_names)//2 else col2:
|
340 |
+
# Use session state to maintain values between reruns
|
341 |
+
if f'input_{feature}' not in st.session_state:
|
342 |
+
st.session_state[f'input_{feature}'] = float(X[feature].mean())
|
343 |
+
|
344 |
+
input_values[feature] = st.number_input(
|
345 |
+
f"{feature}",
|
346 |
+
min_value=float(min_val * 0.9), # Allow slightly below min
|
347 |
+
max_value=float(max_val * 1.1), # Allow slightly above max
|
348 |
+
value=st.session_state[f'input_{feature}'],
|
349 |
+
key=f'input_widget_{feature}',
|
350 |
+
help=f"Range: {min_val:.2f} to {max_val:.2f}"
|
351 |
+
)
|
352 |
+
# Update session state with current value
|
353 |
+
st.session_state[f'input_{feature}'] = input_values[feature]
|
354 |
+
|
355 |
+
# Add predict button
|
356 |
+
if st.button("Predict Resistivity", type="primary"):
|
357 |
+
try:
|
358 |
+
# Create input DataFrame
|
359 |
+
input_df = pd.DataFrame([input_values])
|
360 |
+
|
361 |
+
# Make prediction
|
362 |
+
with st.spinner("Calculating prediction..."):
|
363 |
+
prediction = model_predict(model, input_df, scaler_X, scaler_y)
|
364 |
+
|
365 |
+
# Display prediction
|
366 |
+
st.subheader("Prediction Result")
|
367 |
+
st.markdown(f"### Predicted Resistivity: {prediction[0]:.2f} 惟路m")
|
368 |
+
|
369 |
+
# Calculate and display SHAP values
|
370 |
+
with st.spinner("Generating explanation..."):
|
371 |
+
st.subheader("Feature Importance Explanation")
|
372 |
+
|
373 |
+
# Get SHAP values using the training data as background
|
374 |
+
expected_value, shap_values, temp_file = explain_prediction(
|
375 |
+
model, input_df, X, scaler_X, scaler_y, feature_names
|
376 |
+
)
|
377 |
+
|
378 |
+
# Display the waterfall plot
|
379 |
+
if temp_file and os.path.exists(temp_file):
|
380 |
+
try:
|
381 |
+
st.image(temp_file)
|
382 |
+
except Exception as img_error:
|
383 |
+
st.error(f"Error displaying SHAP explanation image: {str(img_error)}")
|
384 |
+
else:
|
385 |
+
st.warning("Could not generate SHAP explanation plot.")
|
386 |
+
except Exception as pred_error:
|
387 |
+
st.error(f"Error during prediction process: {str(pred_error)}")
|
388 |
+
st.exception(pred_error)
|
389 |
+
|
390 |
+
except Exception as e:
|
391 |
+
st.error(f"""
|
392 |
+
Error loading the model and data. Please make sure:
|
393 |
+
1. The model file 'best_val_r2_model.pth' exists in the application directory
|
394 |
+
2. The data file 'data.xlsx' exists in the application directory
|
395 |
+
3. The scaler files 'scaler_X.pkl' and 'scaler_y.pkl' exist in the application directory
|
396 |
+
4. All required packages are installed
|
397 |
+
|
398 |
+
Error details: {str(e)}
|
399 |
+
""")
|
400 |
+
|
401 |
+
# Show detailed error information
|
402 |
+
st.exception(e)
|
best_val_r2_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59e01841fb6c3e9bdf5a1434c3c698be3de52f261cbad3596ad56bb36eccb369
|
3 |
+
size 142400
|
data.xlsx
ADDED
Binary file (20 kB). View file
|
|
feature_names.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
["\u03c1d (g/cm3)", "w (%)", "F200 (%)", "Gs", "LL (%)", "PL (%)"]
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.32.0
|
2 |
+
pandas==2.1.4
|
3 |
+
numpy==1.26.3
|
4 |
+
torch==2.1.2
|
5 |
+
scikit-learn==1.5.2
|
6 |
+
matplotlib==3.8.2
|
7 |
+
shap==0.44.1
|
8 |
+
openpyxl==3.1.2
|
scaler_X.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0aec13e01d17d0f40db347469ebf2575ac7d60b7ca07bb8cb46b945bd2a18389
|
3 |
+
size 822
|
scaler_y.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d1d17984f930fbc47f453582fb009bfef64d47fd563bf53b4310e1f162ba11e2
|
3 |
+
size 662
|