taarhissian commited on
Commit
5f57049
·
verified ·
1 Parent(s): 82f696d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -30
app.py CHANGED
@@ -2,56 +2,85 @@ import gradio as gr
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.ensemble import IsolationForest
5
- from sklearn.svm import OneClassSVM
 
 
6
  import matplotlib.pyplot as plt
7
  from PIL import Image
 
8
 
9
- # Initialize models
10
- models = {
11
- "Isolation Forest": IsolationForest(contamination=0.1),
12
- "One-Class SVM": OneClassSVM(nu=0.1)
13
- }
14
 
15
- def detect_anomalies(data_input, file_input, algorithm):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  try:
17
  # Load data from file if provided
18
- if file_input is not None:
19
- file_ext = file_input.name.split('.')[-1]
20
- if file_ext in ['xls', 'xlsx']:
21
- data = pd.read_excel(file_input.name).values
22
- elif file_ext == 'csv':
23
- data = pd.read_csv(file_input.name).values
24
  else:
25
- return {"Error": f"Unsupported file type: {file_ext}"}, None
26
  else:
27
  # Use manual input if no file provided
28
- data = np.array([float(x) for x in data_input.split(",")]).reshape(-1, 1)
29
 
30
- # Select the chosen model
31
- model = models[algorithm]
32
 
33
  # Fit model and predict anomalies
34
  model.fit(data)
35
  predictions = model.predict(data)
36
 
37
  # Extract anomalies
38
- anomalies = data[predictions == -1].flatten()
39
 
40
  # Create plot
41
  plt.figure(figsize=(8, 6))
42
- plt.plot(data, 'bo', label="Data Points")
43
- plt.plot(np.where(predictions == -1)[0], anomalies, 'ro', label="Anomalies")
44
- plt.title(f"Anomaly Detection using {algorithm}")
45
  plt.xlabel("Data Index")
46
  plt.ylabel("Data Value")
47
  plt.legend()
48
  plt.tight_layout()
49
 
50
- # Save plot to image
51
- plt.savefig("temp_plot.png")
52
- plot_image = Image.open("temp_plot.png")
 
 
 
 
 
 
 
53
 
54
- return {"Anomalies": anomalies.tolist(), "Total Anomalies": len(anomalies)}, plot_image
55
 
56
  except Exception as e:
57
  return {"Error": str(e)}, None
@@ -60,16 +89,15 @@ def detect_anomalies(data_input, file_input, algorithm):
60
  iface = gr.Interface(
61
  fn=detect_anomalies,
62
  inputs=[
63
- gr.Textbox(label="Enter numbers separated by commas"),
64
- gr.File(label="Upload Data File"),
65
- gr.Radio(choices=["Isolation Forest", "One-Class SVM"], label="Select Algorithm")
66
  ],
67
  outputs=[
68
  "json",
69
  gr.Image(type="pil", label="Anomaly Plot")
70
  ],
71
- title="Anomaly Detection Application",
72
- description="Upload a dataset or enter numbers to detect anomalies using the selected algorithm."
73
  )
74
 
75
  # Launch the interface
 
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.ensemble import IsolationForest
5
+ from sklearn.preprocessing import LabelEncoder
6
+ import matplotlib
7
+ matplotlib.use('Agg') # Use 'Agg' backend for compatibility
8
  import matplotlib.pyplot as plt
9
  from PIL import Image
10
+ import os
11
 
12
+ # Initialize the Isolation Forest model
13
+ model = IsolationForest(contamination=0.1, random_state=42)
 
 
 
14
 
15
+ def preprocess_data(data):
16
+ """
17
+ Preprocess the input data:
18
+ - Handle missing values
19
+ - Encode categorical variables
20
+ - Ensure all data is numeric
21
+ """
22
+ # Handle missing values by filling with the mean of the column
23
+ data = data.fillna(data.mean(numeric_only=True))
24
+
25
+ # Encode categorical variables
26
+ label_encoders = {}
27
+ for column in data.select_dtypes(include=['object']).columns:
28
+ le = LabelEncoder()
29
+ data[column] = le.fit_transform(data[column])
30
+ label_encoders[column] = le
31
+
32
+ # Ensure all data is numeric
33
+ data = data.apply(pd.to_numeric, errors='coerce')
34
+
35
+ return data, label_encoders
36
+
37
+ def detect_anomalies(data_input, file_input):
38
  try:
39
  # Load data from file if provided
40
+ if file_input:
41
+ file_extension = os.path.splitext(file_input.name)[-1].lower()
42
+ if file_extension == ".csv":
43
+ data = pd.read_csv(file_input.name)
44
+ elif file_extension == ".json":
45
+ data = pd.read_json(file_input.name)
46
  else:
47
+ return {"Error": f"Unsupported file type: {file_extension}"}, None
48
  else:
49
  # Use manual input if no file provided
50
+ data = pd.DataFrame([x.split(",") for x in data_input.split("\n")])
51
 
52
+ # Preprocess the data
53
+ data, label_encoders = preprocess_data(data)
54
 
55
  # Fit model and predict anomalies
56
  model.fit(data)
57
  predictions = model.predict(data)
58
 
59
  # Extract anomalies
60
+ anomalies = data[predictions == -1]
61
 
62
  # Create plot
63
  plt.figure(figsize=(8, 6))
64
+ plt.plot(data.index, data, 'o', label="Data Points")
65
+ plt.plot(anomalies.index, anomalies, 'ro', label="Anomalies")
66
+ plt.title("Anomaly Detection Results")
67
  plt.xlabel("Data Index")
68
  plt.ylabel("Data Value")
69
  plt.legend()
70
  plt.tight_layout()
71
 
72
+ # Save plot to a temporary file
73
+ temp_plot_path = "temp_plot.png"
74
+ plt.savefig(temp_plot_path)
75
+ plt.close()
76
+
77
+ # Convert plot to PIL Image for Gradio
78
+ plot_image = Image.open(temp_plot_path)
79
+
80
+ # Clean up the temporary plot file
81
+ os.remove(temp_plot_path)
82
 
83
+ return {"Anomalies": anomalies.to_dict(orient='records'), "Total Anomalies": len(anomalies)}, plot_image
84
 
85
  except Exception as e:
86
  return {"Error": str(e)}, None
 
89
  iface = gr.Interface(
90
  fn=detect_anomalies,
91
  inputs=[
92
+ gr.Textbox(label="Enter data (comma-separated values, one row per line)"),
93
+ gr.File(label="Upload Data File")
 
94
  ],
95
  outputs=[
96
  "json",
97
  gr.Image(type="pil", label="Anomaly Plot")
98
  ],
99
+ title="Anomaly Detector",
100
+ description="Enter a series of numbers or upload a data file to detect anomalies using Isolation Forest."
101
  )
102
 
103
  # Launch the interface