ouhenio's picture
Update app.py
977a1ff verified
raw
history blame
26.6 kB
import os
from queue import Queue
import json
import gradio as gr
import argilla as rg
from argilla.webhooks import webhook_listener
from dataclasses import dataclass, field, asdict
from typing import Dict, List, Optional, Tuple, Any, Callable
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ============================================================================
# DATA MODELS - Clear definition of data structures
# ============================================================================
@dataclass
class CountryData:
"""Data model for country information and annotation progress."""
name: str
target: int
count: int = 0
percent: int = 0
def update_progress(self, new_count: Optional[int] = None):
"""Update the progress percentage based on count/target."""
if new_count is not None:
self.count = new_count
self.percent = min(100, int((self.count / self.target) * 100))
return self
@dataclass
class Event:
"""Data model for events in the system."""
event_type: str
timestamp: str = ""
country: str = ""
count: int = 0
percent: int = 0
error: str = ""
@dataclass
class ApplicationState:
"""Central state management for the application."""
countries: Dict[str, CountryData] = field(default_factory=dict)
events: Queue = field(default_factory=Queue)
def to_dict(self) -> Dict[str, Any]:
"""Convert state to a serializable dictionary for the UI."""
return {
code: asdict(data) for code, data in self.countries.items()
}
def to_json(self) -> str:
"""Convert state to JSON for the UI."""
return json.dumps(self.to_dict())
def add_event(self, event: Event):
"""Add an event to the queue."""
self.events.put(asdict(event))
def get_next_event(self) -> Dict[str, Any]:
"""Get the next event from the queue."""
if not self.events.empty():
return self.events.get()
return {}
def update_country_progress(self, country_code: str, count: Optional[int] = None) -> bool:
"""Update a country's annotation progress."""
if country_code in self.countries:
if count is not None:
self.countries[country_code].count = count
self.countries[country_code].update_progress()
# Create and add a progress update event
self.add_event(Event(
event_type="progress_update",
country=self.countries[country_code].name,
count=self.countries[country_code].count,
percent=self.countries[country_code].percent
))
return True
return False
def increment_country_progress(self, country_code: str) -> bool:
"""Increment a country's annotation count by 1."""
if country_code in self.countries:
self.countries[country_code].count += 1
return self.update_country_progress(country_code)
return False
def get_stats(self) -> Tuple[int, float, int]:
"""Calculate overall statistics."""
total = sum(data.count for data in self.countries.values())
percentages = [data.percent for data in self.countries.values()]
avg = sum(percentages) / len(percentages) if percentages else 0
countries_50_plus = sum(1 for p in percentages if p >= 50)
return total, avg, countries_50_plus
# ============================================================================
# CONFIGURATION - Separated from business logic
# ============================================================================
class Config:
"""Configuration for the application."""
# Country mapping (ISO code to name and target)
COUNTRY_MAPPING = {
"MX": {"name": "Mexico", "target": 1000},
"AR": {"name": "Argentina", "target": 800},
"CO": {"name": "Colombia", "target": 700},
"CL": {"name": "Chile", "target": 600},
"PE": {"name": "Peru", "target": 600},
"ES": {"name": "Spain", "target": 1200},
"BR": {"name": "Brazil", "target": 1000},
"VE": {"name": "Venezuela", "target": 500},
"EC": {"name": "Ecuador", "target": 400},
"BO": {"name": "Bolivia", "target": 300},
"PY": {"name": "Paraguay", "target": 300},
"UY": {"name": "Uruguay", "target": 300},
"CR": {"name": "Costa Rica", "target": 250},
"PA": {"name": "Panama", "target": 250},
"DO": {"name": "Dominican Republic", "target": 300},
"GT": {"name": "Guatemala", "target": 250},
"HN": {"name": "Honduras", "target": 200},
"SV": {"name": "El Salvador", "target": 200},
"NI": {"name": "Nicaragua", "target": 200},
"CU": {"name": "Cuba", "target": 300}
}
@classmethod
def create_country_data(cls) -> Dict[str, CountryData]:
"""Create CountryData objects from the mapping."""
return {
code: CountryData(
name=data["name"],
target=data["target"]
) for code, data in cls.COUNTRY_MAPPING.items()
}
# ============================================================================
# SERVICES - Business logic separated from presentation and data access
# ============================================================================
class ArgillaService:
"""Service for interacting with Argilla."""
def __init__(self, api_url: Optional[str] = None, api_key: Optional[str] = None):
"""Initialize the Argilla service."""
self.api_url = api_url or os.getenv("ARGILLA_API_URL")
self.api_key = api_key or os.getenv("ARGILLA_API_KEY")
self.client = rg.Argilla(
api_url=self.api_url,
api_key=self.api_key,
)
self.server = rg.get_webhook_server()
def get_server(self):
"""Get the Argilla webhook server."""
return self.server
def get_client_base_url(self) -> str:
"""Get the base URL of the Argilla client."""
return self.client.http_client.base_url if hasattr(self.client, 'http_client') else "Not connected"
class CountryMappingService:
"""Service for mapping between dataset names and country codes."""
@staticmethod
def find_country_code_from_dataset(dataset_name: str) -> Optional[str]:
"""
Try to extract a country code from a dataset name by matching
country names in the dataset name.
"""
dataset_name_lower = dataset_name.lower()
for code, data in Config.COUNTRY_MAPPING.items():
country_name = data["name"].lower()
if country_name in dataset_name_lower:
return code
return None
# ============================================================================
# UI COMPONENTS - Presentation layer separated from business logic
# ============================================================================
class MapVisualization:
"""Component for D3.js map visualization."""
@staticmethod
def create_map_html() -> str:
"""Create the initial HTML container for the map."""
return """
<div id="map-container" style="width:100%; height:600px; position:relative; background-color:#111;">
<div style="display:flex; justify-content:center; align-items:center; height:100%; color:white; font-family:sans-serif;">
Loading map visualization...
</div>
</div>
<div id="tooltip" style="position:absolute; background-color:rgba(0,0,0,0.8); border-radius:5px; padding:8px; color:white; font-size:12px; pointer-events:none; opacity:0; transition:opacity 0.3s;"></div>
"""
@staticmethod
def create_d3_script(progress_data: str) -> str:
"""Create the D3.js script for rendering the map."""
return f"""
async () => {{
// Load D3.js modules
const script1 = document.createElement("script");
script1.src = "https://cdn.jsdelivr.net/npm/[email protected]/dist/d3.min.js";
document.head.appendChild(script1);
// Wait for D3 to load
await new Promise(resolve => {{
script1.onload = resolve;
}});
console.log("D3 loaded successfully");
// Load topojson
const script2 = document.createElement("script");
script2.src = "https://cdn.jsdelivr.net/npm/[email protected]/dist/topojson-client.min.js";
document.head.appendChild(script2);
await new Promise(resolve => {{
script2.onload = resolve;
}});
console.log("TopoJSON loaded successfully");
// The progress data passed from Python
const progressData = {progress_data};
// Set up the SVG container
const mapContainer = document.getElementById('map-container');
mapContainer.innerHTML = ''; // Clear loading message
const width = mapContainer.clientWidth;
const height = 600;
const svg = d3.select("#map-container")
.append("svg")
.attr("width", width)
.attr("height", height)
.attr("viewBox", `0 0 ${{width}} ${{height}}`)
.style("background-color", "#111");
// Define color scale
const colorScale = d3.scaleLinear()
.domain([0, 100])
.range(["#4a1942", "#f32b7b"]);
// Set up projection focused on Latin America and Spain
const projection = d3.geoMercator()
.center([-60, 0])
.scale(width / 5)
.translate([width / 2, height / 2]);
const path = d3.geoPath().projection(projection);
// Tooltip setup
const tooltip = d3.select("#tooltip");
// Load the world GeoJSON data
const response = await fetch("https://raw.githubusercontent.com/holtzy/D3-graph-gallery/master/DATA/world.geojson");
const data = await response.json();
// Draw the map
svg.selectAll("path")
.data(data.features)
.enter()
.append("path")
.attr("d", path)
.attr("stroke", "#f32b7b")
.attr("stroke-width", 1)
.attr("fill", d => {{
// Get the ISO code from the properties
const iso = d.properties.iso_a2;
if (progressData[iso]) {{
return colorScale(progressData[iso].percent);
}}
return "#2d3748"; // Default gray for non-tracked countries
}})
.on("mouseover", function(event, d) {{
const iso = d.properties.iso_a2;
d3.select(this)
.attr("stroke", "#4a1942")
.attr("stroke-width", 2);
if (progressData[iso]) {{
tooltip.style("opacity", 1)
.style("left", (event.pageX + 15) + "px")
.style("top", (event.pageY + 15) + "px")
.html(`
<strong>${{progressData[iso].name}}</strong><br/>
Documents: ${{progressData[iso].count.toLocaleString()}}/${{progressData[iso].target.toLocaleString()}}<br/>
Completion: ${{progressData[iso].percent}}%
`);
}}
}})
.on("mousemove", function(event) {{
tooltip.style("left", (event.pageX + 15) + "px")
.style("top", (event.pageY + 15) + "px");
}})
.on("mouseout", function() {{
d3.select(this)
.attr("stroke", "#f32b7b")
.attr("stroke-width", 1);
tooltip.style("opacity", 0);
}});
// Add legend
const legendWidth = Math.min(width - 40, 200);
const legendHeight = 15;
const legendX = width - legendWidth - 20;
const legend = svg.append("g")
.attr("class", "legend")
.attr("transform", `translate(${{legendX}}, 30)`);
// Create gradient for legend
const defs = svg.append("defs");
const gradient = defs.append("linearGradient")
.attr("id", "dataGradient")
.attr("x1", "0%")
.attr("y1", "0%")
.attr("x2", "100%")
.attr("y2", "0%");
gradient.append("stop")
.attr("offset", "0%")
.attr("stop-color", "#4a1942");
gradient.append("stop")
.attr("offset", "100%")
.attr("stop-color", "#f32b7b");
// Add legend title
legend.append("text")
.attr("x", legendWidth / 2)
.attr("y", -10)
.attr("text-anchor", "middle")
.attr("font-size", "12px")
.attr("fill", "#f1f5f9")
.text("Annotation Progress");
// Add legend rectangle
legend.append("rect")
.attr("width", legendWidth)
.attr("height", legendHeight)
.attr("rx", 2)
.attr("ry", 2)
.style("fill", "url(#dataGradient)");
// Add legend labels
legend.append("text")
.attr("x", 0)
.attr("y", legendHeight + 15)
.attr("text-anchor", "start")
.attr("font-size", "10px")
.attr("fill", "#94a3b8")
.text("0%");
legend.append("text")
.attr("x", legendWidth / 2)
.attr("y", legendHeight + 15)
.attr("text-anchor", "middle")
.attr("font-size", "10px")
.attr("fill", "#94a3b8")
.text("50%");
legend.append("text")
.attr("x", legendWidth)
.attr("y", legendHeight + 15)
.attr("text-anchor", "end")
.attr("font-size", "10px")
.attr("fill", "#94a3b8")
.text("100%");
// Handle window resize
globalThis.resizeMap = () => {{
const width = mapContainer.clientWidth;
// Update SVG dimensions
d3.select("svg")
.attr("width", width)
.attr("viewBox", `0 0 ${{width}} ${{height}}`);
// Update projection
projection.scale(width / 5)
.translate([width / 2, height / 2]);
// Update paths
d3.selectAll("path").attr("d", path);
// Update legend position
const legendWidth = Math.min(width - 40, 200);
const legendX = width - legendWidth - 20;
d3.select(".legend")
.attr("transform", `translate(${{legendX}}, 30)`);
}};
window.addEventListener('resize', globalThis.resizeMap);
}}
"""
# ============================================================================
# APPLICATION FACTORY - Creates and configures the application
# ============================================================================
class ApplicationFactory:
"""Factory for creating the application components."""
@classmethod
def create_app_state(cls) -> ApplicationState:
"""Create and initialize the application state."""
state = ApplicationState(countries=Config.create_country_data())
# Initialize with some sample data
for code in ["MX", "AR", "CO", "ES"]:
sample_count = int(state.countries[code].target * 0.3)
state.update_country_progress(code, sample_count)
state.update_country_progress("BR", int(state.countries["BR"].target * 0.5))
state.update_country_progress("CL", int(state.countries["CL"].target * 0.7))
return state
@classmethod
def create_argilla_service(cls) -> ArgillaService:
"""Create the Argilla service."""
return ArgillaService()
@staticmethod
def cleanup_existing_webhooks(argilla_client):
"""Clean up existing webhooks to avoid warnings."""
try:
# Get existing webhooks
existing_webhooks = argilla_client.webhooks.list()
# Look for our webhook
for webhook in existing_webhooks:
if "handle_response_created" in getattr(webhook, 'url', ''):
logger.info(f"Removing existing webhook: {webhook.id}")
argilla_client.webhooks.delete(webhook.id)
break
except Exception as e:
logger.warning(f"Could not clean up webhooks: {e}")
@classmethod
def create_webhook_handler(cls, app_state: ApplicationState) -> Callable:
"""Create the webhook handler function."""
country_service = CountryMappingService()
# Define the webhook handler
@webhook_listener(events=["response.created"])
async def handle_response_created(response, type, timestamp):
try:
# Log the event
logger.info(f"Received webhook event: {type} at {timestamp}")
# Add basic event to the queue
app_state.add_event(Event(
event_type=type,
timestamp=str(timestamp)
))
# Extract dataset name
record = response.record
dataset_name = record.dataset.name
logger.info(f"Processing response for dataset: {dataset_name}")
# Find country code from dataset name
country_code = country_service.find_country_code_from_dataset(dataset_name)
# Update country progress if found
if country_code:
success = app_state.increment_country_progress(country_code)
if success:
country_data = app_state.countries[country_code]
logger.info(
f"Updated progress for {country_data.name}: "
f"{country_data.count}/{country_data.target} ({country_data.percent}%)"
)
except Exception as e:
logger.error(f"Error in webhook handler: {e}", exc_info=True)
app_state.add_event(Event(
event_type="error",
error=str(e)
))
return handle_response_created
@classmethod
def create_ui(cls, argilla_service: ArgillaService, app_state: ApplicationState):
"""Create the Gradio UI."""
# Create and configure the Gradio interface
demo = gr.Blocks(theme=gr.themes.Soft(primary_hue="pink", secondary_hue="purple"))
with demo:
argilla_server = argilla_service.get_client_base_url()
with gr.Row():
gr.Markdown(f"""
# Latin America & Spain Annotation Progress Map
### Connected to Argilla server: {argilla_server}
This dashboard visualizes annotation progress across Latin America and Spain.
""")
with gr.Row():
with gr.Column(scale=2):
# Map visualization - empty at first
map_html = gr.HTML(MapVisualization.create_map_html(), label="Annotation Progress Map")
# Hidden element to store map data
map_data = gr.JSON(value=app_state.to_json(), visible=False)
with gr.Column(scale=1):
# Overall statistics
total_docs, avg_completion, countries_over_50 = app_state.get_stats()
total_docs_ui = gr.Number(value=total_docs, label="Total Documents", interactive=False)
avg_completion_ui = gr.Number(value=avg_completion, label="Average Completion (%)", interactive=False)
countries_over_50_ui = gr.Number(value=countries_over_50, label="Countries Over 50%", interactive=False)
# Country details
country_selector = gr.Dropdown(
choices=[f"{data.name} ({code})" for code, data in app_state.countries.items()],
label="Select Country"
)
country_progress = gr.JSON(label="Country Progress", value={})
# Refresh button
refresh_btn = gr.Button("Refresh Map")
# UI interaction functions
def update_map():
return app_state.to_json()
def update_country_details(country_selection):
if not country_selection:
return {}
# Extract the country code from the selection (format: "Country Name (CODE)")
code = country_selection.split("(")[-1].replace(")", "").strip()
if code in app_state.countries:
return asdict(app_state.countries[code])
return {}
def update_events():
event = app_state.get_next_event()
stats = app_state.get_stats()
# If this is a progress update, update the map data
if event.get("event_type") == "progress_update":
# This will indirectly trigger a map refresh through the change event
return event, app_state.to_json(), stats[0], stats[1], stats[2]
return event, None, stats[0], stats[1], stats[2]
# Set up event handlers
refresh_btn.click(
fn=update_map,
inputs=None,
outputs=map_data
)
country_selector.change(
fn=update_country_details,
inputs=[country_selector],
outputs=[country_progress]
)
# Alternative approach to load JavaScript without using _js parameter
# Create a hidden HTML component to hold our script
js_holder = gr.HTML("", visible=False)
# When map_data is updated, create a script tag with our D3 code
def create_script_tag(data):
script_content = MapVisualization.create_d3_script(data)
html = f"""
<div id="js-executor">
<script>
(async () => {{
const scriptFn = {script_content};
await scriptFn();
}})();
</script>
</div>
"""
return html
map_data.change(
fn=create_script_tag,
inputs=map_data,
outputs=js_holder
)
# Use timer to check for new events and update stats
gr.Timer(1, active=True).tick(
update_events,
outputs=[events_json, map_data, total_docs_ui, avg_completion_ui, countries_over_50_ui]
)
# Initialize D3 on page load using an initial script tag
initial_map_script = gr.HTML(
f"""
<div id="initial-js-executor">
<script>
document.addEventListener('DOMContentLoaded', async () => {{
const scriptFn = {MapVisualization.create_d3_script(app_state.to_json())};
await scriptFn();
}});
</script>
</div>
""",
visible=False
)
return demo
# ============================================================================
# MAIN APPLICATION - Entry point and initialization
# ============================================================================
def create_application():
"""Create and configure the complete application."""
# Create application components
app_state = ApplicationFactory.create_app_state()
argilla_service = ApplicationFactory.create_argilla_service()
# Clean up existing webhooks
ApplicationFactory.cleanup_existing_webhooks(argilla_service.client)
# Create and register webhook handler
webhook_handler = ApplicationFactory.create_webhook_handler(app_state)
# Create the UI
demo = ApplicationFactory.create_ui(argilla_service, app_state)
# Mount the Gradio app to the FastAPI server
server = argilla_service.get_server()
gr.mount_gradio_app(server, demo, path="/")
return server
# Application entry point
if __name__ == "__main__":
import uvicorn
# Create the application
server = create_application()
# Start the server
uvicorn.run(server, host="0.0.0.0", port=7860)