salohnana2018 commited on
Commit
9b66001
·
verified ·
1 Parent(s): 22489d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +315 -0
app.py CHANGED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from zlib import crc32
2
+ import struct
3
+ import gradio as gr
4
+ import os
5
+ import pandas as pd
6
+ import numpy as np
7
+ import joblib
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ # Define top features
13
+ top_features = set([
14
+ 'pm.vbatMV', 'stateEstimate.z', 'motor.m3', 'stateEstimate.yaw', 'yaw_cos',
15
+ 'motor.m2', 'stateEstimate.y', 'stateEstimate.x', 'motor.m1', 'theta',
16
+ 'motor.m4', 'position_magnitude', 'combined_orientation', 'pwm.m3_pwm',
17
+ 'stateEstimate.roll', 'phi', 'pwm.m2_pwm', 'roll_cos', 'vx_cosine',
18
+ 'stateEstimate.vx', 'velocity_magnitude', 'stateEstimate.vy', 'pwm.m4_pwm',
19
+ 'stateEstimate.vz', 'pwm.m1_pwm'
20
+ ])
21
+
22
+ # Load the median values from the CSV once
23
+ feature_medians = pd.read_csv("model/feature_medians.csv")
24
+ medians_dict = feature_medians.set_index('Feature')['Median'].to_dict()
25
+
26
+ # Load the label encoder, scaler, and saved feature names
27
+ label_encoder = joblib.load('model/label_encoder.pkl')
28
+ scaler = joblib.load('model/scaler.pkl')
29
+ saved_feature_names = joblib.load('model/feature_names.pkl')
30
+
31
+ # Define the EnhancedFaultDetectionNN model
32
+ class EnhancedFaultDetectionNN(nn.Module):
33
+ def __init__(self, input_size, output_size, dropout_prob=0.08):
34
+ super(EnhancedFaultDetectionNN, self).__init__()
35
+
36
+ self.fc1 = nn.Linear(input_size, 1024)
37
+ self.bn1 = nn.BatchNorm1d(1024)
38
+ self.fc2 = nn.Linear(1024, 512)
39
+ self.bn2 = nn.BatchNorm1d(512)
40
+ self.fc3 = nn.Linear(512, 256)
41
+ self.bn3 = nn.BatchNorm1d(256)
42
+ self.fc4 = nn.Linear(256, output_size)
43
+ self.dropout = nn.Dropout(dropout_prob)
44
+
45
+ def forward(self, x):
46
+ x = F.relu(self.bn1(self.fc1(x)))
47
+ x = self.dropout(x)
48
+ x = F.relu(self.bn2(self.fc2(x)))
49
+ x = self.dropout(x)
50
+ x = F.relu(self.bn3(self.fc3(x)))
51
+ x = self.dropout(x)
52
+ x = self.fc4(x)
53
+ return x
54
+
55
+ # Load the PyTorch model
56
+ model_path = 'model/best_model_without_oversampling128.pth'
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ input_size = len(saved_feature_names)
59
+ output_size = len(label_encoder.classes_)
60
+
61
+ model = EnhancedFaultDetectionNN(input_size, output_size).to(device)
62
+ model.load_state_dict(torch.load(model_path, map_location=device))
63
+ model.eval()
64
+
65
+ # Mapping of fault types to corresponding images and comments
66
+ defect_image_map = {
67
+
68
+ "Extra Weight": {
69
+ "image": "images/weight.png",
70
+ "comment": "A weight added near the M3 motor causes lift imbalance."
71
+ },
72
+ "Propeller Cut": {
73
+ "image": "images/propeller_cut.png",
74
+ "comment": "A cut on the M2 propeller reduces thrust and causes instability."
75
+ },
76
+ "Tape on Propeller": {
77
+ "image": "images/tape.png",
78
+ "comment": "Tape on the M3 propeller leads to imbalance, drag, and vibrations, reducing stability."
79
+ },
80
+ "Normal Flight": {
81
+ "image": "images/normal_flight.png",
82
+ "comment": "The quadcopter operates normally with balanced thrust and stability."
83
+ },
84
+ }
85
+
86
+ # List of log files corresponding to the fault types
87
+ log_files = [
88
+ "Logs_Samples/add_weight_W1_near_M3_E9_log04",
89
+ "Logs_Samples/cut_M2_0.5мм_46.5мм_E9_log02",
90
+ "Logs_Samples/tape_on_propeller_M3_E9_log01",
91
+ "Logs_Samples/normal_flight_E8_log03"
92
+ ]
93
+
94
+ # Mapping simplified labels to their corresponding folder names
95
+ LabelsMap = {
96
+ "Extra Weight": "add_weight_W1_near_M3",
97
+ "Propeller Cut": "cut_M2_0.5мм_46.5мм",
98
+ "Tape on Propeller": "tape_on_propeller_M3",
99
+ "Normal Flight": "normal_flight"
100
+ }
101
+
102
+ # Function to retrieve the log file path using LabelsMap and log_files
103
+ def get_log_file_path(label_key):
104
+ label_value = LabelsMap[label_key]
105
+ for log_file in log_files:
106
+ if label_value in log_file:
107
+ return log_file
108
+ return None # Return None if no matching file is found
109
+
110
+ def get_name(data, idx):
111
+ end_idx = idx
112
+ while data[end_idx] != 0:
113
+ end_idx += 1
114
+ return data[idx:end_idx].decode("utf-8"), end_idx + 1
115
+
116
+ def cfusdlog_decode(file):
117
+ data = file.read()
118
+
119
+ if data[0] != 0xBC:
120
+ raise gr.Error("Invalid file format: Magic header not found.")
121
+
122
+ crc = crc32(data[0:-4])
123
+ expected_crc, = struct.unpack('I', data[-4:])
124
+ if crc != expected_crc:
125
+ raise gr.Error("File integrity check failed: CRC mismatch.")
126
+
127
+ version, num_event_types = struct.unpack('HH', data[1:5])
128
+ if version not in [1, 2]:
129
+ raise gr.Error(f"Unsupported log file version: {version}")
130
+
131
+ result = {}
132
+ event_by_id = {}
133
+ idx = 5
134
+
135
+ for _ in range(num_event_types):
136
+ event_id, = struct.unpack('H', data[idx:idx+2])
137
+ idx += 2
138
+ event_name, idx = get_name(data, idx)
139
+ result[event_name] = {'timestamp': []}
140
+ num_variables, = struct.unpack('H', data[idx:idx+2])
141
+ idx += 2
142
+ fmt_str = "<"
143
+ variables = []
144
+ for _ in range(num_variables):
145
+ var_name_and_type, idx = get_name(data, idx)
146
+ var_name = var_name_and_type[:-3]
147
+ var_type = var_name_and_type[-2]
148
+ result[event_name][var_name] = []
149
+ fmt_str += var_type
150
+ variables.append(var_name)
151
+ event_by_id[event_id] = {
152
+ 'name': event_name,
153
+ 'fmt_str': fmt_str,
154
+ 'num_bytes': struct.calcsize(fmt_str),
155
+ 'variables': variables,
156
+ }
157
+
158
+ while idx < len(data) - 4:
159
+ if version == 1:
160
+ event_id, timestamp = struct.unpack('<HI', data[idx:idx+6])
161
+ idx += 6
162
+ elif version == 2:
163
+ event_id, timestamp = struct.unpack('<HQ', data[idx:idx+10])
164
+ timestamp /= 1000.0
165
+ idx += 10
166
+ event = event_by_id[event_id]
167
+ event_data = struct.unpack(event['fmt_str'], data[idx:idx+event['num_bytes']])
168
+ idx += event['num_bytes']
169
+ for var, value in zip(event['variables'], event_data):
170
+ result[event['name']][var].append(value)
171
+ result[event['name']]['timestamp'].append(timestamp)
172
+
173
+ for event_name, event_data in result.items():
174
+ for var_name, var_data in event_data.items():
175
+ result[event_name][var_name] = np.array(var_data)
176
+
177
+ return {k: v for k, v in result.items() if len(v['timestamp']) > 0} # Ensure that only non-empty timestamps are kept
178
+
179
+ def fix_time(log_data):
180
+ try:
181
+ timestamps = log_data["timestamp"]
182
+ if len(timestamps) == 0:
183
+ raise gr.Error("Timestamp data is empty.")
184
+ first_value = timestamps[0]
185
+ log_data["timestamp"] = [t - first_value for t in timestamps]
186
+ except KeyError:
187
+ raise gr.Error("Timestamp key not found in the log data.")
188
+ except Exception as e:
189
+ raise gr.Error(f"Failed to adjust timestamps: {e}")
190
+
191
+ def process_log_file(file):
192
+ try:
193
+ log_data = cfusdlog_decode(file)
194
+ log_data = log_data.get('fixedFrequency', {})
195
+ if not log_data:
196
+ raise gr.Warning(f"No 'fixedFrequency' data found in the log file")
197
+
198
+ fix_time(log_data)
199
+ parent_dir_name = os.path.basename(os.path.dirname(file.name))
200
+ log_data["true_label"] = [parent_dir_name] * len(log_data.get("timestamp", []))
201
+
202
+ df = pd.DataFrame(log_data)
203
+ return df
204
+ except Exception as e:
205
+ raise gr.Error(f"Failed to process log file: {e}")
206
+
207
+ def preprocess_single_data_point(single_data_point):
208
+ try:
209
+ if 'timestamp' in single_data_point.columns:
210
+ single_data_point.drop(columns=["timestamp"], inplace=True)
211
+
212
+ single_data_point.fillna(medians_dict, inplace=True)
213
+
214
+ state_x, state_y, state_z = single_data_point[['stateEstimate.x', 'stateEstimate.y', 'stateEstimate.z']].values.T
215
+ single_data_point['r'] = np.sqrt(state_x**2 + state_y**2 + state_z**2)
216
+ single_data_point['theta'] = np.arccos(np.clip(single_data_point['stateEstimate.z'] / single_data_point['r'], -1.0, 1.0)) # Clip to avoid invalid values
217
+ single_data_point['phi'] = np.arctan2(single_data_point['stateEstimate.y'], single_data_point['stateEstimate.x'])
218
+ single_data_point['position_magnitude'] = single_data_point['r']
219
+
220
+ velocity_x, velocity_y, velocity_z = single_data_point[['stateEstimate.vx', 'stateEstimate.vy', 'stateEstimate.vz']].values.T
221
+ single_data_point['velocity_magnitude'] = np.sqrt(velocity_x**2 + velocity_y**2 + velocity_z**2)
222
+ single_data_point['vx_cosine'] = np.divide(velocity_x, single_data_point['velocity_magnitude'], out=np.zeros_like(velocity_x), where=single_data_point['velocity_magnitude']!=0)
223
+ single_data_point['vy_cosine'] = np.divide(velocity_y, single_data_point['velocity_magnitude'], out=np.zeros_like(velocity_y), where=single_data_point['velocity_magnitude']!=0)
224
+ single_data_point['vz_cosine'] = np.divide(velocity_z, single_data_point['velocity_magnitude'], out=np.zeros_like(velocity_z), where=single_data_point['velocity_magnitude']!=0)
225
+
226
+ roll, yaw = single_data_point[['stateEstimate.roll', 'stateEstimate.yaw']].values.T
227
+ single_data_point['combined_orientation'] = roll + yaw
228
+ single_data_point['roll_sin'] = np.sin(np.radians(roll))
229
+ single_data_point['roll_cos'] = np.cos(np.radians(roll))
230
+ single_data_point['yaw_sin'] = np.sin(np.radians(yaw))
231
+ single_data_point['yaw_cos'] = np.cos(np.radians(yaw))
232
+
233
+ features_to_keep = list(top_features.intersection(single_data_point.columns))
234
+ return single_data_point[features_to_keep + ['true_label']]
235
+ except Exception as e:
236
+ raise gr.Error(f"Failed to preprocess single data point: {e}")
237
+
238
+ def predict(file_path):
239
+ try:
240
+ with open(file_path, 'rb') as file:
241
+ log_df = process_log_file(file)
242
+ if log_df is not None:
243
+ single_data_point = log_df.sample(1)
244
+ preprocessed_data_point = preprocess_single_data_point(single_data_point)
245
+ if preprocessed_data_point is not None:
246
+ X = preprocessed_data_point.drop(columns=['true_label'])
247
+ y = preprocessed_data_point['true_label']
248
+
249
+ X_ordered = X[saved_feature_names]
250
+ X_scaled = scaler.transform(X_ordered)
251
+ X_tensor = torch.tensor(X_scaled, dtype=torch.float32).to(device)
252
+
253
+ with torch.no_grad():
254
+ logits = model(X_tensor)
255
+ probabilities = F.softmax(logits, dim=1)
256
+ confidence_scores, predicted_classes = torch.max(probabilities, dim=1)
257
+
258
+ predicted_labels = label_encoder.inverse_transform(predicted_classes.cpu().numpy())
259
+ confidence_scores = confidence_scores.cpu().numpy()
260
+
261
+ predicted_label_value = predicted_labels[0]
262
+ predicted_label_key = [k for k, v in LabelsMap.items() if v == predicted_label_value][0]
263
+ label_confidence_pairs = f"{predicted_label_key}: {predicted_label_value} (Confidence: {confidence_scores[0]:.4f})"
264
+
265
+ # Retrieve the corresponding image and comment using the key name
266
+ defect_info = defect_image_map.get(predicted_label_key, {"image": "images/Placeholder.png", "comment": "No information available."})
267
+ image_path = defect_info["image"]
268
+ comment = defect_info["comment"]
269
+
270
+ return image_path, f"{label_confidence_pairs}\n\nComment: {comment}"
271
+ else:
272
+ raise gr.Warning("Log file processing returned no data.")
273
+ except Exception as e:
274
+ raise gr.Error(f"Failed to process file: {e}")
275
+
276
+ return None, "Failed to process file"
277
+
278
+ # Gradio interface
279
+ with gr.Blocks() as demo:
280
+ gr.Markdown("## Fault Detection in Nano-Quadcopter")
281
+ gr.Markdown("This interface classifies faults in a nano-quadcopter using a deep neural network model.")
282
+
283
+ with gr.Row():
284
+ with gr.Column():
285
+ example_dropdown = gr.Dropdown(
286
+ choices=["Extra Weight", "Propeller Cut", "Tape on Propeller", "Normal Flight"],
287
+ label="Select Fault Type"
288
+ )
289
+ submit_btn = gr.Button("Classify")
290
+
291
+ with gr.Column():
292
+ image_output = gr.Image(type="filepath", label="Corresponding Image")
293
+ label_output = gr.Textbox(label="Predicted Label and Confidence Score")
294
+
295
+ def classify_example(example):
296
+ try:
297
+ file_path = get_log_file_path(example)
298
+ if file_path:
299
+ file_path = file_path
300
+ image_path, label_and_comment = predict(file_path)
301
+ return image_path, label_and_comment
302
+ else:
303
+ raise gr.Error("No matching log file found.")
304
+ except KeyError as e:
305
+ raise gr.Error(f"Error: {e}")
306
+
307
+ submit_btn.click(
308
+ fn=classify_example,
309
+ inputs=[example_dropdown],
310
+ outputs=[image_output, label_output],
311
+ )
312
+
313
+ # Launch the app
314
+ if __name__ == "__main__":
315
+ demo.launch(share=True, debug=True)