import gradio as gr import asyncio import websockets import json import uuid import argparse import urllib.parse from datetime import datetime import logging import sys # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger("chat-node") # Dictionary to store active connections active_connections = {} # Dictionary to store message history for each chat room (in-memory cache) chat_history = {} # Directory to store persistent chat history HISTORY_DIR = "chat_history" import os import shutil from pathlib import Path # Create history directory if it doesn't exist os.makedirs(HISTORY_DIR, exist_ok=True) # README.md file that won't be listed or deleted README_PATH = os.path.join(HISTORY_DIR, "README.md") if not os.path.exists(README_PATH): with open(README_PATH, "w") as f: f.write("# Chat History\n\nThis directory contains persistent chat history files.\n") # Get node name from URL or command line def get_node_name(): parser = argparse.ArgumentParser(description='Start a chat node with a specific name') parser.add_argument('--node-name', type=str, default=None, help='Name for this chat node') parser.add_argument('--port', type=int, default=7860, help='Port to run the Gradio interface on') args = parser.parse_args() node_name = args.node_name port = args.port # If no node name specified, generate a random one if not node_name: node_name = f"node-{uuid.uuid4().hex[:8]}" return node_name, port def get_room_history_file(room_id): """Get the filename for a room's history.""" return os.path.join(HISTORY_DIR, f"{room_id}.md") def load_room_history(room_id): """Load chat history for a room from persistent storage.""" if room_id not in chat_history: chat_history[room_id] = [] # Try to load from file history_file = get_room_history_file(room_id) if os.path.exists(history_file): try: with open(history_file, 'r') as f: history_json = f.read() if history_json.strip(): loaded_history = json.loads(history_json) chat_history[room_id] = loaded_history logger.info(f"Loaded {len(loaded_history)} messages from history for room {room_id}") except Exception as e: logger.error(f"Error loading history for room {room_id}: {e}") return chat_history[room_id] def save_room_history(room_id): """Save chat history for a room to persistent storage.""" if room_id in chat_history and chat_history[room_id]: history_file = get_room_history_file(room_id) try: with open(history_file, 'w') as f: json.dump(chat_history[room_id], f) logger.info(f"Saved {len(chat_history[room_id])} messages to history for room {room_id}") except Exception as e: logger.error(f"Error saving history for room {room_id}: {e}") def get_all_history_files(): """Get a list of all chat history files, sorted by modification time (newest first).""" history_files = [] for file in os.listdir(HISTORY_DIR): if file.endswith(".md") and file != "README.md": file_path = os.path.join(HISTORY_DIR, file) mod_time = os.path.getmtime(file_path) room_id = file[:-3] # Remove .md extension history_files.append((room_id, file_path, mod_time)) # Sort by modification time (newest first) history_files.sort(key=lambda x: x[2], reverse=True) return history_files async def clear_all_history(): """Clear all chat history for all rooms.""" global chat_history # Clear in-memory history chat_history = {} # Delete all history files except README.md for file in os.listdir(HISTORY_DIR): if file.endswith(".md") and file != "README.md": os.remove(os.path.join(HISTORY_DIR, file)) # Broadcast clear message to all rooms clear_msg = { "type": "system", "content": "🧹 All chat history has been cleared by a user", "timestamp": datetime.now().isoformat(), "sender": "system" } for room_id in list(active_connections.keys()): clear_msg["room_id"] = room_id await broadcast_message(clear_msg, room_id) logger.info("All chat history cleared") return "All chat history cleared" async def websocket_handler(websocket, path): """Handle WebSocket connections.""" try: # Extract room_id from path if present path_parts = path.strip('/').split('/') room_id = path_parts[0] if path_parts else "default" # Register the new client client_id = str(uuid.uuid4()) if room_id not in active_connections: active_connections[room_id] = {} active_connections[room_id][client_id] = websocket # Load or initialize chat history room_history = load_room_history(room_id) # Send welcome message welcome_msg = { "type": "system", "content": f"Welcome to room '{room_id}'! Connected from node '{NODE_NAME}'", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await websocket.send(json.dumps(welcome_msg)) # Send chat history for msg in room_history: await websocket.send(json.dumps(msg)) # Broadcast join notification join_msg = { "type": "system", "content": f"User joined the room", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await broadcast_message(join_msg, room_id) logger.info(f"New client {client_id} connected to room {room_id}") # Handle messages from this client async for message in websocket: try: data = json.loads(message) # Check for clear command if data.get("type") == "command" and data.get("command") == "clear_history": result = await clear_all_history() continue # Add metadata to the message data["timestamp"] = datetime.now().isoformat() data["sender_node"] = NODE_NAME data["room_id"] = room_id # Store in history chat_history[room_id].append(data) if len(chat_history[room_id]) > 500: # Increased limit to 500 messages chat_history[room_id] = chat_history[room_id][-500:] # Save to persistent storage save_room_history(room_id) # Broadcast to all clients in the room await broadcast_message(data, room_id) except json.JSONDecodeError: error_msg = { "type": "error", "content": "Invalid JSON format", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await websocket.send(json.dumps(error_msg)) except websockets.exceptions.ConnectionClosed: logger.info(f"Client {client_id} disconnected from room {room_id}") finally: # Remove the client when disconnected if room_id in active_connections and client_id in active_connections[room_id]: del active_connections[room_id][client_id] # Broadcast leave notification leave_msg = { "type": "system", "content": f"User left the room", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await broadcast_message(leave_msg, room_id) # Clean up empty rooms (but keep history) if not active_connections[room_id]: del active_connections[room_id] async def broadcast_message(message, room_id): """Broadcast a message to all clients in a room.""" if room_id in active_connections: disconnected_clients = [] for client_id, websocket in active_connections[room_id].items(): try: await websocket.send(json.dumps(message)) except websockets.exceptions.ConnectionClosed: disconnected_clients.append(client_id) # Clean up disconnected clients for client_id in disconnected_clients: del active_connections[room_id][client_id] async def start_websocket_server(host='0.0.0.0', port=8765): """Start the WebSocket server.""" server = await websockets.serve(websocket_handler, host, port) logger.info(f"WebSocket server started on ws://{host}:{port}") return server # Global variables for event loop and queue main_event_loop = None message_queue = [] def send_message(message, username, room_id): """Function to send a message from the Gradio interface.""" if not message.strip(): return None global message_queue msg_data = { "type": "chat", "content": message, "username": username, "room_id": room_id } # Add to queue for processing by the main loop message_queue.append(msg_data) # Format the message for display in the UI formatted_msg = f"{username}: {message}" return formatted_msg def join_room(room_id, chat_history_output): """Join a specific chat room.""" if not room_id.strip(): return "Please enter a valid room ID", chat_history_output # Sanitize the room ID room_id = urllib.parse.quote(room_id.strip()) # Load room history from persistent storage history = load_room_history(room_id) # Format existing messages formatted_history = [] for msg in history: if msg.get("type") == "chat": sender_node = f" [{msg.get('sender_node', 'unknown')}]" if "sender_node" in msg else "" time_str = "" if "timestamp" in msg: try: dt = datetime.fromisoformat(msg["timestamp"]) time_str = f"[{dt.strftime('%H:%M:%S')}] " except: pass formatted_history.append(f"{time_str}{msg.get('username', 'Anonymous')}{sender_node}: {msg.get('content', '')}") elif msg.get("type") == "system": formatted_history.append(f"System: {msg.get('content', '')}") return f"Joined room: {room_id}", formatted_history def send_clear_command(): """Send a command to clear all chat history.""" global message_queue msg_data = { "type": "command", "command": "clear_history", "username": "System" } # Add to queue for processing by the main loop message_queue.append(msg_data) return "🧹 Clearing all chat history..." def list_available_rooms(): """List all available chat rooms with their last activity time.""" history_files = get_all_history_files() if not history_files: return "No chat rooms available yet. Create one by joining a room!" room_list = "### Available Chat Rooms\n\n" for room_id, file_path, mod_time in history_files: last_activity = datetime.fromtimestamp(mod_time).strftime("%Y-%m-%d %H:%M:%S") room_list += f"- **{room_id}**: Last activity {last_activity}\n" return room_list def create_gradio_interface(): """Create and return the Gradio interface.""" with gr.Blocks(title=f"Chat Node: {NODE_NAME}") as interface: gr.Markdown(f"# Chat Node: {NODE_NAME}") gr.Markdown("Join a room by entering a room ID below or create a new one.") # Room list and management with gr.Row(): with gr.Column(scale=3): room_list = gr.Markdown(value="Loading available rooms...") refresh_button = gr.Button("🔄 Refresh Room List") with gr.Column(scale=1): clear_button = gr.Button("🧹 Clear All Chat History", variant="stop") # Join room controls with gr.Row(): room_id_input = gr.Textbox(label="Room ID", placeholder="Enter room ID") join_button = gr.Button("Join Room") # Chat area chat_history_output = gr.Textbox(label="Chat History", lines=20, max_lines=20) # Message controls with gr.Row(): username_input = gr.Textbox(label="Username", placeholder="Enter your username", value="User") message_input = gr.Textbox(label="Message", placeholder="Type your message here") send_button = gr.Button("Send") # Current room display current_room_display = gr.Textbox(label="Current Room", value="Not joined any room yet") # Event handlers refresh_button.click( list_available_rooms, inputs=[], outputs=[room_list] ) clear_button.click( send_clear_command, inputs=[], outputs=[room_list] ) join_button.click( join_room, inputs=[room_id_input, chat_history_output], outputs=[current_room_display, chat_history_output] ) def send_and_clear(message, username, room_id): if not room_id.startswith("Joined room:"): return "Please join a room first", message actual_room_id = room_id.replace("Joined room: ", "").strip() formatted_msg = send_message(message, username, actual_room_id) if formatted_msg: return "", formatted_msg return message, None send_button.click( send_and_clear, inputs=[message_input, username_input, current_room_display], outputs=[message_input, chat_history_output] ) # Enter key to send message message_input.submit( send_and_clear, inputs=[message_input, username_input, current_room_display], outputs=[message_input, chat_history_output] ) # On load, populate room list interface.load( list_available_rooms, inputs=[], outputs=[room_list] ) return interface async def process_message_queue(): """Process messages in the queue and broadcast them.""" global message_queue while True: # Check if there are messages to process if message_queue: # Get the oldest message msg_data = message_queue.pop(0) # Broadcast it await broadcast_message(msg_data, msg_data["room_id"]) # Sleep to avoid busy-waiting await asyncio.sleep(0.1) async def main(): """Main function to start the application.""" global NODE_NAME, main_event_loop NODE_NAME, port = get_node_name() # Store the main event loop for later use main_event_loop = asyncio.get_running_loop() # Start WebSocket server server = await start_websocket_server() # Start message queue processor asyncio.create_task(process_message_queue()) # Create and launch Gradio interface interface = create_gradio_interface() # Custom middleware to extract node name from URL query parameters from starlette.middleware.base import BaseHTTPMiddleware class NodeNameMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): global NODE_NAME query_params = dict(request.query_params) if "node_name" in query_params: NODE_NAME = query_params["node_name"] logger.info(f"Node name set to {NODE_NAME} from URL parameter") response = await call_next(request) return response # Apply middleware app = gr.routes.App.create_app(interface) app.add_middleware(NodeNameMiddleware) # Launch with the modified app gr.routes.mount_gradio_app(app, interface, path="/") # Run the FastAPI app with uvicorn import uvicorn config = uvicorn.Config(app, host="0.0.0.0", port=port) server = uvicorn.Server(config) logger.info(f"Starting Gradio interface on http://0.0.0.0:{port} with node name '{NODE_NAME}'") # Start message processor logger.info("Starting message queue processor") # Run the server and keep it running await server.serve() if __name__ == "__main__": asyncio.run(main())