Spaces:
Sleeping
Sleeping
import os | |
from queue import Queue | |
import json | |
import gradio as gr | |
import argilla as rg | |
from argilla.webhooks import webhook_listener | |
from dataclasses import dataclass, field, asdict | |
from typing import Dict, List, Optional, Tuple, Any, Callable | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# ============================================================================ | |
# DATA MODELS - Clear definition of data structures | |
# ============================================================================ | |
class CountryData: | |
"""Data model for country information and annotation progress.""" | |
name: str | |
target: int | |
count: int = 0 | |
percent: int = 0 | |
def update_progress(self, new_count: Optional[int] = None): | |
"""Update the progress percentage based on count/target.""" | |
if new_count is not None: | |
self.count = new_count | |
self.percent = min(100, int((self.count / self.target) * 100)) | |
return self | |
class Event: | |
"""Data model for events in the system.""" | |
event_type: str | |
timestamp: str = "" | |
country: str = "" | |
count: int = 0 | |
percent: int = 0 | |
error: str = "" | |
class ApplicationState: | |
"""Central state management for the application.""" | |
countries: Dict[str, CountryData] = field(default_factory=dict) | |
events: Queue = field(default_factory=Queue) | |
def to_dict(self) -> Dict[str, Any]: | |
"""Convert state to a serializable dictionary for the UI.""" | |
return { | |
code: asdict(data) for code, data in self.countries.items() | |
} | |
def to_json(self) -> str: | |
"""Convert state to JSON for the UI.""" | |
return json.dumps(self.to_dict()) | |
def add_event(self, event: Event): | |
"""Add an event to the queue.""" | |
self.events.put(asdict(event)) | |
def get_next_event(self) -> Dict[str, Any]: | |
"""Get the next event from the queue.""" | |
if not self.events.empty(): | |
return self.events.get() | |
return {} | |
def update_country_progress(self, country_code: str, count: Optional[int] = None) -> bool: | |
"""Update a country's annotation progress.""" | |
if country_code in self.countries: | |
if count is not None: | |
self.countries[country_code].count = count | |
self.countries[country_code].update_progress() | |
# Create and add a progress update event | |
self.add_event(Event( | |
event_type="progress_update", | |
country=self.countries[country_code].name, | |
count=self.countries[country_code].count, | |
percent=self.countries[country_code].percent | |
)) | |
return True | |
return False | |
def increment_country_progress(self, country_code: str) -> bool: | |
"""Increment a country's annotation count by 1.""" | |
if country_code in self.countries: | |
self.countries[country_code].count += 1 | |
return self.update_country_progress(country_code) | |
return False | |
def get_stats(self) -> Tuple[int, float, int]: | |
"""Calculate overall statistics.""" | |
total = sum(data.count for data in self.countries.values()) | |
percentages = [data.percent for data in self.countries.values()] | |
avg = sum(percentages) / len(percentages) if percentages else 0 | |
countries_50_plus = sum(1 for p in percentages if p >= 50) | |
return total, avg, countries_50_plus | |
# ============================================================================ | |
# CONFIGURATION - Separated from business logic | |
# ============================================================================ | |
class Config: | |
"""Configuration for the application.""" | |
# Country mapping (ISO code to name and target) | |
COUNTRY_MAPPING = { | |
"MX": {"name": "Mexico", "target": 1000}, | |
"AR": {"name": "Argentina", "target": 800}, | |
"CO": {"name": "Colombia", "target": 700}, | |
"CL": {"name": "Chile", "target": 600}, | |
"PE": {"name": "Peru", "target": 600}, | |
"ES": {"name": "Spain", "target": 1200}, | |
"BR": {"name": "Brazil", "target": 1000}, | |
"VE": {"name": "Venezuela", "target": 500}, | |
"EC": {"name": "Ecuador", "target": 400}, | |
"BO": {"name": "Bolivia", "target": 300}, | |
"PY": {"name": "Paraguay", "target": 300}, | |
"UY": {"name": "Uruguay", "target": 300}, | |
"CR": {"name": "Costa Rica", "target": 250}, | |
"PA": {"name": "Panama", "target": 250}, | |
"DO": {"name": "Dominican Republic", "target": 300}, | |
"GT": {"name": "Guatemala", "target": 250}, | |
"HN": {"name": "Honduras", "target": 200}, | |
"SV": {"name": "El Salvador", "target": 200}, | |
"NI": {"name": "Nicaragua", "target": 200}, | |
"CU": {"name": "Cuba", "target": 300} | |
} | |
def create_country_data(cls) -> Dict[str, CountryData]: | |
"""Create CountryData objects from the mapping.""" | |
return { | |
code: CountryData( | |
name=data["name"], | |
target=data["target"] | |
) for code, data in cls.COUNTRY_MAPPING.items() | |
} | |
# ============================================================================ | |
# SERVICES - Business logic separated from presentation and data access | |
# ============================================================================ | |
class ArgillaService: | |
"""Service for interacting with Argilla.""" | |
def __init__(self, api_url: Optional[str] = None, api_key: Optional[str] = None): | |
"""Initialize the Argilla service.""" | |
self.api_url = api_url or os.getenv("ARGILLA_API_URL") | |
self.api_key = api_key or os.getenv("ARGILLA_API_KEY") | |
self.client = rg.Argilla( | |
api_url=self.api_url, | |
api_key=self.api_key, | |
) | |
self.server = rg.get_webhook_server() | |
def get_server(self): | |
"""Get the Argilla webhook server.""" | |
return self.server | |
def get_client_base_url(self) -> str: | |
"""Get the base URL of the Argilla client.""" | |
return self.client.http_client.base_url if hasattr(self.client, 'http_client') else "Not connected" | |
class CountryMappingService: | |
"""Service for mapping between dataset names and country codes.""" | |
def find_country_code_from_dataset(dataset_name: str) -> Optional[str]: | |
""" | |
Try to extract a country code from a dataset name by matching | |
country names in the dataset name. | |
""" | |
dataset_name_lower = dataset_name.lower() | |
for code, data in Config.COUNTRY_MAPPING.items(): | |
country_name = data["name"].lower() | |
if country_name in dataset_name_lower: | |
return code | |
return None | |
# ============================================================================ | |
# UI COMPONENTS - Presentation layer separated from business logic | |
# ============================================================================ | |
class MapVisualization: | |
"""Component for D3.js map visualization.""" | |
def create_map_html() -> str: | |
"""Create the initial HTML container for the map.""" | |
return """ | |
<div id="map-container" style="width:100%; height:600px; position:relative; background-color:#111;"> | |
<div style="display:flex; justify-content:center; align-items:center; height:100%; color:white; font-family:sans-serif;"> | |
Loading map visualization... | |
</div> | |
</div> | |
<div id="tooltip" style="position:absolute; background-color:rgba(0,0,0,0.8); border-radius:5px; padding:8px; color:white; font-size:12px; pointer-events:none; opacity:0; transition:opacity 0.3s;"></div> | |
""" | |
def create_d3_script(progress_data: str) -> str: | |
"""Create the D3.js script for rendering the map.""" | |
return f""" | |
async () => {{ | |
// Load D3.js modules | |
const script1 = document.createElement("script"); | |
script1.src = "https://cdn.jsdelivr.net/npm/[email protected]/dist/d3.min.js"; | |
document.head.appendChild(script1); | |
// Wait for D3 to load | |
await new Promise(resolve => {{ | |
script1.onload = resolve; | |
}}); | |
console.log("D3 loaded successfully"); | |
// Load topojson | |
const script2 = document.createElement("script"); | |
script2.src = "https://cdn.jsdelivr.net/npm/[email protected]/dist/topojson-client.min.js"; | |
document.head.appendChild(script2); | |
await new Promise(resolve => {{ | |
script2.onload = resolve; | |
}}); | |
console.log("TopoJSON loaded successfully"); | |
// The progress data passed from Python | |
const progressData = {progress_data}; | |
// Set up the SVG container | |
const mapContainer = document.getElementById('map-container'); | |
mapContainer.innerHTML = ''; // Clear loading message | |
const width = mapContainer.clientWidth; | |
const height = 600; | |
const svg = d3.select("#map-container") | |
.append("svg") | |
.attr("width", width) | |
.attr("height", height) | |
.attr("viewBox", `0 0 ${{width}} ${{height}}`) | |
.style("background-color", "#111"); | |
// Define color scale | |
const colorScale = d3.scaleLinear() | |
.domain([0, 100]) | |
.range(["#4a1942", "#f32b7b"]); | |
// Set up projection focused on Latin America and Spain | |
const projection = d3.geoMercator() | |
.center([-60, 0]) | |
.scale(width / 5) | |
.translate([width / 2, height / 2]); | |
const path = d3.geoPath().projection(projection); | |
// Tooltip setup | |
const tooltip = d3.select("#tooltip"); | |
// Load the world GeoJSON data | |
const response = await fetch("https://raw.githubusercontent.com/holtzy/D3-graph-gallery/master/DATA/world.geojson"); | |
const data = await response.json(); | |
// Draw the map | |
svg.selectAll("path") | |
.data(data.features) | |
.enter() | |
.append("path") | |
.attr("d", path) | |
.attr("stroke", "#f32b7b") | |
.attr("stroke-width", 1) | |
.attr("fill", d => {{ | |
// Get the ISO code from the properties | |
const iso = d.properties.iso_a2; | |
if (progressData[iso]) {{ | |
return colorScale(progressData[iso].percent); | |
}} | |
return "#2d3748"; // Default gray for non-tracked countries | |
}}) | |
.on("mouseover", function(event, d) {{ | |
const iso = d.properties.iso_a2; | |
d3.select(this) | |
.attr("stroke", "#4a1942") | |
.attr("stroke-width", 2); | |
if (progressData[iso]) {{ | |
tooltip.style("opacity", 1) | |
.style("left", (event.pageX + 15) + "px") | |
.style("top", (event.pageY + 15) + "px") | |
.html(` | |
<strong>${{progressData[iso].name}}</strong><br/> | |
Documents: ${{progressData[iso].count.toLocaleString()}}/${{progressData[iso].target.toLocaleString()}}<br/> | |
Completion: ${{progressData[iso].percent}}% | |
`); | |
}} | |
}}) | |
.on("mousemove", function(event) {{ | |
tooltip.style("left", (event.pageX + 15) + "px") | |
.style("top", (event.pageY + 15) + "px"); | |
}}) | |
.on("mouseout", function() {{ | |
d3.select(this) | |
.attr("stroke", "#f32b7b") | |
.attr("stroke-width", 1); | |
tooltip.style("opacity", 0); | |
}}); | |
// Add legend | |
const legendWidth = Math.min(width - 40, 200); | |
const legendHeight = 15; | |
const legendX = width - legendWidth - 20; | |
const legend = svg.append("g") | |
.attr("class", "legend") | |
.attr("transform", `translate(${{legendX}}, 30)`); | |
// Create gradient for legend | |
const defs = svg.append("defs"); | |
const gradient = defs.append("linearGradient") | |
.attr("id", "dataGradient") | |
.attr("x1", "0%") | |
.attr("y1", "0%") | |
.attr("x2", "100%") | |
.attr("y2", "0%"); | |
gradient.append("stop") | |
.attr("offset", "0%") | |
.attr("stop-color", "#4a1942"); | |
gradient.append("stop") | |
.attr("offset", "100%") | |
.attr("stop-color", "#f32b7b"); | |
// Add legend title | |
legend.append("text") | |
.attr("x", legendWidth / 2) | |
.attr("y", -10) | |
.attr("text-anchor", "middle") | |
.attr("font-size", "12px") | |
.attr("fill", "#f1f5f9") | |
.text("Annotation Progress"); | |
// Add legend rectangle | |
legend.append("rect") | |
.attr("width", legendWidth) | |
.attr("height", legendHeight) | |
.attr("rx", 2) | |
.attr("ry", 2) | |
.style("fill", "url(#dataGradient)"); | |
// Add legend labels | |
legend.append("text") | |
.attr("x", 0) | |
.attr("y", legendHeight + 15) | |
.attr("text-anchor", "start") | |
.attr("font-size", "10px") | |
.attr("fill", "#94a3b8") | |
.text("0%"); | |
legend.append("text") | |
.attr("x", legendWidth / 2) | |
.attr("y", legendHeight + 15) | |
.attr("text-anchor", "middle") | |
.attr("font-size", "10px") | |
.attr("fill", "#94a3b8") | |
.text("50%"); | |
legend.append("text") | |
.attr("x", legendWidth) | |
.attr("y", legendHeight + 15) | |
.attr("text-anchor", "end") | |
.attr("font-size", "10px") | |
.attr("fill", "#94a3b8") | |
.text("100%"); | |
// Handle window resize | |
globalThis.resizeMap = () => {{ | |
const width = mapContainer.clientWidth; | |
// Update SVG dimensions | |
d3.select("svg") | |
.attr("width", width) | |
.attr("viewBox", `0 0 ${{width}} ${{height}}`); | |
// Update projection | |
projection.scale(width / 5) | |
.translate([width / 2, height / 2]); | |
// Update paths | |
d3.selectAll("path").attr("d", path); | |
// Update legend position | |
const legendWidth = Math.min(width - 40, 200); | |
const legendX = width - legendWidth - 20; | |
d3.select(".legend") | |
.attr("transform", `translate(${{legendX}}, 30)`); | |
}}; | |
window.addEventListener('resize', globalThis.resizeMap); | |
}} | |
""" | |
# ============================================================================ | |
# APPLICATION FACTORY - Creates and configures the application | |
# ============================================================================ | |
class ApplicationFactory: | |
"""Factory for creating the application components.""" | |
def create_app_state(cls) -> ApplicationState: | |
"""Create and initialize the application state.""" | |
state = ApplicationState(countries=Config.create_country_data()) | |
# Initialize with some sample data | |
for code in ["MX", "AR", "CO", "ES"]: | |
sample_count = int(state.countries[code].target * 0.3) | |
state.update_country_progress(code, sample_count) | |
state.update_country_progress("BR", int(state.countries["BR"].target * 0.5)) | |
state.update_country_progress("CL", int(state.countries["CL"].target * 0.7)) | |
return state | |
def create_argilla_service(cls) -> ArgillaService: | |
"""Create the Argilla service.""" | |
return ArgillaService() | |
def cleanup_existing_webhooks(argilla_client): | |
"""Clean up existing webhooks to avoid warnings.""" | |
try: | |
# Get existing webhooks | |
existing_webhooks = argilla_client.webhooks.list() | |
# Look for our webhook | |
for webhook in existing_webhooks: | |
if "handle_response_created" in getattr(webhook, 'url', ''): | |
logger.info(f"Removing existing webhook: {webhook.id}") | |
argilla_client.webhooks.delete(webhook.id) | |
break | |
except Exception as e: | |
logger.warning(f"Could not clean up webhooks: {e}") | |
def create_webhook_handler(cls, app_state: ApplicationState) -> Callable: | |
"""Create the webhook handler function.""" | |
country_service = CountryMappingService() | |
# Define the webhook handler | |
async def handle_response_created(response, type, timestamp): | |
try: | |
# Log the event | |
logger.info(f"Received webhook event: {type} at {timestamp}") | |
# Add basic event to the queue | |
app_state.add_event(Event( | |
event_type=type, | |
timestamp=str(timestamp) | |
)) | |
# Extract dataset name | |
record = response.record | |
dataset_name = record.dataset.name | |
logger.info(f"Processing response for dataset: {dataset_name}") | |
# Find country code from dataset name | |
country_code = country_service.find_country_code_from_dataset(dataset_name) | |
# Update country progress if found | |
if country_code: | |
success = app_state.increment_country_progress(country_code) | |
if success: | |
country_data = app_state.countries[country_code] | |
logger.info( | |
f"Updated progress for {country_data.name}: " | |
f"{country_data.count}/{country_data.target} ({country_data.percent}%)" | |
) | |
except Exception as e: | |
logger.error(f"Error in webhook handler: {e}", exc_info=True) | |
app_state.add_event(Event( | |
event_type="error", | |
error=str(e) | |
)) | |
return handle_response_created | |
def create_ui(cls, argilla_service: ArgillaService, app_state: ApplicationState): | |
"""Create the Gradio UI.""" | |
# Create and configure the Gradio interface | |
demo = gr.Blocks(theme=gr.themes.Soft(primary_hue="pink", secondary_hue="purple")) | |
with demo: | |
argilla_server = argilla_service.get_client_base_url() | |
with gr.Row(): | |
gr.Markdown(f""" | |
# Latin America & Spain Annotation Progress Map | |
### Connected to Argilla server: {argilla_server} | |
This dashboard visualizes annotation progress across Latin America and Spain. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Map visualization - empty at first | |
map_html = gr.HTML(MapVisualization.create_map_html(), label="Annotation Progress Map") | |
# Hidden element to store map data | |
map_data = gr.JSON(value=app_state.to_json(), visible=False) | |
with gr.Column(scale=1): | |
# Overall statistics | |
total_docs, avg_completion, countries_over_50 = app_state.get_stats() | |
total_docs_ui = gr.Number(value=total_docs, label="Total Documents", interactive=False) | |
avg_completion_ui = gr.Number(value=avg_completion, label="Average Completion (%)", interactive=False) | |
countries_over_50_ui = gr.Number(value=countries_over_50, label="Countries Over 50%", interactive=False) | |
# Country details | |
country_selector = gr.Dropdown( | |
choices=[f"{data.name} ({code})" for code, data in app_state.countries.items()], | |
label="Select Country" | |
) | |
country_progress = gr.JSON(label="Country Progress", value={}) | |
# Refresh button | |
refresh_btn = gr.Button("Refresh Map") | |
# UI interaction functions | |
def update_map(): | |
return app_state.to_json() | |
def update_country_details(country_selection): | |
if not country_selection: | |
return {} | |
# Extract the country code from the selection (format: "Country Name (CODE)") | |
code = country_selection.split("(")[-1].replace(")", "").strip() | |
if code in app_state.countries: | |
return asdict(app_state.countries[code]) | |
return {} | |
def update_events(): | |
event = app_state.get_next_event() | |
stats = app_state.get_stats() | |
# If this is a progress update, update the map data | |
if event.get("event_type") == "progress_update": | |
# This will indirectly trigger a map refresh through the change event | |
return event, app_state.to_json(), stats[0], stats[1], stats[2] | |
return event, None, stats[0], stats[1], stats[2] | |
# Set up event handlers | |
refresh_btn.click( | |
fn=update_map, | |
inputs=None, | |
outputs=map_data | |
) | |
country_selector.change( | |
fn=update_country_details, | |
inputs=[country_selector], | |
outputs=[country_progress] | |
) | |
# Alternative approach to load JavaScript without using _js parameter | |
# Create a hidden HTML component to hold our script | |
js_holder = gr.HTML("", visible=False) | |
# When map_data is updated, create a script tag with our D3 code | |
def create_script_tag(data): | |
script_content = MapVisualization.create_d3_script(data) | |
html = f""" | |
<div id="js-executor"> | |
<script> | |
(async () => {{ | |
const scriptFn = {script_content}; | |
await scriptFn(); | |
}})(); | |
</script> | |
</div> | |
""" | |
return html | |
map_data.change( | |
fn=create_script_tag, | |
inputs=map_data, | |
outputs=js_holder | |
) | |
# Use timer to check for new events and update stats | |
gr.Timer(1, active=True).tick( | |
update_events, | |
outputs=[events_json, map_data, total_docs_ui, avg_completion_ui, countries_over_50_ui] | |
) | |
# Initialize D3 on page load using an initial script tag | |
initial_map_script = gr.HTML( | |
f""" | |
<div id="initial-js-executor"> | |
<script> | |
document.addEventListener('DOMContentLoaded', async () => {{ | |
const scriptFn = {MapVisualization.create_d3_script(app_state.to_json())}; | |
await scriptFn(); | |
}}); | |
</script> | |
</div> | |
""", | |
visible=False | |
) | |
return demo | |
# ============================================================================ | |
# MAIN APPLICATION - Entry point and initialization | |
# ============================================================================ | |
def create_application(): | |
"""Create and configure the complete application.""" | |
# Create application components | |
app_state = ApplicationFactory.create_app_state() | |
argilla_service = ApplicationFactory.create_argilla_service() | |
# Clean up existing webhooks | |
ApplicationFactory.cleanup_existing_webhooks(argilla_service.client) | |
# Create and register webhook handler | |
webhook_handler = ApplicationFactory.create_webhook_handler(app_state) | |
# Create the UI | |
demo = ApplicationFactory.create_ui(argilla_service, app_state) | |
# Mount the Gradio app to the FastAPI server | |
server = argilla_service.get_server() | |
gr.mount_gradio_app(server, demo, path="/") | |
return server | |
# Application entry point | |
if __name__ == "__main__": | |
import uvicorn | |
# Create the application | |
server = create_application() | |
# Start the server | |
uvicorn.run(server, host="0.0.0.0", port=7860) |