import gradio as gr
import matplotlib.pyplot as plt
from matplotlib_venn import venn3, venn3_circles
from io import BytesIO
from PIL import Image

def validate_inputs(A, B, C, AB, AC, BC, ABC, U):
    errors = []
    if A < AB + AC - ABC:
        errors.append("A no puede ser menor que la suma de AB y AC menos ABC.")
    if B < AB + BC - ABC:
        errors.append("B no puede ser menor que la suma de AB y BC menos ABC.")
    if C < AC + BC - ABC:
        errors.append("C no puede ser menor que la suma de AC y BC menos ABC.")
    if U < A + B + C - AB - AC - BC + ABC:
        errors.append("El conjunto universal U es menor que la suma total de los conjuntos y sus intersecciones.")
    return errors

def calculate_probabilities(A, B, C, AB, AC, BC, ABC, U):
    total = U

    P_A = A / total
    P_B = B / total
    P_C = C / total
    P_AB = AB / total
    P_AC = AC / total
    P_BC = BC / total
    P_ABC = ABC / total

    P_A_given_B = P_AB / P_B if P_B > 0 else 0
    P_B_given_A = P_AB / P_A if P_A > 0 else 0
    P_A_given_C = P_AC / P_C if P_C > 0 else 0
    P_C_given_A = P_AC / P_A if P_A > 0 else 0
    P_B_given_C = P_BC / P_C if P_C > 0 else 0
    P_C_given_B = P_BC / P_B if P_B > 0 else 0

    P_B_given_A_bayes = (P_A_given_B * P_B) / P_A if P_A > 0 else 0

    formatted_probs = {
        "P(A)": f"{P_A:.2%} ({A}/{total})",
        "P(B)": f"{P_B:.2%} ({B}/{total})",
        "P(C)": f"{P_C:.2%} ({C}/{total})",
        "P(A ∩ B)": f"{P_AB:.2%} ({AB}/{total})",
        "P(A ∩ C)": f"{P_AC:.2%} ({AC}/{total})",
        "P(B ∩ C)": f"{P_BC:.2%} ({BC}/{total})",
        "P(A ∩ B ∩ C)": f"{P_ABC:.2%} ({ABC}/{total})",
        "P(A | B)": f"{P_A_given_B:.2%}",
        "P(B | A)": f"{P_B_given_A:.2%}",
        "P(A | C)": f"{P_A_given_C:.2%}",
        "P(C | A)": f"{P_C_given_A:.2%}",
        "P(B | C)": f"{P_B_given_C:.2%}",
        "P(C | B)": f"{P_C_given_B:.2%}",
        "P(B | A) (Bayes)": f"{P_B_given_A_bayes:.2%}",
    }
    
    return formatted_probs

def plot_venn(A, B, C, AB, AC, BC, ABC, U):
    plt.figure(figsize=(10, 10))
    
    subsets = {
        '100': A - AB - AC + ABC,
        '010': B - AB - BC + ABC,
        '001': C - AC - BC + ABC,
        '110': AB - ABC,
        '101': AC - ABC,
        '011': BC - ABC,
        '111': ABC
    }
    venn = venn3(subsets=subsets, set_labels=('A', 'B', 'C'))
    venn_circles = venn3_circles(subsets=subsets, linewidth=1.0)
    
    plt.title(f"Diagrama de Venn con U = {U}")
    
    buf = BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    img = Image.open(buf)
    return img

def suggest_intersections(A, B, C, AB, AC, BC, ABC, U):
    max_AB = min(A, B, U - (A + B + C - AB - AC - BC + ABC))
    max_AC = min(A, C, U - (A + B + C - AB - AC - BC + ABC))
    max_BC = min(B, C, U - (A + B + C - AB - AC - BC + ABC))
    max_ABC = min(max_AB, max_AC, max_BC)
    
    min_AB = max(0, A + B - U + C)
    min_AC = max(0, A + C - U + B)
    min_BC = max(0, B + C - U + A)
    min_ABC = max(0, A + B + C - U)
    
    max_A = U - (B + C - BC)
    max_B = U - (A + C - AC)
    max_C = U - (A + B - AB)
    min_A = max(AB + AC - ABC, 0)
    min_B = max(AB + BC - ABC, 0)
    min_C = max(AC + BC - ABC, 0)

    suggestions = {
        "Máximo valor sugerido para A": max_A,
        "Mínimo valor sugerido para A": min_A,
        "Máximo valor sugerido para B": max_B,
        "Mínimo valor sugerido para B": min_B,
        "Máximo valor sugerido para C": max_C,
        "Mínimo valor sugerido para C": min_C,
        "Máximo valor sugerido para A ∩ B": max_AB,
        "Mínimo valor sugerido para A ∩ B": min_AB,
        "Máximo valor sugerido para A ∩ C": max_AC,
        "Mínimo valor sugerido para A ∩ C": min_AC,
        "Máximo valor sugerido para B ∩ C": max_BC,
        "Mínimo valor sugerido para B ∩ C": min_BC,
        "Máximo valor sugerido para A ∩ B ∩ C": max_ABC,
        "Mínimo valor sugerido para A ∩ B ∩ C": min_ABC,
    }
    return suggestions

def main(A, B, C, AB, AC, BC, ABC, U):
    errors = validate_inputs(A, B, C, AB, AC, BC, ABC, U)
    if errors:
        return None, {"error": "\n".join(errors), "sugerencias": suggest_intersections(A, B, C, AB, AC, BC, ABC, U)}
    
    venn_diagram = plot_venn(A, B, C, AB, AC, BC, ABC, U)
    probabilities = calculate_probabilities(A, B, C, AB, AC, BC, ABC, U)
    return venn_diagram, {"probabilidades": probabilities, "sugerencias": suggest_intersections(A, B, C, AB, AC, BC, ABC, U)}

iface = gr.Interface(
    fn=main,
    inputs=[
        gr.Number(label="Conjunto Universal (U)", value=0),
        gr.Number(label="Cantidad en A"),
        gr.Number(label="Cantidad en B"),
        gr.Number(label="Cantidad en C"),
        gr.Number(label="Cantidad en A ∩ B"),
        gr.Number(label="Cantidad en A ∩ C"),
        gr.Number(label="Cantidad en B ∩ C"),
        gr.Number(label="Cantidad en A ∩ B ∩ C")
    ],
    outputs=[
        gr.Image(type="pil", label="Diagrama de Venn"),
        gr.JSON(label="Resultados y Sugerencias")
    ],
    live=True
)

iface.launch()