Spaces:
Sleeping
Sleeping
import gradio as gr | |
import asyncio | |
import websockets | |
import json | |
import uuid | |
import argparse | |
import urllib.parse | |
from datetime import datetime | |
import logging | |
import sys | |
import os | |
import time | |
from pathlib import Path | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler(sys.stdout)] | |
) | |
logger = logging.getLogger("chat-node") | |
# Dictionary to store active connections | |
active_connections = {} | |
# Dictionary to store message history for each chat room (in-memory cache) | |
chat_history = {} | |
# Dictionary to track file modification times | |
file_modification_times = {} | |
# Dictionary to track users in each room/sector | |
sector_users = {} | |
# Global variables for event loop and queue | |
main_event_loop = None | |
message_queue = [] | |
# Grid dimensions for 2D sector map | |
GRID_WIDTH = 10 | |
GRID_HEIGHT = 10 | |
# Directory to store persistent chat history | |
HISTORY_DIR = "chat_history" | |
# Create history directory if it doesn't exist | |
os.makedirs(HISTORY_DIR, exist_ok=True) | |
# README.md file that won't be listed or deleted | |
README_PATH = os.path.join(HISTORY_DIR, "README.md") | |
if not os.path.exists(README_PATH): | |
with open(README_PATH, "w") as f: | |
f.write("# Chat History\n\nThis directory contains persistent chat history files.\n") | |
# Get node name from URL or command line | |
def get_node_name(): | |
parser = argparse.ArgumentParser(description='Start a chat node with a specific name') | |
parser.add_argument('--node-name', type=str, default=None, help='Name for this chat node') | |
parser.add_argument('--port', type=int, default=7860, help='Port to run the Gradio interface on') | |
parser.add_argument('--ws-port', type=int, default=8765, help='Port to run the WebSocket server on') | |
args = parser.parse_args() | |
node_name = args.node_name | |
port = args.port | |
ws_port = args.ws_port | |
# If no node name specified, generate a random one | |
if not node_name: | |
node_name = f"node-{uuid.uuid4().hex[:8]}" | |
return node_name, port, ws_port | |
def get_room_history_file(room_id): | |
"""Get the filename for a room's history.""" | |
# Create timestamp-based log files | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
return os.path.join(HISTORY_DIR, f"{room_id}_{timestamp}.jsonl") | |
def get_all_room_history_files(room_id): | |
"""Get all history files for a specific room.""" | |
files = [] | |
for file in os.listdir(HISTORY_DIR): | |
if file.startswith(f"{room_id}_") and file.endswith(".jsonl"): | |
files.append(os.path.join(HISTORY_DIR, file)) | |
# Sort by modification time (newest first) | |
files.sort(key=lambda x: os.path.getmtime(x), reverse=True) | |
return files | |
def get_all_history_files(): | |
"""Get a list of all unique room IDs from history files, sorted by modification time (newest first).""" | |
try: | |
rooms = {} # room_id -> (newest_file_path, mod_time) | |
for file in os.listdir(HISTORY_DIR): | |
if file.endswith(".jsonl"): | |
file_path = os.path.join(HISTORY_DIR, file) | |
mod_time = os.path.getmtime(file_path) | |
# Extract room_id from filename (format: roomid_timestamp.jsonl) | |
parts = file.split('_', 1) | |
if len(parts) > 0: | |
room_id = parts[0] | |
# Keep track of newest file for each room | |
if room_id not in rooms or mod_time > rooms[room_id][1]: | |
rooms[room_id] = (file_path, mod_time) | |
# Convert to list and sort by modification time | |
history_files = [(room_id, file_path, mod_time) for room_id, (file_path, mod_time) in rooms.items()] | |
history_files.sort(key=lambda x: x[2], reverse=True) | |
return history_files | |
except Exception as e: | |
logger.error(f"Error in get_all_history_files: {e}") | |
return [] # Return empty list on error | |
def load_room_history(room_id): | |
"""Load chat history for a room from all persistent storage files.""" | |
if room_id not in chat_history: | |
chat_history[room_id] = [] | |
# Get all history files for this room | |
history_files = get_all_room_history_files(room_id) | |
# Track file modification times | |
for file in history_files: | |
if file not in file_modification_times: | |
file_modification_times[file] = os.path.getmtime(file) | |
# Load messages from all files | |
messages = [] | |
for history_file in history_files: | |
try: | |
with open(history_file, 'r') as f: | |
for line in f: | |
line = line.strip() | |
if line: # Skip empty lines | |
try: | |
data = json.loads(line) | |
messages.append(data) | |
except json.JSONDecodeError: | |
logger.error(f"Error parsing JSON line in {history_file}") | |
except Exception as e: | |
logger.error(f"Error loading history from {history_file}: {e}") | |
# Sort by timestamp | |
messages.sort(key=lambda x: x.get("timestamp", ""), reverse=False) | |
chat_history[room_id] = messages | |
logger.info(f"Loaded {len(messages)} messages from {len(history_files)} files for room {room_id}") | |
# Track users in this sector | |
if room_id not in sector_users: | |
sector_users[room_id] = set() | |
return chat_history[room_id] | |
def save_message_to_history(room_id, message): | |
"""Save a single message to the newest history file for a room.""" | |
# Get the newest history file or create a new one | |
history_files = get_all_room_history_files(room_id) | |
if not history_files: | |
# Create a new file | |
history_file = get_room_history_file(room_id) | |
else: | |
# Use the newest file if it's less than 1 MB, otherwise create a new one | |
newest_file = history_files[0] | |
if os.path.getsize(newest_file) > 1024 * 1024: # 1 MB | |
history_file = get_room_history_file(room_id) | |
else: | |
history_file = newest_file | |
try: | |
# Append the message as a single line of JSON | |
with open(history_file, 'a') as f: | |
f.write(json.dumps(message) + '\n') | |
# Update modification time | |
file_modification_times[history_file] = os.path.getmtime(history_file) | |
logger.debug(f"Saved message to {history_file}") | |
except Exception as e: | |
logger.error(f"Error saving message to {history_file}: {e}") | |
def check_for_new_messages(): | |
"""Check for new messages in all history files.""" | |
updated_rooms = set() | |
# Check all files in the history directory | |
for file in os.listdir(HISTORY_DIR): | |
if file.endswith(".jsonl"): | |
file_path = os.path.join(HISTORY_DIR, file) | |
current_mtime = os.path.getmtime(file_path) | |
# Check if this file is new or has been modified | |
if file_path not in file_modification_times or current_mtime > file_modification_times[file_path]: | |
# Extract room_id from filename | |
parts = file.split('_', 1) | |
if len(parts) > 0: | |
room_id = parts[0] | |
updated_rooms.add(room_id) | |
# Update tracked modification time | |
file_modification_times[file_path] = current_mtime | |
# Reload history for updated rooms | |
for room_id in updated_rooms: | |
if room_id in chat_history: | |
# Remember we had this room loaded | |
old_history_len = len(chat_history[room_id]) | |
# Clear and reload | |
chat_history[room_id] = [] | |
load_room_history(room_id) | |
new_history_len = len(chat_history[room_id]) | |
if new_history_len > old_history_len: | |
logger.info(f"Found {new_history_len - old_history_len} new messages for room {room_id}") | |
return updated_rooms | |
def get_sector_coordinates(room_id): | |
"""Convert a room ID to grid coordinates, or assign new ones.""" | |
try: | |
# Try to parse room ID as "x,y" | |
if ',' in room_id: | |
x, y = map(int, room_id.split(',')) | |
return max(0, min(x, GRID_WIDTH-1)), max(0, min(y, GRID_HEIGHT-1)) | |
except: | |
pass | |
# Hash the room_id string to get stable coordinates | |
hash_val = hash(room_id) | |
x = abs(hash_val) % GRID_WIDTH | |
y = abs(hash_val >> 8) % GRID_HEIGHT | |
return x, y | |
def generate_sector_map(): | |
"""Generate an ASCII representation of the sector map.""" | |
# Initialize empty grid | |
grid = [[' ' for _ in range(GRID_WIDTH)] for _ in range(GRID_HEIGHT)] | |
# Place active rooms with user counts | |
for room_id, users in sector_users.items(): | |
if users: # Only show rooms with users | |
x, y = get_sector_coordinates(room_id) | |
user_count = len(users) | |
grid[y][x] = str(min(user_count, 9)) if user_count < 10 else '+' | |
# Create ASCII representation | |
header = ' ' + ''.join([str(i % 10) for i in range(GRID_WIDTH)]) | |
map_str = header + '\n' | |
for y in range(GRID_HEIGHT): | |
row = f"{y % 10}|" | |
for x in range(GRID_WIDTH): | |
row += grid[y][x] | |
row += '|' | |
map_str += row + '\n' | |
footer = ' ' + ''.join([str(i % 10) for i in range(GRID_WIDTH)]) | |
map_str += footer | |
return f"```\n{map_str}\n```\n\nLegend: Number indicates users in sector. '+' means 10+ users." | |
def list_available_rooms(): | |
"""List all available chat rooms with their last activity time and user count.""" | |
try: | |
history_files = get_all_history_files() | |
if not history_files: | |
return "No chat rooms available yet. Create one by joining a room!" | |
room_list = "### Available Chat Rooms (Sectors)\n\n" | |
room_list += "| Room ID | Sector | Users | Last Activity |\n" | |
room_list += "|---------|--------|-------|---------------|\n" | |
for room_id, file_path, mod_time in history_files: | |
x, y = get_sector_coordinates(room_id) | |
user_count = len(sector_users.get(room_id, set())) | |
last_activity = datetime.fromtimestamp(mod_time).strftime("%Y-%m-%d %H:%M:%S") | |
room_list += f"| {room_id} | ({x},{y}) | {user_count} | {last_activity} |\n" | |
room_list += "\n\n### Sector Map\n\n" + generate_sector_map() | |
return room_list | |
except Exception as e: | |
logger.error(f"Error in list_available_rooms: {e}") | |
return f"Error listing rooms: {str(e)}" | |
async def broadcast_message(message, room_id): | |
"""Broadcast a message to all clients in a room.""" | |
if room_id in active_connections: | |
disconnected_clients = [] | |
for client_id, websocket in active_connections[room_id].items(): | |
try: | |
await websocket.send(json.dumps(message)) | |
except websockets.exceptions.ConnectionClosed: | |
disconnected_clients.append(client_id) | |
# Clean up disconnected clients | |
for client_id in disconnected_clients: | |
del active_connections[room_id][client_id] | |
async def start_websocket_server(host='0.0.0.0', port=8765): | |
"""Start the WebSocket server.""" | |
try: | |
server = await websockets.serve(websocket_handler, host, port) | |
logger.info(f"WebSocket server started on ws://{host}:{port}") | |
return server | |
except OSError as e: | |
if e.errno == 98: # Address already in use | |
logger.warning(f"Port {port} already in use, trying port {port+1}") | |
# Try a different port | |
return await start_websocket_server(host, port+1) | |
else: | |
# If it's a different error, re-raise it | |
raise | |
def send_message(message, username, room_id): | |
"""Function to send a message from the Gradio interface.""" | |
if not message.strip(): | |
return None | |
global message_queue | |
msg_data = { | |
"type": "chat", | |
"content": message, | |
"username": username, | |
"room_id": room_id | |
} | |
# Add to queue for processing by the main loop | |
message_queue.append(msg_data) | |
# Format the message for display in the UI | |
formatted_msg = f"{username}: {message}" | |
return formatted_msg | |
def send_clear_command(): | |
"""Send a command to clear all chat history.""" | |
global message_queue | |
msg_data = { | |
"type": "command", | |
"command": "clear_history", | |
"username": "System" | |
} | |
# Add to queue for processing by the main loop | |
message_queue.append(msg_data) | |
return "🧹 Clearing all chat history..." | |
async def clear_all_history(): | |
"""Clear all chat history for all rooms.""" | |
global chat_history, sector_users | |
# Clear in-memory history | |
chat_history = {} | |
sector_users = {} | |
# Delete all history files except README.md | |
for file in os.listdir(HISTORY_DIR): | |
if file.endswith(".jsonl"): | |
try: | |
os.remove(os.path.join(HISTORY_DIR, file)) | |
except Exception as e: | |
logger.error(f"Error removing file {file}: {e}") | |
# Broadcast clear message to all rooms | |
clear_msg = { | |
"type": "system", | |
"content": "🧹 All chat history has been cleared by a user", | |
"timestamp": datetime.now().isoformat(), | |
"sender": "system" | |
} | |
for room_id in list(active_connections.keys()): | |
clear_msg["room_id"] = room_id | |
await broadcast_message(clear_msg, room_id) | |
logger.info("All chat history cleared") | |
return "All chat history cleared" | |
def join_room(room_id, chat_history_output): | |
"""Join a specific chat room.""" | |
if not room_id.strip(): | |
return "Please enter a valid room ID", chat_history_output | |
# Sanitize the room ID | |
room_id = urllib.parse.quote(room_id.strip()) | |
# Load room history from persistent storage | |
history = load_room_history(room_id) | |
# Get sector coordinates | |
x, y = get_sector_coordinates(room_id) | |
# Format existing messages | |
formatted_history = [f"You are now in Sector ({x},{y}) - Room ID: {room_id}"] | |
formatted_history.append(f"Sector Map:\n{generate_sector_map()}") | |
for msg in history: | |
if msg.get("type") == "chat": | |
sender_node = f" [{msg.get('sender_node', 'unknown')}]" if "sender_node" in msg else "" | |
time_str = "" | |
if "timestamp" in msg: | |
try: | |
dt = datetime.fromisoformat(msg["timestamp"]) | |
time_str = f"[{dt.strftime('%H:%M:%S')}] " | |
except: | |
pass | |
formatted_history.append(f"{time_str}{msg.get('username', 'Anonymous')}{sender_node}: {msg.get('content', '')}") | |
elif msg.get("type") == "system": | |
formatted_history.append(f"System: {msg.get('content', '')}") | |
return f"Joined room: {room_id}", formatted_history | |
async def websocket_handler(websocket, path): | |
"""Handle WebSocket connections.""" | |
client_id = str(uuid.uuid4()) | |
room_id = "default" # Default initialization to avoid reference errors | |
try: | |
# Extract room_id from path if present | |
path_parts = path.strip('/').split('/') | |
room_id = path_parts[0] if path_parts else "default" | |
# Register the new client | |
if room_id not in active_connections: | |
active_connections[room_id] = {} | |
active_connections[room_id][client_id] = websocket | |
# Add user to sector map | |
if room_id not in sector_users: | |
sector_users[room_id] = set() | |
sector_users[room_id].add(client_id) | |
# Get sector coordinates | |
x, y = get_sector_coordinates(room_id) | |
# Load or initialize chat history | |
room_history = load_room_history(room_id) | |
# Send welcome message | |
welcome_msg = { | |
"type": "system", | |
"content": f"Welcome to room '{room_id}' (Sector {x},{y})! Connected from node '{NODE_NAME}'", | |
"timestamp": datetime.now().isoformat(), | |
"sender": "system", | |
"room_id": room_id | |
} | |
await websocket.send(json.dumps(welcome_msg)) | |
# Send sector map | |
map_msg = { | |
"type": "system", | |
"content": f"Sector Map:\n{generate_sector_map()}", | |
"timestamp": datetime.now().isoformat(), | |
"sender": "system", | |
"room_id": room_id | |
} | |
await websocket.send(json.dumps(map_msg)) | |
# Send chat history | |
for msg in room_history: | |
await websocket.send(json.dumps(msg)) | |
# Broadcast join notification | |
join_msg = { | |
"type": "system", | |
"content": f"User joined the room (Sector {x},{y}) - {len(sector_users[room_id])} users now present", | |
"timestamp": datetime.now().isoformat(), | |
"sender": "system", | |
"room_id": room_id | |
} | |
await broadcast_message(join_msg, room_id) | |
save_message_to_history(room_id, join_msg) | |
logger.info(f"New client {client_id} connected to room {room_id} (Sector {x},{y})") | |
# Handle messages from this client | |
async for message in websocket: | |
try: | |
data = json.loads(message) | |
# Check for clear command | |
if data.get("type") == "command" and data.get("command") == "clear_history": | |
result = await clear_all_history() | |
continue | |
# Check for map request | |
if data.get("type") == "command" and data.get("command") == "show_map": | |
map_msg = { | |
"type": "system", | |
"content": f"Sector Map:\n{generate_sector_map()}", | |
"timestamp": datetime.now().isoformat(), | |
"sender": "system", | |
"room_id": room_id | |
} | |
await websocket.send(json.dumps(map_msg)) | |
continue | |
# Add metadata to the message | |
data["timestamp"] = datetime.now().isoformat() | |
data["sender_node"] = NODE_NAME | |
data["room_id"] = room_id | |
# Store in history | |
chat_history[room_id].append(data) | |
if len(chat_history[room_id]) > 500: # Increased limit to 500 messages | |
chat_history[room_id] = chat_history[room_id][-500:] | |
# Save to persistent storage | |
save_message_to_history(room_id, data) | |
# Broadcast to all clients in the room | |
await broadcast_message(data, room_id) | |
except json.JSONDecodeError: | |
error_msg = { | |
"type": "error", | |
"content": "Invalid JSON format", | |
"timestamp": datetime.now().isoformat(), | |
"sender": "system", | |
"room_id": room_id | |
} | |
await websocket.send(json.dumps(error_msg)) | |
except websockets.exceptions.ConnectionClosed: | |
logger.info(f"Client {client_id} disconnected from room {room_id}") | |
finally: | |
# Remove the client when disconnected | |
if room_id in active_connections and client_id in active_connections[room_id]: | |
del active_connections[room_id][client_id] | |
# Remove user from sector map | |
if room_id in sector_users and client_id in sector_users[room_id]: | |
sector_users[room_id].remove(client_id) | |
# Get sector coordinates | |
x, y = get_sector_coordinates(room_id) | |
# Broadcast leave notification | |
leave_msg = { | |
"type": "system", | |
"content": f"User left the room (Sector {x},{y}) - {len(sector_users.get(room_id, set()))} users remaining", | |
"timestamp": datetime.now().isoformat(), | |
"sender": "system", | |
"room_id": room_id | |
} | |
await broadcast_message(leave_msg, room_id) | |
save_message_to_history(room_id, leave_msg) | |
# Clean up empty rooms (but keep history) | |
if not active_connections[room_id]: | |
del active_connections[room_id] | |
async def process_message_queue(): | |
"""Process messages in the queue and broadcast them.""" | |
global message_queue | |
while True: | |
# Check if there are messages to process | |
if message_queue: | |
# Get the oldest message | |
msg_data = message_queue.pop(0) | |
# Check for clear command | |
if msg_data.get("type") == "command" and msg_data.get("command") == "clear_history": | |
await clear_all_history() | |
elif "room_id" in msg_data: | |
# Regular message - add to history and broadcast | |
room_id = msg_data["room_id"] | |
# Add timestamp if not present | |
if "timestamp" not in msg_data: | |
msg_data["timestamp"] = datetime.now().isoformat() | |
# Add sender node if not present | |
if "sender_node" not in msg_data: | |
msg_data["sender_node"] = NODE_NAME | |
# Store in memory history | |
if room_id not in chat_history: | |
chat_history[room_id] = [] | |
chat_history[room_id].append(msg_data) | |
# Save to file | |
save_message_to_history(room_id, msg_data) | |
# Broadcast to all clients in the room | |
await broadcast_message(msg_data, room_id) | |
# Check for file changes every second | |
updated_rooms = check_for_new_messages() | |
# If any rooms were updated, notify clients | |
for room_id in updated_rooms: | |
if room_id in active_connections: | |
# Find the newest messages that clients might not have seen | |
# This is a simplification - in a real app, you'd track which messages each client has seen | |
if room_id in chat_history and chat_history[room_id]: | |
# Get the latest 5 messages as an example | |
latest_messages = chat_history[room_id][-5:] | |
for msg in latest_messages: | |
# Only broadcast messages from other nodes (to avoid duplicates) | |
if msg.get("sender_node") != NODE_NAME: | |
await broadcast_message(msg, room_id) | |
# Sleep to avoid busy-waiting | |
await asyncio.sleep(1.0) # Check every second | |
def create_gradio_interface(): | |
"""Create and return the Gradio interface.""" | |
with gr.Blocks(title=f"Chat Node: {NODE_NAME}") as interface: | |
gr.Markdown(f"# Chat Node: {NODE_NAME}") | |
gr.Markdown("Join a room by entering a room ID below or create a new one.") | |
# Room list and management | |
with gr.Row(): | |
with gr.Column(scale=3): | |
room_list = gr.Markdown(value="Loading available rooms...") | |
refresh_button = gr.Button("🔄 Refresh Room List") | |
with gr.Column(scale=1): | |
clear_button = gr.Button("🧹 Clear All Chat History", variant="stop") | |
# Join room controls with 2D grid input | |
with gr.Row(): | |
with gr.Column(scale=2): | |
room_id_input = gr.Textbox(label="Room ID", placeholder="Enter room ID or use x,y coordinates") | |
join_button = gr.Button("Join Room") | |
with gr.Column(scale=1): | |
with gr.Row(): | |
x_coord = gr.Number(label="X", value=0, minimum=0, maximum=GRID_WIDTH-1, step=1) | |
y_coord = gr.Number(label="Y", value=0, minimum=0, maximum=GRID_HEIGHT-1, step=1) | |
grid_join_button = gr.Button("Join by Coordinates") | |
# Chat area with multiline support | |
chat_history_output = gr.Textbox(label="Chat History", lines=20, max_lines=20) | |
# Message controls with multiline support | |
with gr.Row(): | |
username_input = gr.Textbox(label="Username", placeholder="Enter your username", value="User") | |
with gr.Column(scale=3): | |
message_input = gr.Textbox( | |
label="Message", | |
placeholder="Type your message here. Press Shift+Enter for new line, Enter to send.", | |
lines=3 | |
) | |
with gr.Column(scale=1): | |
send_button = gr.Button("Send") | |
map_button = gr.Button("🗺️ Show Map") | |
# Current room display | |
current_room_display = gr.Textbox(label="Current Room", value="Not joined any room yet") | |
# Event handlers | |
refresh_button.click( | |
list_available_rooms, | |
inputs=[], | |
outputs=[room_list] | |
) | |
clear_button.click( | |
send_clear_command, | |
inputs=[], | |
outputs=[room_list] | |
) | |
def join_by_coordinates(x, y): | |
"""Join a room using grid coordinates.""" | |
room_id = f"{int(x)},{int(y)}" | |
return room_id | |
# Link grid coordinates to room ID | |
grid_join_button.click( | |
join_by_coordinates, | |
inputs=[x_coord, y_coord], | |
outputs=[room_id_input] | |
).then( | |
join_room, | |
inputs=[room_id_input, chat_history_output], | |
outputs=[current_room_display, chat_history_output] | |
) | |
join_button.click( | |
join_room, | |
inputs=[room_id_input, chat_history_output], | |
outputs=[current_room_display, chat_history_output] | |
) | |
def send_and_clear(message, username, room_id): | |
if not room_id.startswith("Joined room:"): | |
return "Please join a room first", message | |
actual_room_id = room_id.replace("Joined room: ", "").strip() | |
# Support for multi-line messages | |
message_lines = message.strip().split("\n") | |
formatted_msg = "" | |
for line in message_lines: | |
if line.strip(): # Skip empty lines | |
sent_msg = send_message(line.strip(), username, actual_room_id) | |
if sent_msg: | |
formatted_msg += sent_msg + "\n" | |
if formatted_msg: | |
return "", formatted_msg | |
return message, None | |
send_button.click( | |
send_and_clear, | |
inputs=[message_input, username_input, current_room_display], | |
outputs=[message_input, chat_history_output] | |
) | |
def show_sector_map(room_id): | |
if not room_id.startswith("Joined room:"): | |
return "Please join a room first to view the map" | |
return generate_sector_map() | |
map_button.click( | |
show_sector_map, | |
inputs=[current_room_display], | |
outputs=[chat_history_output] | |
) | |
# Handle Enter key for sending, Shift+Enter for new line | |
def on_message_submit(message, username, room_id): | |
# Simply call send_and_clear | |
return send_and_clear(message, username, room_id) | |
message_input.submit( | |
on_message_submit, | |
inputs=[message_input, username_input, current_room_display], | |
outputs=[message_input, chat_history_output] | |
) | |
# On load, populate room list | |
interface.load( | |
list_available_rooms, | |
inputs=[], | |
outputs=[room_list] | |
) | |
return interface | |
async def main(): | |
"""Main function to start the application.""" | |
global NODE_NAME, main_event_loop | |
NODE_NAME, port, ws_port = get_node_name() | |
# Store the main event loop for later use | |
main_event_loop = asyncio.get_running_loop() | |
# Start WebSocket server | |
try: | |
server = await start_websocket_server(port=ws_port) | |
# Start message queue processor | |
asyncio.create_task(process_message_queue()) | |
# Create and launch Gradio interface | |
interface = create_gradio_interface() | |
# Custom middleware to extract node name from URL query parameters | |
from starlette.middleware.base import BaseHTTPMiddleware | |
class NodeNameMiddleware(BaseHTTPMiddleware): | |
async def dispatch(self, request, call_next): | |
global NODE_NAME | |
query_params = dict(request.query_params) | |
if "node_name" in query_params: | |
NODE_NAME = query_params["node_name"] | |
logger.info(f"Node name set to {NODE_NAME} from URL parameter") | |
response = await call_next(request) | |
return response | |
# Apply middleware | |
app = gr.routes.App.create_app(interface) | |
app.add_middleware(NodeNameMiddleware) | |
# Launch with the modified app | |
gr.routes.mount_gradio_app(app, interface, path="/") | |
# Run the FastAPI app with uvicorn | |
import uvicorn | |
config = uvicorn.Config(app, host="0.0.0.0", port=port) | |
# Try to create the server with retries for port conflicts | |
server_started = False | |
max_retries = 5 | |
current_port = port | |
for attempt in range(max_retries): | |
try: | |
config = uvicorn.Config(app, host="0.0.0.0", port=current_port) | |
server = uvicorn.Server(config) | |
logger.info(f"Starting Gradio interface on http://0.0.0.0:{current_port} with node name '{NODE_NAME}'") | |
logger.info("Starting message queue processor") | |
await server.serve() | |
server_started = True | |
break | |
except OSError as e: | |
if e.errno == 98: # Address already in use | |
current_port += 1 | |
logger.warning(f"Port {current_port-1} already in use, trying port {current_port}") | |
else: | |
raise | |
if not server_started: | |
logger.error(f"Failed to start server after {max_retries} attempts") | |
except Exception as e: | |
logger.error(f"Error in main: {e}") | |
raise | |
if __name__ == "__main__": | |
asyncio.run(main()) |