from typing import List, Dict, Optional
from datetime import datetime, timedelta
from fastapi import FastAPI, HTTPException, Query, Body, Request
from pydantic import BaseModel, validator, root_validator
import json
import os
from user_agents import parse

app = FastAPI()

# Data storage (in-memory)
user_data: Dict[str, dict] = {}  # Key: IP address, Value: User entry

# --- Data Models ---
class UserEntry(BaseModel):
    ip_address: str
    device_type: str = "N/A"
    timestamp: datetime = datetime.now()
    browser: str = "N/A"
    OS: str = "N/A"

    @validator("ip_address")
    def validate_ip_address(cls, value):
        parts = value.split('.')
        if len(parts) != 4:
            raise ValueError("Invalid IP address format")
        for part in parts:
            try:
                num = int(part)
                if not 0 <= num <= 255:
                    raise ValueError("Invalid IP address value")
            except ValueError:
                raise ValueError("Invalid IP address value")
        return value

    @root_validator(pre=True)
    def set_default_values(cls, values):
        """Set default values for missing fields."""
        defaults = {
            "device_type": "N/A",
            "browser": "N/A",
            "OS": "N/A",
            "timestamp": datetime.now()
        }
        for key, default_value in defaults.items():
            if key not in values or values[key] is None:
                values[key] = default_value
        return values

def clean_old_data():
    """Deletes data older than 7 days."""
    global user_data
    cutoff_time = datetime.now() - timedelta(days=7)
    ips_to_delete = [ip for ip, entry in user_data.items() if entry["timestamp"] < cutoff_time]
    for ip in ips_to_delete:
        del user_data[ip]

# --- API Endpoints ---
@app.post("/auto_entry/", response_model=UserEntry, status_code=201)
async def create_auto_user_entry(
    request: Request
):
    """
    Endpoint to automatically record user entry by extracting the IP address
    from the request and taking device_type and optional timestamp as input.
    """
    try:
        # Automatically extract the client's IP address
        ip_address = request.client.host
        if "x-forwarded-for" in request.headers:
            ip_address = request.headers["x-forwarded-for"].split(",")[0]

        user_agent = request.headers.get("User-Agent", "N/A")
        user_agent_parsed = parse(user_agent)
    
        device_type = "Mobile" if user_agent_parsed.is_mobile else "Tablet" if user_agent_parsed.is_tablet else "Desktop"
        browser_name = user_agent_parsed.browser.family if user_agent_parsed else "N/A"
        os_name = user_agent_parsed.os.family if user_agent_parsed else "N/A"

        timestamp = datetime.now()

        # Create a UserEntry object
        entry_data = UserEntry(
            ip_address=ip_address,
            device_type=device_type,
            timestamp=timestamp,
            browser=browser_name,
            OS=os_name
        )

        # Save the entry
        user_data[ip_address] = entry_data.dict()

        return entry_data

    except ValueError as ve:
        raise HTTPException(status_code=400, detail=str(ve))
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Internal server error: {e}")

@app.get("/analytics/")
async def get_user_analytics(period: str = Query(..., enum=["last_hour", "last_day", "last_7_day"])):
    """Endpoint to get advanced user analytics."""
    try:
        clean_old_data()  # Clean data before processing

        now = datetime.now()
        if period == "last_hour":
            cutoff = now - timedelta(hours=1)
        elif period == "last_day":
            cutoff = now - timedelta(days=1)
        elif period == "last_7_day":
            cutoff = now - timedelta(days=7)
        else:
            raise HTTPException(status_code=400, detail="Invalid period specified")

        filtered_data = [entry for entry in user_data.values() if entry["timestamp"] >= cutoff]
        unique_users = len(set(entry["ip_address"] for entry in filtered_data))
        device_counts: Dict[str, int] = {}
        browser_counts: Dict[str, int] = {}
        os_counts: Dict[str, int] = {}

        for entry in filtered_data:
            device_counts[entry["device_type"]] = device_counts.get(entry["device_type"], 0) + 1
            browser_counts[entry["browser"]] = browser_counts.get(entry["browser"], 0) + 1
            os_counts[entry["OS"]] = os_counts.get(entry["OS"], 0) + 1

        # Calculate percentages
        total_entries = len(filtered_data)
        device_percentages = {device: (count / total_entries) * 100 for device, count in device_counts.items()}
        browser_percentages = {browser: (count / total_entries) * 100 for browser, count in browser_counts.items()}
        os_percentages = {os: (count / total_entries) * 100 for os, count in os_counts.items()}

        return {
            "total_unique_users": unique_users,
            "device_type_info": {
                "counts": device_counts,
                "percentages": device_percentages
            },
            "browser_info": {
                "counts": browser_counts,
                "percentages": browser_percentages
            },
            "os_info": {
                "counts": os_counts,
                "percentages": os_percentages
            },
            "total_entries": total_entries
        }
    except HTTPException as he:
        raise he
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error generating analytics: {e}")


