taarhissian's picture
Update app.py
5f57049 verified
raw
history blame
3.27 kB
import gradio as gr
import numpy as np
import pandas as pd
from sklearn.ensemble import IsolationForest
from sklearn.preprocessing import LabelEncoder
import matplotlib
matplotlib.use('Agg') # Use 'Agg' backend for compatibility
import matplotlib.pyplot as plt
from PIL import Image
import os
# Initialize the Isolation Forest model
model = IsolationForest(contamination=0.1, random_state=42)
def preprocess_data(data):
"""
Preprocess the input data:
- Handle missing values
- Encode categorical variables
- Ensure all data is numeric
"""
# Handle missing values by filling with the mean of the column
data = data.fillna(data.mean(numeric_only=True))
# Encode categorical variables
label_encoders = {}
for column in data.select_dtypes(include=['object']).columns:
le = LabelEncoder()
data[column] = le.fit_transform(data[column])
label_encoders[column] = le
# Ensure all data is numeric
data = data.apply(pd.to_numeric, errors='coerce')
return data, label_encoders
def detect_anomalies(data_input, file_input):
try:
# Load data from file if provided
if file_input:
file_extension = os.path.splitext(file_input.name)[-1].lower()
if file_extension == ".csv":
data = pd.read_csv(file_input.name)
elif file_extension == ".json":
data = pd.read_json(file_input.name)
else:
return {"Error": f"Unsupported file type: {file_extension}"}, None
else:
# Use manual input if no file provided
data = pd.DataFrame([x.split(",") for x in data_input.split("\n")])
# Preprocess the data
data, label_encoders = preprocess_data(data)
# Fit model and predict anomalies
model.fit(data)
predictions = model.predict(data)
# Extract anomalies
anomalies = data[predictions == -1]
# Create plot
plt.figure(figsize=(8, 6))
plt.plot(data.index, data, 'o', label="Data Points")
plt.plot(anomalies.index, anomalies, 'ro', label="Anomalies")
plt.title("Anomaly Detection Results")
plt.xlabel("Data Index")
plt.ylabel("Data Value")
plt.legend()
plt.tight_layout()
# Save plot to a temporary file
temp_plot_path = "temp_plot.png"
plt.savefig(temp_plot_path)
plt.close()
# Convert plot to PIL Image for Gradio
plot_image = Image.open(temp_plot_path)
# Clean up the temporary plot file
os.remove(temp_plot_path)
return {"Anomalies": anomalies.to_dict(orient='records'), "Total Anomalies": len(anomalies)}, plot_image
except Exception as e:
return {"Error": str(e)}, None
# Gradio Interface
iface = gr.Interface(
fn=detect_anomalies,
inputs=[
gr.Textbox(label="Enter data (comma-separated values, one row per line)"),
gr.File(label="Upload Data File")
],
outputs=[
"json",
gr.Image(type="pil", label="Anomaly Plot")
],
title="Anomaly Detector",
description="Enter a series of numbers or upload a data file to detect anomalies using Isolation Forest."
)
# Launch the interface
iface.launch()