Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import onnxruntime as ort | |
| import sys | |
| from pathlib import Path | |
| sys.path.append("rd2l_pred") | |
| from training_data_prep import list_format, modification, league_money, df_gen | |
| from feature_engineering import heroes, hero_information | |
| # Global variables for model and feature columns | |
| MODEL = None | |
| FEATURE_COLUMNS = None | |
| def load_model(): | |
| """Load the ONNX model and get input features""" | |
| global MODEL, FEATURE_COLUMNS | |
| try: | |
| model_path = Path("model/rd2l_forest.onnx") | |
| if not model_path.exists(): | |
| return "Model file not found at: " + str(model_path) | |
| MODEL = ort.InferenceSession(str(model_path)) | |
| # Load feature columns from prediction data | |
| try: | |
| FEATURE_COLUMNS = pd.read_csv("result_prediction_data_prepped.csv").columns.tolist() | |
| except: | |
| try: | |
| FEATURE_COLUMNS = pd.read_csv("prediction_data_prepped.csv").columns.tolist() | |
| except: | |
| return "Error: Could not find prediction data files to determine feature structure" | |
| return "Model loaded successfully" | |
| except Exception as e: | |
| return f"Error loading model: {str(e)}" | |
| def process_player_data(player_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5): | |
| """Process player data similar to training pipeline""" | |
| try: | |
| # Clean player ID from URL if needed | |
| if "/" in player_id: | |
| player_id = player_id.split("/")[-1] | |
| # Create initial player series | |
| player_data = { | |
| "player_id": player_id, | |
| "mmr": float(mmr), | |
| "p1": int(comf_1), | |
| "p2": int(comf_2), | |
| "p3": int(comf_3), | |
| "p4": int(comf_4), | |
| "p5": int(comf_5) | |
| } | |
| # Read the example row from prediction_data_prepped.csv to get the expected structure | |
| try: | |
| pred_data = pd.read_csv("prediction_data_prepped.csv") | |
| print("\nReference columns from prediction_data_prepped.csv:") | |
| print(sorted(pred_data.columns.tolist())) | |
| print(f"Number of reference columns: {len(pred_data.columns)}") | |
| if not pred_data.empty: | |
| # Get column structure from the first row | |
| for col in pred_data.columns: | |
| if col not in player_data and col != 'Predicted_Cost': # Skip the target variable | |
| player_data[col] = 0 | |
| except Exception as e: | |
| print(f"Warning - Error reading prediction data template: {str(e)}") | |
| # Get hero statistics using OpenDota API | |
| try: | |
| hero_stats = hero_information(player_id) | |
| player_data.update(hero_stats.to_dict()) | |
| # Add season identifier to match training data format | |
| player_season = f"{player_id}_S34" # Assuming current season is 34 | |
| temp_dict = {} | |
| temp_dict[player_season] = 1.0 # Set current season flag to 1.0 | |
| player_data.update(temp_dict) | |
| except Exception as e: | |
| print(f"Warning - Error fetching hero data: {str(e)}") | |
| # If hero stats fail, add placeholder values | |
| player_data.update({ | |
| "total_games_played": 0, | |
| "total_winrate": 0.0 | |
| }) | |
| # Convert to DataFrame for consistency with training | |
| df = pd.DataFrame([player_data]) | |
| # Print out the columns we have in our processed data | |
| print("\nProcessed data columns:") | |
| print(sorted(df.columns.tolist())) | |
| print(f"Number of processed columns: {len(df.columns)}") | |
| # Find missing columns | |
| expected_cols = set(pred_data.columns) - {'Predicted_Cost'} # Remove target variable | |
| actual_cols = set(df.columns) | |
| missing_cols = expected_cols - actual_cols | |
| extra_cols = actual_cols - expected_cols | |
| if missing_cols: | |
| print("\nMissing columns:") | |
| print(sorted(list(missing_cols))) | |
| if extra_cols: | |
| print("\nExtra columns:") | |
| print(sorted(list(extra_cols))) | |
| # Ensure we have all needed columns and remove any extras | |
| for col in missing_cols: | |
| df[col] = 0 | |
| df = df[list(expected_cols)] | |
| print(f"\nFinal number of columns: {len(df.columns)}") | |
| return df | |
| except Exception as e: | |
| return f"Error processing player data: {str(e)}" | |
| def predict_cost(user_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5): | |
| """Main prediction function for Gradio interface""" | |
| try: | |
| # Check if model is loaded | |
| if MODEL is None: | |
| result = load_model() | |
| if not result.startswith("Model loaded"): | |
| return result | |
| # Process input data | |
| processed_data = process_player_data(user_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5) | |
| if isinstance(processed_data, str): # Error occurred | |
| return processed_data | |
| # Print debug information | |
| print("Processed data shape:", processed_data.shape) | |
| print("Processed data columns:", processed_data.columns.tolist()) | |
| # Make prediction | |
| try: | |
| input_name = MODEL.get_inputs()[0].name | |
| prediction = MODEL.run(None, {input_name: processed_data.values.astype(np.float32)})[0] | |
| predicted_cost = round(float(prediction[0]), 2) | |
| except Exception as e: | |
| return f"Error during prediction: {str(e)}\nProcessed data shape: {processed_data.shape}" | |
| return f"""Predicted Cost: {predicted_cost} | |
| Player Details: | |
| - MMR: {mmr} | |
| - Position Comfort: | |
| * Pos 1: {comf_1} | |
| * Pos 2: {comf_2} | |
| * Pos 3: {comf_3} | |
| * Pos 4: {comf_4} | |
| * Pos 5: {comf_5} | |
| Note: This prediction is based on historical data and player statistics from OpenDota.""" | |
| except Exception as e: | |
| return f"Error in prediction pipeline: {str(e)}" | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=predict_cost, | |
| inputs=[ | |
| gr.Textbox(label="Player ID or Link to OpenDota/Dotabuff", | |
| placeholder="Enter player ID or full profile URL"), | |
| gr.Number(label="MMR", value=3000), | |
| gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 1)"), | |
| gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 2)"), | |
| gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 3)"), | |
| gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 4)"), | |
| gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 5)") | |
| ], | |
| examples=[ | |
| ["https://www.dotabuff.com/players/188649776", 6812, 5, 5, 4, 2, 1] | |
| ], | |
| outputs=gr.Textbox(label="Prediction Results"), | |
| title="RD2L Player Cost Predictor", | |
| description="""This tool predicts the auction cost for RD2L players based on their MMR, | |
| position comfort levels, and historical performance data from OpenDota. | |
| Enter a player's OpenDota ID or profile URL along with their current stats | |
| to get a predicted cost.""", | |
| article="""### How it works | |
| - The predictor uses machine learning trained on historical RD2L draft data | |
| - Player statistics are fetched from OpenDota API | |
| - Position comfort levels range from 1 (least comfortable) to 5 (most comfortable) | |
| - Predictions are based on both current stats and historical performance | |
| ### Notes | |
| - MMR should be the player's current solo MMR | |
| - Position comfort should reflect actual role experience | |
| - Predictions are estimates and may vary from actual draft results""" | |
| ) | |
| # Load model on startup | |
| print(load_model()) | |
| if __name__ == "__main__": | |
| demo.launch() | |