@app.get("/export/", response_model=Dict[str, UserEntry])
async def export_user_data():
    """Endpoint to export all user data in JSON format."""
    try:
        clean_old_data()  # Ensure no old data is exported
        return user_data
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error exporting data: {e}")

@app.post("/import/")
async def import_user_data(data: Dict[str, dict] = Body(...)):
    """Endpoint to import user data from JSON."""
    try:
        imported_count = 0
        for ip, entry_dict in data.items():
            try:
                # Validate the imported entry
                entry = UserEntry(**entry_dict)
                entry.timestamp = datetime.fromisoformat(entry_dict.get("timestamp", datetime.now().isoformat())) # Ensure timestamp is datetime
                user_data[ip] = entry.dict()
                imported_count += 1
            except Exception as e:
                print(f"Error importing entry for IP {ip}: {e}")  # Log individual import errors
        return {"message": f"Successfully imported {imported_count} user entries."}
    except json.JSONDecodeError:
        raise HTTPException(status_code=400, detail="Invalid JSON format")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error importing data: {e}")

# Data storage (in-memory)
user_data: Dict[str, dict] = {}  # Key: IP address, Value: User entry
poll_data: Dict[str, dict] = {}  # Key: Poll name, Value: Poll details
poll_responses: Dict[str, Dict[str, int]] = {}  # Key: Poll name, Value: {IP: response}

# --- Data Models ---
class PollCreate(BaseModel):
    poll_name: str
    question: str
    options: List[str]

class PollEntry(BaseModel):
    poll_name: str
    response: int

# --- API Endpoints ---
@app.post("/poll/create/", status_code=201)
async def create_poll(poll: PollCreate):
    """Endpoint to create a new poll."""
    if poll.poll_name in poll_data:
        raise HTTPException(status_code=400, detail="Poll with this name already exists.")
    poll_data[poll.poll_name] = {
        "question": poll.question,
        "options": poll.options,
        "created_at": datetime.now()
    }
    poll_responses[poll.poll_name] = {}
    return {"message": "Poll created successfully.", "poll_name": poll.poll_name}

@app.post("/poll/entry/", status_code=201)
async def create_poll_entry(request: Request, poll_entry: PollEntry):
    """Endpoint to record a user's response to a poll."""
    ip_address = request.client.host
    if "x-forwarded-for" in request.headers:
        ip_address = request.headers["x-forwarded-for"].split(",")[0]

    if poll_entry.poll_name not in poll_data:
        raise HTTPException(status_code=404, detail="Poll not found.")

    if poll_entry.response < 1 or poll_entry.response > len(poll_data[poll_entry.poll_name]["options"]):
        raise HTTPException(status_code=400, detail="Invalid response option.")

    poll_responses[poll_entry.poll_name][ip_address] = poll_entry.response
    return {"message": "Poll entry recorded successfully."}

@app.get("/poll/analytics/")
async def get_poll_analytics(poll_name: str = Query(..., description="Name of the poll")):
    """Endpoint to get analytics for a specific poll."""
    if poll_name not in poll_data:
        raise HTTPException(status_code=404, detail="Poll not found.")

    responses = poll_responses[poll_name]
    response_counts = {option: 0 for option in range(1, len(poll_data[poll_name]["options"]) + 1)}
    for response in responses.values():
        response_counts[response] += 1

    return {
        "poll_name": poll_name,
        "question": poll_data[poll_name]["question"],
        "options": poll_data[poll_name]["options"],
        "response_counts": response_counts,
        "total_responses": len(responses)
    }

# --- Background Task for Poll Deletion ---
async def scheduled_poll_cleanup():
    """Periodically clean up old polls."""
    while True:
        now = datetime.now()
        polls_to_delete = [poll_name for poll_name, poll in poll_data.items() if (now - poll["created_at"]).days >= 7]
        for poll_name in polls_to_delete:
            del poll_data[poll_name]
            del poll_responses[poll_name]
        await asyncio.sleep(60 * 60 * 24)  # Clean every 24 hours
    
# --- Background Task (Optional, for regular cleanup) ---
async def scheduled_cleanup():
    """Periodically clean up old data."""
    while True:
        clean_old_data()
        await asyncio.sleep(60 * 60)  # Clean every hour

# Import asyncio if you use the background task
import asyncio
from fastapi import BackgroundTasks

@app.on_event("startup")
async def startup_event():
    # You can uncomment this to run the background task
    asyncio.create_task(scheduled_cleanup())
    asyncio.create_task(scheduled_poll_cleanup())
    pass

# --- Error Handling (Advanced - using exception handlers) ---
from fastapi import Request
from fastapi.responses import JSONResponse

@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
    return JSONResponse(
        status_code=exc.status_code,
        content={"message": exc.detail},
    )

@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
    return JSONResponse(
        status_code=500,
        content={"message": f"An unexpected error occurred: {exc}"},
    )

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8083, debug=True)