Spaces:
Runtime error
Runtime error
| import os | |
| from queue import Queue | |
| import json | |
| import gradio as gr | |
| import argilla as rg | |
| from argilla.webhooks import webhook_listener | |
| from dataclasses import dataclass, field, asdict | |
| from typing import Dict, List, Optional, Tuple, Any, Callable | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================ | |
| # DATA MODELS - Clear definition of data structures | |
| # ============================================================================ | |
| class CountryData: | |
| """Data model for country information and annotation progress.""" | |
| name: str | |
| target: int | |
| count: int = 0 | |
| percent: int = 0 | |
| def update_progress(self, new_count: Optional[int] = None): | |
| """Update the progress percentage based on count/target.""" | |
| if new_count is not None: | |
| self.count = new_count | |
| self.percent = min(100, int((self.count / self.target) * 100)) | |
| return self | |
| class Event: | |
| """Data model for events in the system.""" | |
| event_type: str | |
| timestamp: str = "" | |
| country: str = "" | |
| count: int = 0 | |
| percent: int = 0 | |
| error: str = "" | |
| class ApplicationState: | |
| """Central state management for the application.""" | |
| countries: Dict[str, CountryData] = field(default_factory=dict) | |
| events: Queue = field(default_factory=Queue) | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert state to a serializable dictionary for the UI.""" | |
| return { | |
| code: asdict(data) for code, data in self.countries.items() | |
| } | |
| def to_json(self) -> str: | |
| """Convert state to JSON for the UI.""" | |
| return json.dumps(self.to_dict()) | |
| def add_event(self, event: Event): | |
| """Add an event to the queue.""" | |
| self.events.put(asdict(event)) | |
| def get_next_event(self) -> Dict[str, Any]: | |
| """Get the next event from the queue.""" | |
| if not self.events.empty(): | |
| return self.events.get() | |
| return {} | |
| def update_country_progress(self, country_code: str, count: Optional[int] = None) -> bool: | |
| """Update a country's annotation progress.""" | |
| if country_code in self.countries: | |
| if count is not None: | |
| self.countries[country_code].count = count | |
| self.countries[country_code].update_progress() | |
| # Create and add a progress update event | |
| self.add_event(Event( | |
| event_type="progress_update", | |
| country=self.countries[country_code].name, | |
| count=self.countries[country_code].count, | |
| percent=self.countries[country_code].percent | |
| )) | |
| return True | |
| return False | |
| def increment_country_progress(self, country_code: str) -> bool: | |
| """Increment a country's annotation count by 1.""" | |
| if country_code in self.countries: | |
| self.countries[country_code].count += 1 | |
| return self.update_country_progress(country_code) | |
| return False | |
| def get_stats(self) -> Tuple[int, float, int]: | |
| """Calculate overall statistics.""" | |
| total = sum(data.count for data in self.countries.values()) | |
| percentages = [data.percent for data in self.countries.values()] | |
| avg = sum(percentages) / len(percentages) if percentages else 0 | |
| countries_50_plus = sum(1 for p in percentages if p >= 50) | |
| return total, avg, countries_50_plus | |
| # ============================================================================ | |
| # CONFIGURATION - Separated from business logic | |
| # ============================================================================ | |
| class Config: | |
| """Configuration for the application.""" | |
| # Country mapping (ISO code to name and target) | |
| COUNTRY_MAPPING = { | |
| "MX": {"name": "Mexico", "target": 1000}, | |
| "AR": {"name": "Argentina", "target": 800}, | |
| "CO": {"name": "Colombia", "target": 700}, | |
| "CL": {"name": "Chile", "target": 600}, | |
| "PE": {"name": "Peru", "target": 600}, | |
| "ES": {"name": "Spain", "target": 1200}, | |
| "BR": {"name": "Brazil", "target": 1000}, | |
| "VE": {"name": "Venezuela", "target": 500}, | |
| "EC": {"name": "Ecuador", "target": 400}, | |
| "BO": {"name": "Bolivia", "target": 300}, | |
| "PY": {"name": "Paraguay", "target": 300}, | |
| "UY": {"name": "Uruguay", "target": 300}, | |
| "CR": {"name": "Costa Rica", "target": 250}, | |
| "PA": {"name": "Panama", "target": 250}, | |
| "DO": {"name": "Dominican Republic", "target": 300}, | |
| "GT": {"name": "Guatemala", "target": 250}, | |
| "HN": {"name": "Honduras", "target": 200}, | |
| "SV": {"name": "El Salvador", "target": 200}, | |
| "NI": {"name": "Nicaragua", "target": 200}, | |
| "CU": {"name": "Cuba", "target": 300} | |
| } | |
| def create_country_data(cls) -> Dict[str, CountryData]: | |
| """Create CountryData objects from the mapping.""" | |
| return { | |
| code: CountryData( | |
| name=data["name"], | |
| target=data["target"] | |
| ) for code, data in cls.COUNTRY_MAPPING.items() | |
| } | |
| # ============================================================================ | |
| # SERVICES - Business logic separated from presentation and data access | |
| # ============================================================================ | |
| class ArgillaService: | |
| """Service for interacting with Argilla.""" | |
| def __init__(self, api_url: Optional[str] = None, api_key: Optional[str] = None): | |
| """Initialize the Argilla service.""" | |
| self.api_url = api_url or os.getenv("ARGILLA_API_URL") | |
| self.api_key = api_key or os.getenv("ARGILLA_API_KEY") | |
| self.client = rg.Argilla( | |
| api_url=self.api_url, | |
| api_key=self.api_key, | |
| ) | |
| self.server = rg.get_webhook_server() | |
| def get_server(self): | |
| """Get the Argilla webhook server.""" | |
| return self.server | |
| def get_client_base_url(self) -> str: | |
| """Get the base URL of the Argilla client.""" | |
| return self.client.http_client.base_url if hasattr(self.client, 'http_client') else "Not connected" | |
| class CountryMappingService: | |
| """Service for mapping between dataset names and country codes.""" | |
| def find_country_code_from_dataset(dataset_name: str) -> Optional[str]: | |
| """ | |
| Try to extract a country code from a dataset name by matching | |
| country names in the dataset name. | |
| """ | |
| dataset_name_lower = dataset_name.lower() | |
| for code, data in Config.COUNTRY_MAPPING.items(): | |
| country_name = data["name"].lower() | |
| if country_name in dataset_name_lower: | |
| return code | |
| return None | |
| # ============================================================================ | |
| # UI COMPONENTS - Presentation layer separated from business logic | |
| # ============================================================================ | |
| class MapVisualization: | |
| """Component for D3.js map visualization.""" | |
| def create_map_html() -> str: | |
| """Create the initial HTML container for the map.""" | |
| return """ | |
| <div id="map-container" style="width:100%; height:600px; position:relative; background-color:#111;"> | |
| <div style="display:flex; justify-content:center; align-items:center; height:100%; color:white; font-family:sans-serif;"> | |
| Loading map visualization... | |
| </div> | |
| </div> | |
| <div id="tooltip" style="position:absolute; background-color:rgba(0,0,0,0.8); border-radius:5px; padding:8px; color:white; font-size:12px; pointer-events:none; opacity:0; transition:opacity 0.3s;"></div> | |
| """ | |
| def create_d3_script(progress_data: str) -> str: | |
| """Create the D3.js script for rendering the map.""" | |
| return f""" | |
| async () => {{ | |
| // Load D3.js modules | |
| const script1 = document.createElement("script"); | |
| script1.src = "https://cdn.jsdelivr.net/npm/d3@7"; | |
| document.head.appendChild(script1); | |
| // Wait for D3 to load | |
| await new Promise(resolve => {{ | |
| script1.onload = resolve; | |
| }}); | |
| console.log("D3 loaded successfully"); | |
| // Load topojson | |
| const script2 = document.createElement("script"); | |
| script2.src = "https://cdn.jsdelivr.net/npm/topojson@3"; | |
| document.head.appendChild(script2); | |
| await new Promise(resolve => {{ | |
| script2.onload = resolve; | |
| }}); | |
| console.log("TopoJSON loaded successfully"); | |
| // The progress data passed from Python | |
| const progressData = {progress_data}; | |
| // Set up the SVG container | |
| const mapContainer = document.getElementById('map-container'); | |
| mapContainer.innerHTML = ''; // Clear loading message | |
| const width = mapContainer.clientWidth; | |
| const height = 600; | |
| const svg = d3.select("#map-container") | |
| .append("svg") | |
| .attr("width", width) | |
| .attr("height", height) | |
| .attr("viewBox", `0 0 ${{width}} ${{height}}`) | |
| .style("background-color", "#111"); | |
| // Define color scale | |
| const colorScale = d3.scaleLinear() | |
| .domain([0, 100]) | |
| .range(["#4a1942", "#f32b7b"]); | |
| // Set up projection focused on Latin America and Spain | |
| const projection = d3.geoMercator() | |
| .center([-60, 0]) | |
| .scale(width / 5) | |
| .translate([width / 2, height / 2]); | |
| const path = d3.geoPath().projection(projection); | |
| // Tooltip setup | |
| const tooltip = d3.select("#tooltip"); | |
| // Load the world GeoJSON data | |
| const response = await fetch("https://raw.githubusercontent.com/holtzy/D3-graph-gallery/master/DATA/world.geojson"); | |
| const data = await response.json(); | |
| // Draw the map | |
| svg.selectAll("path") | |
| .data(data.features) | |
| .enter() | |
| .append("path") | |
| .attr("d", path) | |
| .attr("stroke", "#f32b7b") | |
| .attr("stroke-width", 1) | |
| .attr("fill", d => {{ | |
| // Get the ISO code from the properties | |
| const iso = d.properties.iso_a2; | |
| if (progressData[iso]) {{ | |
| return colorScale(progressData[iso].percent); | |
| }} | |
| return "#2d3748"; // Default gray for non-tracked countries | |
| }}) | |
| .on("mouseover", function(event, d) {{ | |
| const iso = d.properties.iso_a2; | |
| d3.select(this) | |
| .attr("stroke", "#4a1942") | |
| .attr("stroke-width", 2); | |
| if (progressData[iso]) {{ | |
| tooltip.style("opacity", 1) | |
| .style("left", (event.pageX + 15) + "px") | |
| .style("top", (event.pageY + 15) + "px") | |
| .html(` | |
| <strong>${{progressData[iso].name}}</strong><br/> | |
| Documents: ${{progressData[iso].count.toLocaleString()}}/${{progressData[iso].target.toLocaleString()}}<br/> | |
| Completion: ${{progressData[iso].percent}}% | |
| `); | |
| }} | |
| }}) | |
| .on("mousemove", function(event) {{ | |
| tooltip.style("left", (event.pageX + 15) + "px") | |
| .style("top", (event.pageY + 15) + "px"); | |
| }}) | |
| .on("mouseout", function() {{ | |
| d3.select(this) | |
| .attr("stroke", "#f32b7b") | |
| .attr("stroke-width", 1); | |
| tooltip.style("opacity", 0); | |
| }}); | |
| // Add legend | |
| const legendWidth = Math.min(width - 40, 200); | |
| const legendHeight = 15; | |
| const legendX = width - legendWidth - 20; | |
| const legend = svg.append("g") | |
| .attr("class", "legend") | |
| .attr("transform", `translate(${{legendX}}, 30)`); | |
| // Create gradient for legend | |
| const defs = svg.append("defs"); | |
| const gradient = defs.append("linearGradient") | |
| .attr("id", "dataGradient") | |
| .attr("x1", "0%") | |
| .attr("y1", "0%") | |
| .attr("x2", "100%") | |
| .attr("y2", "0%"); | |
| gradient.append("stop") | |
| .attr("offset", "0%") | |
| .attr("stop-color", "#4a1942"); | |
| gradient.append("stop") | |
| .attr("offset", "100%") | |
| .attr("stop-color", "#f32b7b"); | |
| // Add legend title | |
| legend.append("text") | |
| .attr("x", legendWidth / 2) | |
| .attr("y", -10) | |
| .attr("text-anchor", "middle") | |
| .attr("font-size", "12px") | |
| .attr("fill", "#f1f5f9") | |
| .text("Annotation Progress"); | |
| // Add legend rectangle | |
| legend.append("rect") | |
| .attr("width", legendWidth) | |
| .attr("height", legendHeight) | |
| .attr("rx", 2) | |
| .attr("ry", 2) | |
| .style("fill", "url(#dataGradient)"); | |
| // Add legend labels | |
| legend.append("text") | |
| .attr("x", 0) | |
| .attr("y", legendHeight + 15) | |
| .attr("text-anchor", "start") | |
| .attr("font-size", "10px") | |
| .attr("fill", "#94a3b8") | |
| .text("0%"); | |
| legend.append("text") | |
| .attr("x", legendWidth / 2) | |
| .attr("y", legendHeight + 15) | |
| .attr("text-anchor", "middle") | |
| .attr("font-size", "10px") | |
| .attr("fill", "#94a3b8") | |
| .text("50%"); | |
| legend.append("text") | |
| .attr("x", legendWidth) | |
| .attr("y", legendHeight + 15) | |
| .attr("text-anchor", "end") | |
| .attr("font-size", "10px") | |
| .attr("fill", "#94a3b8") | |
| .text("100%"); | |
| // Handle window resize | |
| globalThis.resizeMap = () => {{ | |
| const width = mapContainer.clientWidth; | |
| // Update SVG dimensions | |
| d3.select("svg") | |
| .attr("width", width) | |
| .attr("viewBox", `0 0 ${{width}} ${{height}}`); | |
| // Update projection | |
| projection.scale(width / 5) | |
| .translate([width / 2, height / 2]); | |
| // Update paths | |
| d3.selectAll("path").attr("d", path); | |
| // Update legend position | |
| const legendWidth = Math.min(width - 40, 200); | |
| const legendX = width - legendWidth - 20; | |
| d3.select(".legend") | |
| .attr("transform", `translate(${{legendX}}, 30)`); | |
| }}; | |
| window.addEventListener('resize', globalThis.resizeMap); | |
| }} | |
| """ | |
| # ============================================================================ | |
| # APPLICATION FACTORY - Creates and configures the application | |
| # ============================================================================ | |
| class ApplicationFactory: | |
| """Factory for creating the application components.""" | |
| def create_app_state(cls) -> ApplicationState: | |
| """Create and initialize the application state.""" | |
| state = ApplicationState(countries=Config.create_country_data()) | |
| # Initialize with some sample data | |
| for code in ["MX", "AR", "CO", "ES"]: | |
| sample_count = int(state.countries[code].target * 0.3) | |
| state.update_country_progress(code, sample_count) | |
| state.update_country_progress("BR", int(state.countries["BR"].target * 0.5)) | |
| state.update_country_progress("CL", int(state.countries["CL"].target * 0.7)) | |
| return state | |
| def create_argilla_service(cls) -> ArgillaService: | |
| """Create the Argilla service.""" | |
| return ArgillaService() | |
| def cleanup_existing_webhooks(argilla_client): | |
| """Clean up existing webhooks to avoid warnings.""" | |
| try: | |
| # Get existing webhooks | |
| existing_webhooks = argilla_client.webhooks.list() | |
| # Look for our webhook | |
| for webhook in existing_webhooks: | |
| if "handle_response_created" in getattr(webhook, 'url', ''): | |
| logger.info(f"Removing existing webhook: {webhook.id}") | |
| argilla_client.webhooks.delete(webhook.id) | |
| break | |
| except Exception as e: | |
| logger.warning(f"Could not clean up webhooks: {e}") | |
| def create_webhook_handler(cls, app_state: ApplicationState) -> Callable: | |
| """Create the webhook handler function.""" | |
| country_service = CountryMappingService() | |
| # Define the webhook handler | |
| async def handle_response_created(response, type, timestamp): | |
| try: | |
| # Log the event | |
| logger.info(f"Received webhook event: {type} at {timestamp}") | |
| # Add basic event to the queue | |
| app_state.add_event(Event( | |
| event_type=type, | |
| timestamp=str(timestamp) | |
| )) | |
| # Extract dataset name | |
| record = response.record | |
| dataset_name = record.dataset.name | |
| logger.info(f"Processing response for dataset: {dataset_name}") | |
| # Find country code from dataset name | |
| country_code = country_service.find_country_code_from_dataset(dataset_name) | |
| # Update country progress if found | |
| if country_code: | |
| success = app_state.increment_country_progress(country_code) | |
| if success: | |
| country_data = app_state.countries[country_code] | |
| logger.info( | |
| f"Updated progress for {country_data.name}: " | |
| f"{country_data.count}/{country_data.target} ({country_data.percent}%)" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in webhook handler: {e}", exc_info=True) | |
| app_state.add_event(Event( | |
| event_type="error", | |
| error=str(e) | |
| )) | |
| return handle_response_created | |
| def create_ui(cls, argilla_service: ArgillaService, app_state: ApplicationState): | |
| """Create the Gradio UI.""" | |
| # Create and configure the Gradio interface | |
| demo = gr.Blocks(theme=gr.themes.Soft(primary_hue="pink", secondary_hue="purple")) | |
| with demo: | |
| argilla_server = argilla_service.get_client_base_url() | |
| with gr.Row(): | |
| gr.Markdown(f""" | |
| # Latin America & Spain Annotation Progress Map | |
| ### Connected to Argilla server: {argilla_server} | |
| This dashboard visualizes annotation progress across Latin America and Spain. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Map visualization - empty at first | |
| map_html = gr.HTML(MapVisualization.create_map_html(), label="Annotation Progress Map") | |
| # Hidden element to store map data | |
| map_data = gr.JSON(value=app_state.to_json(), visible=False) | |
| with gr.Column(scale=1): | |
| # Overall statistics | |
| with gr.Group(): | |
| gr.Markdown("### Statistics") | |
| total_docs, avg_completion, countries_over_50 = app_state.get_stats() | |
| total_docs_ui = gr.Number(value=total_docs, label="Total Documents", interactive=False) | |
| avg_completion_ui = gr.Number(value=avg_completion, label="Average Completion (%)", interactive=False) | |
| countries_over_50_ui = gr.Number(value=countries_over_50, label="Countries Over 50%", interactive=False) | |
| # Country details | |
| with gr.Group(): | |
| gr.Markdown("### Country Details") | |
| country_selector = gr.Dropdown( | |
| choices=[f"{data.name} ({code})" for code, data in app_state.countries.items()], | |
| label="Select Country" | |
| ) | |
| country_progress = gr.JSON(label="Country Progress", value={}) | |
| # Refresh button | |
| refresh_btn = gr.Button("Refresh Map") | |
| # UI interaction functions | |
| def update_map(): | |
| return app_state.to_json() | |
| def update_country_details(country_selection): | |
| if not country_selection: | |
| return {} | |
| # Extract the country code from the selection (format: "Country Name (CODE)") | |
| code = country_selection.split("(")[-1].replace(")", "").strip() | |
| if code in app_state.countries: | |
| return asdict(app_state.countries[code]) | |
| return {} | |
| def update_events(): | |
| event = app_state.get_next_event() | |
| stats = app_state.get_stats() | |
| # If this is a progress update, update the map data | |
| if event.get("event_type") == "progress_update": | |
| # This will indirectly trigger a map refresh through the change event | |
| return event, app_state.to_json(), stats[0], stats[1], stats[2] | |
| return event, None, stats[0], stats[1], stats[2] | |
| # Set up event handlers | |
| refresh_btn.click( | |
| fn=update_map, | |
| inputs=None, | |
| outputs=map_data | |
| ) | |
| country_selector.change( | |
| fn=update_country_details, | |
| inputs=[country_selector], | |
| outputs=[country_progress] | |
| ) | |
| # Alternative approach to load JavaScript without using _js parameter | |
| # Create a hidden HTML component to hold our script | |
| js_holder = gr.HTML("", visible=False) | |
| # When map_data is updated, create a script tag with our D3 code | |
| def create_script_tag(data): | |
| script_content = MapVisualization.create_d3_script(data) | |
| html = f""" | |
| <div id="js-executor"> | |
| <script> | |
| (async () => {{ | |
| const scriptFn = {script_content}; | |
| await scriptFn(); | |
| }})(); | |
| </script> | |
| </div> | |
| """ | |
| return html | |
| map_data.change( | |
| fn=create_script_tag, | |
| inputs=map_data, | |
| outputs=js_holder | |
| ) | |
| # Use timer to check for new events and update stats | |
| gr.Timer(1, active=True).tick( | |
| update_events, | |
| outputs=[events_json, map_data, total_docs_ui, avg_completion_ui, countries_over_50_ui] | |
| ) | |
| # Initialize D3 on page load using an initial script tag | |
| initial_map_script = gr.HTML( | |
| f""" | |
| <div id="initial-js-executor"> | |
| <script> | |
| document.addEventListener('DOMContentLoaded', async () => {{ | |
| const scriptFn = {MapVisualization.create_d3_script(app_state.to_json())}; | |
| await scriptFn(); | |
| }}); | |
| </script> | |
| </div> | |
| """, | |
| visible=False | |
| ) | |
| return demo | |
| # ============================================================================ | |
| # MAIN APPLICATION - Entry point and initialization | |
| # ============================================================================ | |
| def create_application(): | |
| """Create and configure the complete application.""" | |
| # Create application components | |
| app_state = ApplicationFactory.create_app_state() | |
| argilla_service = ApplicationFactory.create_argilla_service() | |
| # Clean up existing webhooks | |
| ApplicationFactory.cleanup_existing_webhooks(argilla_service.client) | |
| # Create and register webhook handler | |
| webhook_handler = ApplicationFactory.create_webhook_handler(app_state) | |
| # Create the UI | |
| demo = ApplicationFactory.create_ui(argilla_service, app_state) | |
| # Mount the Gradio app to the FastAPI server | |
| server = argilla_service.get_server() | |
| gr.mount_gradio_app(server, demo, path="/") | |
| return server | |
| # Application entry point | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Create the application | |
| server = create_application() | |
| # Start the server | |
| uvicorn.run(server, host="0.0.0.0", port=7860) |