tgd1115's picture
Upload app.py
c31d853 verified
raw
history blame
20.7 kB
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import numpy as np
from sklearn.preprocessing import StandardScaler
from dataclasses import dataclass
from datetime import datetime
import torch
import torch.nn as nn
import os
from torch.utils.data import DataLoader, TensorDataset
from path_config import MODEL_DIR
from pipeline import Transformer
@dataclass()
class NYCTaxiAnomalyDetector:
def __init__(self, data):
self.data = data.copy()
self.scaler = StandardScaler()
self.model = None
self.TRANSFORMER_S_MODEL_PATH = os.path.join(
MODEL_DIR, "transformer_model_small.pth"
)
def create_sequences(self, data, seq_length=24):
"""Create sequences for the transformer model"""
sequences = []
values = data.reshape(-1, 1)
for i in range(len(values) - seq_length + 1):
sequences.append(values[i : i + seq_length])
return np.array(sequences)
def filter_by_date_range(self, start_date, end_date):
"""
Filter data by specified date range
:param start_date: Start date of the range
:param end_date: End date of the range
:return: Filtered DataFrame
"""
# Ensure date column is datetime
if not pd.api.types.is_datetime64_any_dtype(self.data["date"]):
self.data["date"] = pd.to_datetime(self.data["date"])
# Filter data
filtered_data = self.data[
(self.data["date"] >= start_date) & (self.data["date"] <= end_date)
]
return filtered_data
def preprocess_data(self, data, column):
"""
Preprocess data for anomaly detection
:param data: Filtered DataFrame
:param column: Column to detect anomalies in
:return: Scaled data and original index
"""
# Ensure the column is numeric
data.loc[:, column] = pd.to_numeric(data[column], errors="coerce")
# Remove NaN values
clean_data = data[column].dropna()
# Scale the data
scaled_data = self.scaler.fit_transform(clean_data.values.reshape(-1, 1))
sequences = self.create_sequences(scaled_data)
return sequences, clean_data.index[23:]
def detect_anomalies(self, data, column, contamination=0.05):
"""
Detect anomalies using Isolation Forest
:param data: Filtered DataFrame
:param column: Column to detect anomalies in
:param contamination: Expected proportion of outliers
:return: DataFrame with anomaly detection results
"""
if self.model is None:
self.model = Transformer()
self.model.load_state_dict(
torch.load(self.TRANSFORMER_S_MODEL_PATH, weights_only=True)
)
self.model.eval()
# Preprocess data
sequences, original_index = self.preprocess_data(data, column)
# Create DataLoader
dataset = TensorDataset(torch.FloatTensor(sequences))
test_loader = DataLoader(dataset, batch_size=32, shuffle=False)
# Calculate threshold percentile from contamination
threshold_percentile = (1 - contamination) * 100
# Detect anomalies
reconstruction_errors, predictions, anomalies, optimal_threshold = (
self.detect_anomalies_batch(
self.model, test_loader, threshold_percentile=threshold_percentile
)
)
# Create results DataFrame
anomaly_results = pd.DataFrame(
{
"date": data.loc[original_index, "date"],
column: data.loc[original_index, column],
"is_anomaly": anomalies,
"reconstruction_error": reconstruction_errors,
"prediction": predictions,
}
)
return anomaly_results
def detect_anomalies_batch(self, model, test_loader, threshold_percentile=99.7):
"""Detect anomalies in batches"""
reconstruction_errors = []
predictions = []
with torch.no_grad():
for seq_true in test_loader:
x = seq_true[0] # Remove extra dimension from TensorDataset
pred = model(x)
# Calculate reconstruction error for each sequence
errors = torch.mean(
torch.abs(pred - x), dim=(1, 2)
) # Mean over sequence length and features
reconstruction_errors.extend(errors.cpu().numpy())
predictions.extend(
pred[:, -1, 0].cpu().numpy()
) # Take last timestep prediction
reconstruction_errors = np.array(reconstruction_errors)
predictions = np.array(predictions)
optimal_threshold = np.percentile(reconstruction_errors, threshold_percentile)
anomalies = (reconstruction_errors > optimal_threshold).astype(int)
return reconstruction_errors, predictions, anomalies, optimal_threshold
@dataclass()
class AIContextGenerator:
predefined_anomalies = {
datetime(2014, 7, 4).date(): [
{
"type": "USA's 238th Independence Day",
"description": "The Fourth of July in 2014 marked the USA's 238th Independence Day, celebrated with widespread fireworks as The Macy's 4th of July Fireworks returned to the East River, attracting thousands of locals and tourists to the waterfront for one of the city's largest annual displays. The surge in activity led to increased taxi ridership as residents and tourists traveled to and from celebrations.",
"reference": {
"text": "USA's 238th Independence Day 2014",
"url": "https://www.theguardian.com/world/2014/jul/05/fourth-july-independence-day-america-pictures",
},
}
],
datetime(2014, 7, 6).date(): [
{
"type": "Long Holiday for USA's 238th Independence Day",
"description": "There was a long holiday this weekend due to the Independence Day celebrations, which explains why taxi ridership saw a sustained increase as the people took the holiday to their advantage.",
"reference": {
"text": "Long Holiday for USA's 238th Independence Day",
"url": "https://www.cbsnews.com/pictures/independence-day-2014/2/",
},
}
],
datetime(2014, 9, 1).date(): [
{
"type": "Labor Day",
"description": "Labor Day 2014 marked the unofficial end of summer, with many New Yorkers and tourists enjoying the long weekend by doing outdoor activities, shopping and even traveling. This holiday lead to increased taxi ridership as people attended parades, visited parks and took advantage of end-of-summer sales across the city.",
"reference": {
"text": "Labor Day 2014",
"url": "https://www.nycclc.org/gallery/2014-nyc-labor-day-parade-video",
},
}
],
datetime(2014, 11, 2).date(): [
{
"type": "New York City (NYC) Marathon 2014",
"description": "The 2014 NYC Marathon, the largest in history with 50,869 starters and 50,564 finishers, significantly increased taxi ridership due to the influx of participants and spectators.",
"reference": {
"text": "NYC Marathon 2014",
"url": "https://en.wikipedia.org/wiki/2014_New_York_City_Marathon",
},
}
],
datetime(2014, 11, 27).date(): [
{
"type": "Thanksgiving Day",
"description": "Thanksgiving Day 2014 featured the iconic Macys Thanksgiving Day Parade, with giant balloons, floats, and performances drawing millions of spectators to the streets of New York City. The event, along with holiday travel and family gatherings, led to a significant increase in taxi ridership throughout the day.",
"reference": {
"text": "Thanksgiving Day 2014",
"url": "https://www.cbsnews.com/pictures/macys-thanksgiving-day-parade-2014/",
},
}
],
datetime(2014, 11, 28).date(): [
{
"type": "Post-Thanksgiving Day",
"description": "The day after Thanksgiving in 2014 saw a surge in activity across New York City as residents and tourists participated in Black Friday shopping, visited family, and enjoyed extended holiday festivities. This led to increased taxi ridership as people traveled to retail hubs, restaurants, and other destinations throughout the city.",
"reference": {
"text": "Thanksgiving Day 2014",
"url": "https://www.cbsnews.com/pictures/macys-thanksgiving-day-parade-2014/#:~:text=Macy's%20Parade&text=The%20Wicked%20Witch%20of%20the,million%20people%20were%20in%20attendance.",
},
}
],
datetime(2014, 12, 24).date(): [
{
"type": "Christmas Eve",
"description": "Christmas Eve 2014 in New York City was marked by festive celebrations, last-minute shopping, and gatherings with family and friends. The holiday spirit led to increased taxi ridership as people traveled to stores, restaurants, and holiday events across the city.",
"reference": {
"text": "Christmas Eve 2014",
"url": "https://archive.nytimes.com/cityroom.blogs.nytimes.com/2014/12/31/live-video-new-years-eve-in-times-square/",
},
}
],
datetime(2014, 12, 25).date(): [
{
"type": "Christmas Day",
"description": "Christmas Day 2014 in New York City was a time of festive gatherings, family celebrations, and holiday cheer. With many restaurants, attractions, and public spaces open, taxi ridership saw an increase as residents and tourists traveled to visit loved ones, attend holiday events, and enjoy the city's festive atmosphere.",
"reference": {
"text": "Christmas Day 2014",
"url": "https://storiesmysuitcasecouldtell.com/2014/12/24/a-merry-christmas-in-new-york-city/",
},
}
],
datetime(2014, 12, 26).date(): [
{
"type": "Boxing Day",
"description": "The day after Christmas in 2014 which is the Boxing Day saw continued holiday activity in New York City, with residents and tourists taking advantage of post-holiday sales, returning gifts, and enjoying extended celebrations. This led to increased taxi ridership as people traveled to shopping centers, restaurants, and entertainment venues across the city.",
"reference": None,
}
],
datetime(2014, 12, 31).date(): [
{
"type": "New Year's Eve",
"description": "New Year’s Eve 2014 in New York City was marked by the iconic Times Square Ball Drop, drawing over a million spectators to the area and millions more watching worldwide. The celebrations, along with parties and events across the city, led to a significant increase in taxi ridership as residents and tourists traveled to and from festivities.",
"reference": {
"text": "New Year's Eve",
"url": "https://abcnews.go.com/US/2014-years-eve-times-square-numbers/story?id=27929342",
},
}
],
datetime(2015, 1, 1).date(): [
{
"type": "New Year's Day",
"description": "New Year’s Day 2015 in New York City saw a continuation of celebrations from the previous night, with many residents and tourists recovering from festivities or attending brunches, parades, and family gatherings. The increased activity led to higher taxi ridership as people traveled across the city to celebrate the start of the new year.",
"reference": {
"text": "New Year's Day 2015",
"url": "https://www.cbsnews.com/newyork/pictures/new-years-2015-in-times-square/",
},
}
],
datetime(2015, 1, 26).date(): [
{
"type": "Winter Snow Juno",
"description": "Winter Storm Juno, a historic blizzard, hit New York City on January 26-27, 2015, bringing heavy snowfall, strong winds, and widespread disruptions. The storm led to a sharp decline in taxi ridership as travel became hazardous, with many roads closed and public transportation services suspended.",
"reference": {
"text": "Winter Snow Juno 2015",
"url": "https://www.rms.com/blog/2015/01/27/winter-storm-juno-three-facts-about-snowmageddon-2015#:~:text=There%20were%20predictions%20that%20Winter,hunkered%20down%20in%20their%20homes.",
},
}
],
datetime(2015, 1, 27).date(): [
{
"type": "2nd Day of Winter Snow Juno",
"description": "2nd day of Winter Storm Juno, a historic blizzard, hit New York City on January 26-27, 2015, bringing heavy snowfall, strong winds, and widespread disruptions. The storm led to a sharp decline in taxi ridership as travel became hazardous, with many roads closed and public transportation services suspended.",
"reference": {
"text": "Winter Snow Juno 2015",
"url": "https://www.rms.com/blog/2015/01/27/winter-storm-juno-three-facts-about-snowmageddon-2015#:~:text=There%20were%20predictions%20that%20Winter,hunkered%20down%20in%20their%20homes.",
},
}
],
}
def generate_context(self, anomaly_date):
"""
Generate potential context for the anomaly if predefined
:param anomaly_date: Date of the anomaly
:return: List of contextual insights if available, else None
"""
if isinstance(anomaly_date, pd.Timestamp):
anomaly_date = anomaly_date.date()
return self.predefined_anomalies.get(anomaly_date, None)
def load_nyc_taxi_data(file_path="data/nyc_taxi_traffic_data.csv"):
"""
Load and preprocess NYC Taxi dataset from a CSV file.
:return: DataFrame with taxi traffic data
"""
# Load the CSV file
df = pd.read_csv(file_path)
# Ensure timestamp column is datetime and rename columns for consistency
df["timestamp"] = pd.to_datetime(df["timestamp"])
df.rename(columns={"timestamp": "date", "value": "daily_traffic"}, inplace=True)
# Sort by date to ensure proper time-series ordering
df = df.sort_values(by="date").reset_index(drop=True)
return df
def main():
st.set_page_config(
page_title="NYC Taxi Traffic Anomaly Detection",
page_icon="πŸš•",
layout="wide",
initial_sidebar_state="expanded",
)
st.title("πŸš• NYC Taxi Traffic Anomaly Detection")
# Load Data
taxi_data = load_nyc_taxi_data()
# Sidebar for Configuration
st.sidebar.header("Anomaly Detection Settings")
# Date Range Selection
st.sidebar.subheader("Date Range")
min_date = taxi_data["date"].min().date()
max_date = taxi_data["date"].max().date()
col1, col2 = st.sidebar.columns(2)
with col1:
start_date = st.date_input(
"Start Date", min_value=min_date, max_value=max_date, value=min_date
)
with col2:
end_date = st.date_input(
"End Date", min_value=min_date, max_value=max_date, value=max_date
)
# Anomaly Sensitivity
anomaly_threshold = st.sidebar.slider(
"Anomaly Sensitivity",
min_value=0.01,
max_value=0.1,
value=0.05,
step=0.01,
help="Lower values detect fewer but more extreme anomalies",
)
# Instantiate Detector
detector = NYCTaxiAnomalyDetector(taxi_data)
# Filter Data by Date Range
filtered_data = detector.filter_by_date_range(
pd.to_datetime(start_date), pd.to_datetime(end_date)
)
# Detect Anomalies
anomalies = detector.detect_anomalies(
filtered_data, "daily_traffic", contamination=anomaly_threshold
)
# Get anomaly points for visualization
anomaly_points = anomalies[anomalies["is_anomaly"] == 1]
# Filter true anomalies based on predefined anomalies
true_anomaly_points = anomaly_points[
anomaly_points["date"].dt.date.isin(AIContextGenerator.predefined_anomalies)
]
# Visualization
st.header("Daily Taxi Traffic Trend ✨")
fig = px.line(
filtered_data,
x="date",
y="daily_traffic",
title=f"NYC Taxi Daily Traffic ({start_date} to {end_date})",
labels={"daily_traffic": "Number of Taxi Rides"},
)
# Highlight Anomalies
fig.add_trace(
go.Scatter(
x=anomaly_points["date"],
y=anomaly_points["daily_traffic"],
mode="markers",
name="Anomalies",
marker=dict(color="red", size=10, symbol="star"),
)
)
st.plotly_chart(fig, use_container_width=True)
# Anomaly Details
st.header("Insights of Anomalies with Known Events πŸ“ˆ")
# Calculate metrics using the anomalies DataFrame
total_anomalies_detected = len(anomaly_points)
true_anomalies = len(true_anomaly_points)
false_anomalies = total_anomalies_detected - true_anomalies
st.sidebar.subheader("Summary")
st.sidebar.metric("Total Anomalies Detected:", total_anomalies_detected)
st.sidebar.metric("Anomalies with Known Events:", true_anomalies)
st.sidebar.metric("Unexplained Anomalies:", false_anomalies)
if not true_anomaly_points.empty:
context_generator = AIContextGenerator()
# Group by date and calculate min/max traffic
grouped_anomalies = (
true_anomaly_points.groupby(true_anomaly_points["date"].dt.date)
.agg({"daily_traffic": ["min", "max"]})
.reset_index()
)
# Flatten the multi-level columns
grouped_anomalies.columns = ["date", "min_traffic", "max_traffic"]
for _, anomaly in grouped_anomalies.iterrows():
col1, col2 = st.columns(2)
with col1:
st.subheader(f"Anomaly on {anomaly['date']}")
traffic_range = (
f"{anomaly['min_traffic']:.0f}-{anomaly['max_traffic']:.0f}"
)
st.metric("Taxi Rides Range", traffic_range)
with col2:
contexts = context_generator.generate_context(anomaly["date"])
if contexts:
for context in contexts:
st.subheader(f"Event: {context['type']}")
reference_text = (
context["reference"]["text"]
if context["reference"]
else "-"
)
reference_url = (
context["reference"]["url"]
if context["reference"] and context["reference"]["url"]
else ""
)
url = (
f"[{reference_text}]({reference_url})"
if reference_url
else reference_text
)
st.markdown(
f"""
- Description: {context['description']}
- Reference: {url}
"""
)
else:
st.write("No significant event available for this anomaly.")
else:
st.info("No significant anomalies detected with current settings.")
if __name__ == "__main__":
main()