Spaces:
Sleeping
Sleeping
salohnana2018
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from zlib import crc32
|
2 |
+
import struct
|
3 |
+
import gradio as gr
|
4 |
+
import os
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
import joblib
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
# Define top features
|
13 |
+
top_features = set([
|
14 |
+
'pm.vbatMV', 'stateEstimate.z', 'motor.m3', 'stateEstimate.yaw', 'yaw_cos',
|
15 |
+
'motor.m2', 'stateEstimate.y', 'stateEstimate.x', 'motor.m1', 'theta',
|
16 |
+
'motor.m4', 'position_magnitude', 'combined_orientation', 'pwm.m3_pwm',
|
17 |
+
'stateEstimate.roll', 'phi', 'pwm.m2_pwm', 'roll_cos', 'vx_cosine',
|
18 |
+
'stateEstimate.vx', 'velocity_magnitude', 'stateEstimate.vy', 'pwm.m4_pwm',
|
19 |
+
'stateEstimate.vz', 'pwm.m1_pwm'
|
20 |
+
])
|
21 |
+
|
22 |
+
# Load the median values from the CSV once
|
23 |
+
feature_medians = pd.read_csv("model/feature_medians.csv")
|
24 |
+
medians_dict = feature_medians.set_index('Feature')['Median'].to_dict()
|
25 |
+
|
26 |
+
# Load the label encoder, scaler, and saved feature names
|
27 |
+
label_encoder = joblib.load('model/label_encoder.pkl')
|
28 |
+
scaler = joblib.load('model/scaler.pkl')
|
29 |
+
saved_feature_names = joblib.load('model/feature_names.pkl')
|
30 |
+
|
31 |
+
# Define the EnhancedFaultDetectionNN model
|
32 |
+
class EnhancedFaultDetectionNN(nn.Module):
|
33 |
+
def __init__(self, input_size, output_size, dropout_prob=0.08):
|
34 |
+
super(EnhancedFaultDetectionNN, self).__init__()
|
35 |
+
|
36 |
+
self.fc1 = nn.Linear(input_size, 1024)
|
37 |
+
self.bn1 = nn.BatchNorm1d(1024)
|
38 |
+
self.fc2 = nn.Linear(1024, 512)
|
39 |
+
self.bn2 = nn.BatchNorm1d(512)
|
40 |
+
self.fc3 = nn.Linear(512, 256)
|
41 |
+
self.bn3 = nn.BatchNorm1d(256)
|
42 |
+
self.fc4 = nn.Linear(256, output_size)
|
43 |
+
self.dropout = nn.Dropout(dropout_prob)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
x = F.relu(self.bn1(self.fc1(x)))
|
47 |
+
x = self.dropout(x)
|
48 |
+
x = F.relu(self.bn2(self.fc2(x)))
|
49 |
+
x = self.dropout(x)
|
50 |
+
x = F.relu(self.bn3(self.fc3(x)))
|
51 |
+
x = self.dropout(x)
|
52 |
+
x = self.fc4(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
# Load the PyTorch model
|
56 |
+
model_path = 'model/best_model_without_oversampling128.pth'
|
57 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
58 |
+
input_size = len(saved_feature_names)
|
59 |
+
output_size = len(label_encoder.classes_)
|
60 |
+
|
61 |
+
model = EnhancedFaultDetectionNN(input_size, output_size).to(device)
|
62 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
63 |
+
model.eval()
|
64 |
+
|
65 |
+
# Mapping of fault types to corresponding images and comments
|
66 |
+
defect_image_map = {
|
67 |
+
|
68 |
+
"Extra Weight": {
|
69 |
+
"image": "images/weight.png",
|
70 |
+
"comment": "A weight added near the M3 motor causes lift imbalance."
|
71 |
+
},
|
72 |
+
"Propeller Cut": {
|
73 |
+
"image": "images/propeller_cut.png",
|
74 |
+
"comment": "A cut on the M2 propeller reduces thrust and causes instability."
|
75 |
+
},
|
76 |
+
"Tape on Propeller": {
|
77 |
+
"image": "images/tape.png",
|
78 |
+
"comment": "Tape on the M3 propeller leads to imbalance, drag, and vibrations, reducing stability."
|
79 |
+
},
|
80 |
+
"Normal Flight": {
|
81 |
+
"image": "images/normal_flight.png",
|
82 |
+
"comment": "The quadcopter operates normally with balanced thrust and stability."
|
83 |
+
},
|
84 |
+
}
|
85 |
+
|
86 |
+
# List of log files corresponding to the fault types
|
87 |
+
log_files = [
|
88 |
+
"Logs_Samples/add_weight_W1_near_M3_E9_log04",
|
89 |
+
"Logs_Samples/cut_M2_0.5мм_46.5мм_E9_log02",
|
90 |
+
"Logs_Samples/tape_on_propeller_M3_E9_log01",
|
91 |
+
"Logs_Samples/normal_flight_E8_log03"
|
92 |
+
]
|
93 |
+
|
94 |
+
# Mapping simplified labels to their corresponding folder names
|
95 |
+
LabelsMap = {
|
96 |
+
"Extra Weight": "add_weight_W1_near_M3",
|
97 |
+
"Propeller Cut": "cut_M2_0.5мм_46.5мм",
|
98 |
+
"Tape on Propeller": "tape_on_propeller_M3",
|
99 |
+
"Normal Flight": "normal_flight"
|
100 |
+
}
|
101 |
+
|
102 |
+
# Function to retrieve the log file path using LabelsMap and log_files
|
103 |
+
def get_log_file_path(label_key):
|
104 |
+
label_value = LabelsMap[label_key]
|
105 |
+
for log_file in log_files:
|
106 |
+
if label_value in log_file:
|
107 |
+
return log_file
|
108 |
+
return None # Return None if no matching file is found
|
109 |
+
|
110 |
+
def get_name(data, idx):
|
111 |
+
end_idx = idx
|
112 |
+
while data[end_idx] != 0:
|
113 |
+
end_idx += 1
|
114 |
+
return data[idx:end_idx].decode("utf-8"), end_idx + 1
|
115 |
+
|
116 |
+
def cfusdlog_decode(file):
|
117 |
+
data = file.read()
|
118 |
+
|
119 |
+
if data[0] != 0xBC:
|
120 |
+
raise gr.Error("Invalid file format: Magic header not found.")
|
121 |
+
|
122 |
+
crc = crc32(data[0:-4])
|
123 |
+
expected_crc, = struct.unpack('I', data[-4:])
|
124 |
+
if crc != expected_crc:
|
125 |
+
raise gr.Error("File integrity check failed: CRC mismatch.")
|
126 |
+
|
127 |
+
version, num_event_types = struct.unpack('HH', data[1:5])
|
128 |
+
if version not in [1, 2]:
|
129 |
+
raise gr.Error(f"Unsupported log file version: {version}")
|
130 |
+
|
131 |
+
result = {}
|
132 |
+
event_by_id = {}
|
133 |
+
idx = 5
|
134 |
+
|
135 |
+
for _ in range(num_event_types):
|
136 |
+
event_id, = struct.unpack('H', data[idx:idx+2])
|
137 |
+
idx += 2
|
138 |
+
event_name, idx = get_name(data, idx)
|
139 |
+
result[event_name] = {'timestamp': []}
|
140 |
+
num_variables, = struct.unpack('H', data[idx:idx+2])
|
141 |
+
idx += 2
|
142 |
+
fmt_str = "<"
|
143 |
+
variables = []
|
144 |
+
for _ in range(num_variables):
|
145 |
+
var_name_and_type, idx = get_name(data, idx)
|
146 |
+
var_name = var_name_and_type[:-3]
|
147 |
+
var_type = var_name_and_type[-2]
|
148 |
+
result[event_name][var_name] = []
|
149 |
+
fmt_str += var_type
|
150 |
+
variables.append(var_name)
|
151 |
+
event_by_id[event_id] = {
|
152 |
+
'name': event_name,
|
153 |
+
'fmt_str': fmt_str,
|
154 |
+
'num_bytes': struct.calcsize(fmt_str),
|
155 |
+
'variables': variables,
|
156 |
+
}
|
157 |
+
|
158 |
+
while idx < len(data) - 4:
|
159 |
+
if version == 1:
|
160 |
+
event_id, timestamp = struct.unpack('<HI', data[idx:idx+6])
|
161 |
+
idx += 6
|
162 |
+
elif version == 2:
|
163 |
+
event_id, timestamp = struct.unpack('<HQ', data[idx:idx+10])
|
164 |
+
timestamp /= 1000.0
|
165 |
+
idx += 10
|
166 |
+
event = event_by_id[event_id]
|
167 |
+
event_data = struct.unpack(event['fmt_str'], data[idx:idx+event['num_bytes']])
|
168 |
+
idx += event['num_bytes']
|
169 |
+
for var, value in zip(event['variables'], event_data):
|
170 |
+
result[event['name']][var].append(value)
|
171 |
+
result[event['name']]['timestamp'].append(timestamp)
|
172 |
+
|
173 |
+
for event_name, event_data in result.items():
|
174 |
+
for var_name, var_data in event_data.items():
|
175 |
+
result[event_name][var_name] = np.array(var_data)
|
176 |
+
|
177 |
+
return {k: v for k, v in result.items() if len(v['timestamp']) > 0} # Ensure that only non-empty timestamps are kept
|
178 |
+
|
179 |
+
def fix_time(log_data):
|
180 |
+
try:
|
181 |
+
timestamps = log_data["timestamp"]
|
182 |
+
if len(timestamps) == 0:
|
183 |
+
raise gr.Error("Timestamp data is empty.")
|
184 |
+
first_value = timestamps[0]
|
185 |
+
log_data["timestamp"] = [t - first_value for t in timestamps]
|
186 |
+
except KeyError:
|
187 |
+
raise gr.Error("Timestamp key not found in the log data.")
|
188 |
+
except Exception as e:
|
189 |
+
raise gr.Error(f"Failed to adjust timestamps: {e}")
|
190 |
+
|
191 |
+
def process_log_file(file):
|
192 |
+
try:
|
193 |
+
log_data = cfusdlog_decode(file)
|
194 |
+
log_data = log_data.get('fixedFrequency', {})
|
195 |
+
if not log_data:
|
196 |
+
raise gr.Warning(f"No 'fixedFrequency' data found in the log file")
|
197 |
+
|
198 |
+
fix_time(log_data)
|
199 |
+
parent_dir_name = os.path.basename(os.path.dirname(file.name))
|
200 |
+
log_data["true_label"] = [parent_dir_name] * len(log_data.get("timestamp", []))
|
201 |
+
|
202 |
+
df = pd.DataFrame(log_data)
|
203 |
+
return df
|
204 |
+
except Exception as e:
|
205 |
+
raise gr.Error(f"Failed to process log file: {e}")
|
206 |
+
|
207 |
+
def preprocess_single_data_point(single_data_point):
|
208 |
+
try:
|
209 |
+
if 'timestamp' in single_data_point.columns:
|
210 |
+
single_data_point.drop(columns=["timestamp"], inplace=True)
|
211 |
+
|
212 |
+
single_data_point.fillna(medians_dict, inplace=True)
|
213 |
+
|
214 |
+
state_x, state_y, state_z = single_data_point[['stateEstimate.x', 'stateEstimate.y', 'stateEstimate.z']].values.T
|
215 |
+
single_data_point['r'] = np.sqrt(state_x**2 + state_y**2 + state_z**2)
|
216 |
+
single_data_point['theta'] = np.arccos(np.clip(single_data_point['stateEstimate.z'] / single_data_point['r'], -1.0, 1.0)) # Clip to avoid invalid values
|
217 |
+
single_data_point['phi'] = np.arctan2(single_data_point['stateEstimate.y'], single_data_point['stateEstimate.x'])
|
218 |
+
single_data_point['position_magnitude'] = single_data_point['r']
|
219 |
+
|
220 |
+
velocity_x, velocity_y, velocity_z = single_data_point[['stateEstimate.vx', 'stateEstimate.vy', 'stateEstimate.vz']].values.T
|
221 |
+
single_data_point['velocity_magnitude'] = np.sqrt(velocity_x**2 + velocity_y**2 + velocity_z**2)
|
222 |
+
single_data_point['vx_cosine'] = np.divide(velocity_x, single_data_point['velocity_magnitude'], out=np.zeros_like(velocity_x), where=single_data_point['velocity_magnitude']!=0)
|
223 |
+
single_data_point['vy_cosine'] = np.divide(velocity_y, single_data_point['velocity_magnitude'], out=np.zeros_like(velocity_y), where=single_data_point['velocity_magnitude']!=0)
|
224 |
+
single_data_point['vz_cosine'] = np.divide(velocity_z, single_data_point['velocity_magnitude'], out=np.zeros_like(velocity_z), where=single_data_point['velocity_magnitude']!=0)
|
225 |
+
|
226 |
+
roll, yaw = single_data_point[['stateEstimate.roll', 'stateEstimate.yaw']].values.T
|
227 |
+
single_data_point['combined_orientation'] = roll + yaw
|
228 |
+
single_data_point['roll_sin'] = np.sin(np.radians(roll))
|
229 |
+
single_data_point['roll_cos'] = np.cos(np.radians(roll))
|
230 |
+
single_data_point['yaw_sin'] = np.sin(np.radians(yaw))
|
231 |
+
single_data_point['yaw_cos'] = np.cos(np.radians(yaw))
|
232 |
+
|
233 |
+
features_to_keep = list(top_features.intersection(single_data_point.columns))
|
234 |
+
return single_data_point[features_to_keep + ['true_label']]
|
235 |
+
except Exception as e:
|
236 |
+
raise gr.Error(f"Failed to preprocess single data point: {e}")
|
237 |
+
|
238 |
+
def predict(file_path):
|
239 |
+
try:
|
240 |
+
with open(file_path, 'rb') as file:
|
241 |
+
log_df = process_log_file(file)
|
242 |
+
if log_df is not None:
|
243 |
+
single_data_point = log_df.sample(1)
|
244 |
+
preprocessed_data_point = preprocess_single_data_point(single_data_point)
|
245 |
+
if preprocessed_data_point is not None:
|
246 |
+
X = preprocessed_data_point.drop(columns=['true_label'])
|
247 |
+
y = preprocessed_data_point['true_label']
|
248 |
+
|
249 |
+
X_ordered = X[saved_feature_names]
|
250 |
+
X_scaled = scaler.transform(X_ordered)
|
251 |
+
X_tensor = torch.tensor(X_scaled, dtype=torch.float32).to(device)
|
252 |
+
|
253 |
+
with torch.no_grad():
|
254 |
+
logits = model(X_tensor)
|
255 |
+
probabilities = F.softmax(logits, dim=1)
|
256 |
+
confidence_scores, predicted_classes = torch.max(probabilities, dim=1)
|
257 |
+
|
258 |
+
predicted_labels = label_encoder.inverse_transform(predicted_classes.cpu().numpy())
|
259 |
+
confidence_scores = confidence_scores.cpu().numpy()
|
260 |
+
|
261 |
+
predicted_label_value = predicted_labels[0]
|
262 |
+
predicted_label_key = [k for k, v in LabelsMap.items() if v == predicted_label_value][0]
|
263 |
+
label_confidence_pairs = f"{predicted_label_key}: {predicted_label_value} (Confidence: {confidence_scores[0]:.4f})"
|
264 |
+
|
265 |
+
# Retrieve the corresponding image and comment using the key name
|
266 |
+
defect_info = defect_image_map.get(predicted_label_key, {"image": "images/Placeholder.png", "comment": "No information available."})
|
267 |
+
image_path = defect_info["image"]
|
268 |
+
comment = defect_info["comment"]
|
269 |
+
|
270 |
+
return image_path, f"{label_confidence_pairs}\n\nComment: {comment}"
|
271 |
+
else:
|
272 |
+
raise gr.Warning("Log file processing returned no data.")
|
273 |
+
except Exception as e:
|
274 |
+
raise gr.Error(f"Failed to process file: {e}")
|
275 |
+
|
276 |
+
return None, "Failed to process file"
|
277 |
+
|
278 |
+
# Gradio interface
|
279 |
+
with gr.Blocks() as demo:
|
280 |
+
gr.Markdown("## Fault Detection in Nano-Quadcopter")
|
281 |
+
gr.Markdown("This interface classifies faults in a nano-quadcopter using a deep neural network model.")
|
282 |
+
|
283 |
+
with gr.Row():
|
284 |
+
with gr.Column():
|
285 |
+
example_dropdown = gr.Dropdown(
|
286 |
+
choices=["Extra Weight", "Propeller Cut", "Tape on Propeller", "Normal Flight"],
|
287 |
+
label="Select Fault Type"
|
288 |
+
)
|
289 |
+
submit_btn = gr.Button("Classify")
|
290 |
+
|
291 |
+
with gr.Column():
|
292 |
+
image_output = gr.Image(type="filepath", label="Corresponding Image")
|
293 |
+
label_output = gr.Textbox(label="Predicted Label and Confidence Score")
|
294 |
+
|
295 |
+
def classify_example(example):
|
296 |
+
try:
|
297 |
+
file_path = get_log_file_path(example)
|
298 |
+
if file_path:
|
299 |
+
file_path = file_path
|
300 |
+
image_path, label_and_comment = predict(file_path)
|
301 |
+
return image_path, label_and_comment
|
302 |
+
else:
|
303 |
+
raise gr.Error("No matching log file found.")
|
304 |
+
except KeyError as e:
|
305 |
+
raise gr.Error(f"Error: {e}")
|
306 |
+
|
307 |
+
submit_btn.click(
|
308 |
+
fn=classify_example,
|
309 |
+
inputs=[example_dropdown],
|
310 |
+
outputs=[image_output, label_output],
|
311 |
+
)
|
312 |
+
|
313 |
+
# Launch the app
|
314 |
+
if __name__ == "__main__":
|
315 |
+
demo.launch(share=True, debug=True)
|