Spaces:
Sleeping
Sleeping
import gradio as gr | |
import asyncio | |
import websockets | |
import json | |
import uuid | |
import argparse | |
import urllib.parse | |
from datetime import datetime | |
import logging | |
import sys | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler(sys.stdout)] | |
) | |
logger = logging.getLogger("chat-node") | |
# Dictionary to store active connections | |
active_connections = {} | |
# Dictionary to store message history for each chat room | |
chat_history = {} | |
# Get node name from URL or command line | |
def get_node_name(): | |
parser = argparse.ArgumentParser(description='Start a chat node with a specific name') | |
parser.add_argument('--node-name', type=str, default=None, help='Name for this chat node') | |
parser.add_argument('--port', type=int, default=7860, help='Port to run the Gradio interface on') | |
args = parser.parse_args() | |
node_name = args.node_name | |
port = args.port | |
# If no node name specified, generate a random one | |
if not node_name: | |
node_name = f"node-{uuid.uuid4().hex[:8]}" | |
return node_name, port | |
async def websocket_handler(websocket, path): | |
"""Handle WebSocket connections.""" | |
try: | |
# Extract room_id from path if present | |
path_parts = path.strip('/').split('/') | |
room_id = path_parts[0] if path_parts else "default" | |
# Register the new client | |
client_id = str(uuid.uuid4()) | |
if room_id not in active_connections: | |
active_connections[room_id] = {} | |
chat_history[room_id] = [] | |
active_connections[room_id][client_id] = websocket | |
# Send welcome message and chat history | |
welcome_msg = { | |
"type": "system", | |
"content": f"Welcome to room '{room_id}'! Connected from node '{NODE_NAME}'", | |
"timestamp": datetime.now().isoformat(), | |
"sender": "system", | |
"room_id": room_id | |
} | |
await websocket.send(json.dumps(welcome_msg)) | |
# Send chat history | |
for msg in chat_history[room_id]: | |
await websocket.send(json.dumps(msg)) | |
# Broadcast join notification | |
join_msg = { | |
"type": "system", | |
"content": f"User joined the room", | |
"timestamp": datetime.now().isoformat(), | |
"sender": "system", | |
"room_id": room_id | |
} | |
await broadcast_message(join_msg, room_id) | |
logger.info(f"New client {client_id} connected to room {room_id}") | |
# Handle messages from this client | |
async for message in websocket: | |
try: | |
data = json.loads(message) | |
# Add metadata to the message | |
data["timestamp"] = datetime.now().isoformat() | |
data["sender_node"] = NODE_NAME | |
data["room_id"] = room_id | |
# Store in history | |
chat_history[room_id].append(data) | |
if len(chat_history[room_id]) > 100: # Limit history to 100 messages | |
chat_history[room_id] = chat_history[room_id][-100:] | |
# Broadcast to all clients in the room | |
await broadcast_message(data, room_id) | |
except json.JSONDecodeError: | |
error_msg = { | |
"type": "error", | |
"content": "Invalid JSON format", | |
"timestamp": datetime.now().isoformat(), | |
"sender": "system", | |
"room_id": room_id | |
} | |
await websocket.send(json.dumps(error_msg)) | |
except websockets.exceptions.ConnectionClosed: | |
logger.info(f"Client {client_id} disconnected from room {room_id}") | |
finally: | |
# Remove the client when disconnected | |
if room_id in active_connections and client_id in active_connections[room_id]: | |
del active_connections[room_id][client_id] | |
# Broadcast leave notification | |
leave_msg = { | |
"type": "system", | |
"content": f"User left the room", | |
"timestamp": datetime.now().isoformat(), | |
"sender": "system", | |
"room_id": room_id | |
} | |
await broadcast_message(leave_msg, room_id) | |
# Clean up empty rooms | |
if not active_connections[room_id]: | |
del active_connections[room_id] | |
# Optionally, you might want to keep the chat history | |
async def broadcast_message(message, room_id): | |
"""Broadcast a message to all clients in a room.""" | |
if room_id in active_connections: | |
disconnected_clients = [] | |
for client_id, websocket in active_connections[room_id].items(): | |
try: | |
await websocket.send(json.dumps(message)) | |
except websockets.exceptions.ConnectionClosed: | |
disconnected_clients.append(client_id) | |
# Clean up disconnected clients | |
for client_id in disconnected_clients: | |
del active_connections[room_id][client_id] | |
async def start_websocket_server(host='0.0.0.0', port=8765): | |
"""Start the WebSocket server.""" | |
server = await websockets.serve(websocket_handler, host, port) | |
logger.info(f"WebSocket server started on ws://{host}:{port}") | |
return server | |
# Global variables for event loop and queue | |
main_event_loop = None | |
message_queue = [] | |
def send_message(message, username, room_id): | |
"""Function to send a message from the Gradio interface.""" | |
if not message.strip(): | |
return None | |
global message_queue | |
msg_data = { | |
"type": "chat", | |
"content": message, | |
"username": username, | |
"room_id": room_id | |
} | |
# Add to queue for processing by the main loop | |
message_queue.append(msg_data) | |
# Format the message for display in the UI | |
formatted_msg = f"{username}: {message}" | |
return formatted_msg | |
def join_room(room_id, chat_history_output): | |
"""Join a specific chat room.""" | |
if not room_id.strip(): | |
return "Please enter a valid room ID", chat_history_output | |
# Sanitize the room ID | |
room_id = urllib.parse.quote(room_id.strip()) | |
# Create the room if it doesn't exist | |
if room_id not in chat_history: | |
chat_history[room_id] = [] | |
# Format existing messages | |
formatted_history = [] | |
for msg in chat_history[room_id]: | |
if msg.get("type") == "chat": | |
formatted_history.append(f"{msg.get('username', 'Anonymous')}: {msg.get('content', '')}") | |
elif msg.get("type") == "system": | |
formatted_history.append(f"System: {msg.get('content', '')}") | |
return f"Joined room: {room_id}", formatted_history | |
def create_gradio_interface(): | |
"""Create and return the Gradio interface.""" | |
with gr.Blocks(title=f"Chat Node: {NODE_NAME}") as interface: | |
gr.Markdown(f"# Chat Node: {NODE_NAME}") | |
gr.Markdown("Join a room by entering a room ID below or create a new one.") | |
with gr.Row(): | |
room_id_input = gr.Textbox(label="Room ID", placeholder="Enter room ID") | |
join_button = gr.Button("Join Room") | |
chat_history_output = gr.Textbox(label="Chat History", lines=15, max_lines=15) | |
with gr.Row(): | |
username_input = gr.Textbox(label="Username", placeholder="Enter your username", value="User") | |
message_input = gr.Textbox(label="Message", placeholder="Type your message here") | |
send_button = gr.Button("Send") | |
# Current room display | |
current_room_display = gr.Textbox(label="Current Room", value="Not joined any room yet") | |
# Event handlers | |
join_button.click( | |
join_room, | |
inputs=[room_id_input, chat_history_output], | |
outputs=[current_room_display, chat_history_output] | |
) | |
def send_and_clear(message, username, room_id): | |
if not room_id.startswith("Joined room:"): | |
return "Please join a room first", message | |
actual_room_id = room_id.replace("Joined room: ", "").strip() | |
formatted_msg = send_message(message, username, actual_room_id) | |
if formatted_msg: | |
return "", formatted_msg | |
return message, None | |
send_button.click( | |
send_and_clear, | |
inputs=[message_input, username_input, current_room_display], | |
outputs=[message_input, chat_history_output] | |
) | |
# Enter key to send message | |
message_input.submit( | |
send_and_clear, | |
inputs=[message_input, username_input, current_room_display], | |
outputs=[message_input, chat_history_output] | |
) | |
return interface | |
async def process_message_queue(): | |
"""Process messages in the queue and broadcast them.""" | |
global message_queue | |
while True: | |
# Check if there are messages to process | |
if message_queue: | |
# Get the oldest message | |
msg_data = message_queue.pop(0) | |
# Broadcast it | |
await broadcast_message(msg_data, msg_data["room_id"]) | |
# Sleep to avoid busy-waiting | |
await asyncio.sleep(0.1) | |
async def main(): | |
"""Main function to start the application.""" | |
global NODE_NAME, main_event_loop | |
NODE_NAME, port = get_node_name() | |
# Store the main event loop for later use | |
main_event_loop = asyncio.get_running_loop() | |
# Start WebSocket server | |
server = await start_websocket_server() | |
# Start message queue processor | |
asyncio.create_task(process_message_queue()) | |
# Create and launch Gradio interface | |
interface = create_gradio_interface() | |
# Custom middleware to extract node name from URL query parameters | |
from starlette.middleware.base import BaseHTTPMiddleware | |
class NodeNameMiddleware(BaseHTTPMiddleware): | |
async def dispatch(self, request, call_next): | |
global NODE_NAME | |
query_params = dict(request.query_params) | |
if "node_name" in query_params: | |
NODE_NAME = query_params["node_name"] | |
logger.info(f"Node name set to {NODE_NAME} from URL parameter") | |
response = await call_next(request) | |
return response | |
# Apply middleware | |
app = gr.routes.App.create_app(interface) | |
app.add_middleware(NodeNameMiddleware) | |
# Launch with the modified app | |
gr.routes.mount_gradio_app(app, interface, path="/") | |
# Run the FastAPI app with uvicorn | |
import uvicorn | |
config = uvicorn.Config(app, host="0.0.0.0", port=port) | |
server = uvicorn.Server(config) | |
logger.info(f"Starting Gradio interface on http://0.0.0.0:{port} with node name '{NODE_NAME}'") | |
# Start message processor | |
logger.info("Starting message queue processor") | |
# Run the server and keep it running | |
await server.serve() | |
if __name__ == "__main__": | |
asyncio.run(main()) |