Spaces:
Sleeping
Sleeping
import streamlit as st | |
import requests | |
import numpy as np | |
import time | |
import threading | |
import json | |
import logging | |
from datetime import datetime | |
# Configure logging | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
# Client Simulator Class | |
class ClientSimulator: | |
def __init__(self, server_url): | |
self.server_url = server_url | |
self.client_id = f"web_client_{int(time.time())}" | |
self.is_running = False | |
self.thread = None | |
self.last_update = "Never" | |
self.last_error = None | |
def start(self): | |
self.is_running = True | |
self.thread = threading.Thread(target=self._run_client, daemon=True) | |
self.thread.start() | |
logger.info(f"Client simulator started for {self.server_url}") | |
def stop(self): | |
self.is_running = False | |
logger.info("Client simulator stopped") | |
def _run_client(self): | |
try: | |
logger.info(f"Attempting to register client {self.client_id} with server {self.server_url}") | |
client_info = { | |
'dataset_size': 100, | |
'model_params': 10000, | |
'capabilities': ['training', 'inference'] | |
} | |
resp = requests.post(f"{self.server_url}/register", | |
json={'client_id': self.client_id, 'client_info': client_info}, | |
timeout=10) | |
if resp.status_code == 200: | |
logger.info(f"Successfully registered client {self.client_id}") | |
st.session_state.training_history.append({ | |
'round': 0, | |
'active_clients': 1, | |
'clients_ready': 0, | |
'timestamp': datetime.now() | |
}) | |
while self.is_running: | |
try: | |
logger.debug(f"Checking training status from {self.server_url}/training_status") | |
status = requests.get(f"{self.server_url}/training_status", timeout=5) | |
if status.status_code == 200: | |
data = status.json() | |
logger.debug(f"Training status: {data}") | |
st.session_state.training_history.append({ | |
'round': data.get('current_round', 0), | |
'active_clients': data.get('active_clients', 0), | |
'clients_ready': data.get('clients_ready', 0), | |
'timestamp': datetime.now() | |
}) | |
if len(st.session_state.training_history) > 50: | |
st.session_state.training_history = st.session_state.training_history[-50:] | |
else: | |
logger.warning(f"Training status returned {status.status_code}: {status.text}") | |
time.sleep(5) | |
except requests.exceptions.Timeout: | |
logger.warning("Timeout while checking training status") | |
self.last_error = "Timeout connecting to server" | |
time.sleep(10) | |
except requests.exceptions.ConnectionError as e: | |
logger.error(f"Connection error while checking training status: {e}") | |
self.last_error = f"Connection error: {e}" | |
time.sleep(10) | |
except Exception as e: | |
logger.error(f"Unexpected error in client simulator: {e}") | |
self.last_error = f"Unexpected error: {e}" | |
time.sleep(10) | |
except requests.exceptions.ConnectionError as e: | |
logger.error(f"Failed to connect to server {self.server_url}: {e}") | |
self.last_error = f"Failed to connect to server: {e}" | |
self.is_running = False | |
except Exception as e: | |
logger.error(f"Failed to start client simulator: {e}") | |
self.last_error = f"Failed to start: {e}" | |
self.is_running = False | |
def check_server_health(server_url): | |
"""Check if server is reachable and healthy""" | |
try: | |
logger.debug(f"Checking server health at {server_url}/health") | |
resp = requests.get(f"{server_url}/health", timeout=5) | |
if resp.status_code == 200: | |
logger.info("Server is healthy") | |
return True, resp.json() | |
else: | |
logger.warning(f"Server health check returned {resp.status_code}") | |
return False, f"HTTP {resp.status_code}: {resp.text}" | |
except requests.exceptions.Timeout: | |
logger.error("Server health check timeout") | |
return False, "Timeout" | |
except requests.exceptions.ConnectionError as e: | |
logger.error(f"Server health check connection error: {e}") | |
return False, f"Connection refused: {e}" | |
except Exception as e: | |
logger.error(f"Server health check unexpected error: {e}") | |
return False, f"Unexpected error: {e}" | |
st.set_page_config(page_title="Federated Credit Scoring Demo", layout="centered") | |
st.title("Federated Credit Scoring Demo") | |
# Sidebar configuration | |
st.sidebar.header("Configuration") | |
SERVER_URL = st.sidebar.text_input("Server URL", value="http://localhost:8080") | |
DEMO_MODE = st.sidebar.checkbox("Demo Mode", value=True) | |
# Initialize session state | |
if 'client_simulator' not in st.session_state: | |
st.session_state.client_simulator = None | |
if 'training_history' not in st.session_state: | |
st.session_state.training_history = [] | |
if 'debug_messages' not in st.session_state: | |
st.session_state.debug_messages = [] | |
# Debug section in sidebar | |
with st.sidebar.expander("Debug Information"): | |
st.write("**Server Status:**") | |
if not DEMO_MODE: | |
is_healthy, health_info = check_server_health(SERVER_URL) | |
if is_healthy: | |
st.success("✅ Server is healthy") | |
st.json(health_info) | |
else: | |
st.error(f"❌ Server error: {health_info}") | |
st.write("**Recent Logs:**") | |
if st.session_state.debug_messages: | |
for msg in st.session_state.debug_messages[-5:]: # Show last 5 messages | |
st.text(msg) | |
else: | |
st.text("No debug messages yet") | |
if st.button("Clear Debug Logs"): | |
st.session_state.debug_messages = [] | |
# Sidebar educational content | |
with st.sidebar.expander("About Federated Learning"): | |
st.markdown(""" | |
**Traditional ML:** Banks send data to central server → Privacy risk | |
**Federated Learning:** | |
- Banks keep data locally | |
- Only model updates are shared | |
- Collaborative learning without data sharing | |
""") | |
# Client Simulator in sidebar | |
st.sidebar.header("Client Simulator") | |
if st.sidebar.button("Start Client"): | |
if not DEMO_MODE: | |
try: | |
st.session_state.client_simulator = ClientSimulator(SERVER_URL) | |
st.session_state.client_simulator.start() | |
st.sidebar.success("Client started!") | |
st.session_state.debug_messages.append(f"{datetime.now()}: Client simulator started") | |
except Exception as e: | |
st.sidebar.error(f"Failed to start client: {e}") | |
st.session_state.debug_messages.append(f"{datetime.now()}: Failed to start client - {e}") | |
else: | |
st.sidebar.warning("Only works in Real Mode") | |
if st.sidebar.button("Stop Client"): | |
if st.session_state.client_simulator: | |
st.session_state.client_simulator.stop() | |
st.session_state.client_simulator = None | |
st.sidebar.success("Client stopped!") | |
st.session_state.debug_messages.append(f"{datetime.now()}: Client simulator stopped") | |
# Main content - focused on core functionality | |
st.header("Enter Customer Features") | |
with st.form("feature_form"): | |
features = [] | |
cols = st.columns(4) | |
for i in range(32): | |
with cols[i % 4]: | |
val = st.number_input(f"Feature {i+1}", value=0.0, format="%.4f", key=f"f_{i}") | |
features.append(val) | |
submitted = st.form_submit_button("Predict Credit Score") | |
# Prediction results | |
if submitted: | |
logger.info(f"Prediction requested with {len(features)} features") | |
if DEMO_MODE: | |
with st.spinner("Processing..."): | |
time.sleep(1) | |
demo_prediction = sum(features) / len(features) * 100 + 500 | |
st.success(f"Predicted Credit Score: {demo_prediction:.2f}") | |
st.session_state.debug_messages.append(f"{datetime.now()}: Demo prediction: {demo_prediction:.2f}") | |
else: | |
try: | |
logger.info(f"Sending prediction request to {SERVER_URL}/predict") | |
with st.spinner("Connecting to server..."): | |
resp = requests.post(f"{SERVER_URL}/predict", json={"features": features}, timeout=10) | |
if resp.status_code == 200: | |
prediction = resp.json().get("prediction") | |
st.success(f"Predicted Credit Score: {prediction:.2f}") | |
st.session_state.debug_messages.append(f"{datetime.now()}: Real prediction: {prediction:.2f}") | |
logger.info(f"Prediction successful: {prediction}") | |
else: | |
error_msg = f"Prediction failed: {resp.json().get('error', 'Unknown error')}" | |
st.error(error_msg) | |
st.session_state.debug_messages.append(f"{datetime.now()}: {error_msg}") | |
logger.error(f"Prediction failed with status {resp.status_code}: {resp.text}") | |
except requests.exceptions.Timeout: | |
error_msg = "Timeout connecting to server" | |
st.error(error_msg) | |
st.session_state.debug_messages.append(f"{datetime.now()}: {error_msg}") | |
logger.error("Prediction request timeout") | |
except requests.exceptions.ConnectionError as e: | |
error_msg = f"Connection error: {e}" | |
st.error(error_msg) | |
st.session_state.debug_messages.append(f"{datetime.now()}: {error_msg}") | |
logger.error(f"Prediction connection error: {e}") | |
except Exception as e: | |
error_msg = f"Unexpected error: {e}" | |
st.error(error_msg) | |
st.session_state.debug_messages.append(f"{datetime.now()}: {error_msg}") | |
logger.error(f"Prediction unexpected error: {e}") | |
# Training progress - simplified | |
st.header("Training Progress") | |
if DEMO_MODE: | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.metric("Round", "3/10") | |
with col2: | |
st.metric("Clients", "3") | |
with col3: | |
st.metric("Accuracy", "85.2%") | |
with col4: | |
st.metric("Status", "Active") | |
else: | |
try: | |
logger.debug(f"Fetching training status from {SERVER_URL}/training_status") | |
status = requests.get(f"{SERVER_URL}/training_status", timeout=5) | |
if status.status_code == 200: | |
data = status.json() | |
logger.debug(f"Training status received: {data}") | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.metric("Round", f"{data.get('current_round', 0)}/{data.get('total_rounds', 10)}") | |
with col2: | |
st.metric("Clients", data.get('active_clients', 0)) | |
with col3: | |
st.metric("Ready", data.get('clients_ready', 0)) | |
with col4: | |
st.metric("Status", "Active" if data.get('training_active', False) else "Inactive") | |
else: | |
error_msg = f"Training status failed: HTTP {status.status_code}" | |
st.warning(error_msg) | |
st.session_state.debug_messages.append(f"{datetime.now()}: {error_msg}") | |
logger.warning(f"Training status returned {status.status_code}: {status.text}") | |
except requests.exceptions.Timeout: | |
error_msg = "Training status timeout" | |
st.warning(error_msg) | |
st.session_state.debug_messages.append(f"{datetime.now()}: {error_msg}") | |
logger.warning("Training status request timeout") | |
except requests.exceptions.ConnectionError as e: | |
error_msg = f"Training status connection error: {e}" | |
st.warning(error_msg) | |
st.session_state.debug_messages.append(f"{datetime.now()}: {error_msg}") | |
logger.error(f"Training status connection error: {e}") | |
except Exception as e: | |
error_msg = f"Training status unexpected error: {e}" | |
st.warning(error_msg) | |
st.session_state.debug_messages.append(f"{datetime.now()}: {error_msg}") | |
logger.error(f"Training status unexpected error: {e}") | |
# Client status in sidebar | |
if st.session_state.client_simulator and not DEMO_MODE: | |
st.sidebar.header("Client Status") | |
if st.session_state.client_simulator.is_running: | |
st.sidebar.success("Connected") | |
st.sidebar.info(f"ID: {st.session_state.client_simulator.client_id}") | |
if st.session_state.client_simulator.last_error: | |
st.sidebar.error(f"Last Error: {st.session_state.client_simulator.last_error}") | |
else: | |
st.sidebar.warning("Disconnected") |