Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import onnxruntime as ort | |
import sys | |
from pathlib import Path | |
sys.path.append("rd2l_pred") | |
from training_data_prep import list_format, modification, league_money, df_gen | |
from feature_engineering import heroes, hero_information | |
# Global variables for model and feature columns | |
MODEL = None | |
FEATURE_COLUMNS = None | |
def load_model(): | |
"""Load the ONNX model and get input features""" | |
global MODEL, FEATURE_COLUMNS | |
try: | |
model_path = Path("model/rd2l_forest.onnx") | |
if not model_path.exists(): | |
return "Model file not found at: " + str(model_path) | |
MODEL = ort.InferenceSession(str(model_path)) | |
# Use the known list of features | |
FEATURE_COLUMNS = ['mmr', 'p1', 'p2', 'p3', 'p4', 'p5', 'count', 'mean', 'std', 'min', 'max', | |
'sum', 'total_games_played', 'total_winrate'] + \ | |
[f'games_{i}' for i in range(1, 139)] + \ | |
[f'winrate_{i}' for i in range(1, 139)] | |
print(f"Number of features loaded: {len(FEATURE_COLUMNS)}") | |
return "Model loaded successfully" | |
except Exception as e: | |
return f"Error loading model: {str(e)}" | |
def process_player_data(player_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5): | |
"""Process player data similar to training pipeline""" | |
try: | |
# Clean player ID from URL if needed | |
if "/" in player_id: | |
player_id = player_id.split("/")[-1] | |
# Create initial player series | |
player_data = { | |
"player_id": player_id, | |
"mmr": float(mmr), | |
"p1": int(comf_1), | |
"p2": int(comf_2), | |
"p3": int(comf_3), | |
"p4": int(comf_4), | |
"p5": int(comf_5) | |
} | |
# Read the example row from prediction_data_prepped.csv to get the expected structure | |
try: | |
pred_data = pd.read_csv("prediction_data_prepped.csv") | |
if not pred_data.empty: | |
# Get column structure from the first row | |
for col in pred_data.columns: | |
if col not in player_data: | |
player_data[col] = 0 | |
except Exception as e: | |
print(f"Warning - Error reading prediction data template: {str(e)}") | |
# Get hero statistics using OpenDota API | |
try: | |
hero_stats = hero_information(player_id) | |
player_data.update(hero_stats.to_dict()) | |
# Add season identifier to match training data format | |
player_season = f"{player_id}_S34" # Assuming current season is 34 | |
temp_dict = {} | |
temp_dict[player_season] = 1.0 # Set current season flag to 1.0 | |
player_data.update(temp_dict) | |
except Exception as e: | |
print(f"Warning - Error fetching hero data: {str(e)}") | |
# If hero stats fail, add placeholder values | |
player_data.update({ | |
"total_games_played": 0, | |
"total_winrate": 0.0 | |
}) | |
# Convert to DataFrame for consistency with training | |
df = pd.DataFrame([player_data]) | |
# Load reference data structure if available | |
try: | |
ref_data = pd.read_csv("result_prediction_data_prepped.csv") | |
if not ref_data.empty: | |
# Get all columns from reference data | |
for col in ref_data.columns: | |
if col not in df.columns: | |
df[col] = 0 | |
# Reorder columns to match reference data | |
df = df[ref_data.columns] | |
except Exception as e: | |
print(f"Warning - Error matching reference data structure: {str(e)}") | |
# Load the expected columns from your prediction data | |
pred_data = pd.read_csv("prediction_data_prepped.csv") | |
expected_columns = pred_data.columns.tolist() | |
# Debug print | |
print(f"\nNumber of expected columns: {len(expected_columns)}") | |
print(f"Number of current columns: {len(df.columns)}") | |
# Find missing columns | |
missing_columns = [col for col in expected_columns if col not in df.columns] | |
extra_columns = [col for col in df.columns if col not in expected_columns] | |
print(f"\nMissing columns: {missing_columns}") | |
print(f"Extra columns: {extra_columns}") | |
# Ensure all expected columns exist | |
for col in expected_columns: | |
if col not in df.columns: | |
df[col] = 0 | |
# Remove any extra columns | |
df = df[expected_columns] | |
print(f"\nFinal number of columns: {len(df.columns)}") | |
print(f"First few columns: {list(df.columns)[:5]}") | |
return df | |
except Exception as e: | |
return f"Error processing player data: {str(e)}" | |
def predict_cost(user_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5): | |
"""Main prediction function for Gradio interface""" | |
try: | |
# Check if model is loaded | |
if MODEL is None: | |
result = load_model() | |
if not result.startswith("Model loaded"): | |
return result | |
# Process input data | |
processed_data = process_player_data(user_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5) | |
if isinstance(processed_data, str): # Error occurred | |
return processed_data | |
# Print debug information | |
print("Processed data shape:", processed_data.shape) | |
print("Processed data columns:", processed_data.columns.tolist()) | |
# Make prediction | |
try: | |
input_name = MODEL.get_inputs()[0].name | |
prediction = MODEL.run(None, {input_name: processed_data.values.astype(np.float32)})[0] | |
predicted_cost = round(float(prediction[0]), 2) | |
except Exception as e: | |
return f"Error during prediction: {str(e)}\nProcessed data shape: {processed_data.shape}" | |
return f"""Predicted Cost: {predicted_cost} | |
Player Details: | |
- MMR: {mmr} | |
- Position Comfort: | |
* Pos 1: {comf_1} | |
* Pos 2: {comf_2} | |
* Pos 3: {comf_3} | |
* Pos 4: {comf_4} | |
* Pos 5: {comf_5} | |
Note: This prediction is based on historical data and player statistics from OpenDota.""" | |
except Exception as e: | |
return f"Error in prediction pipeline: {str(e)}" | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=predict_cost, | |
inputs=[ | |
gr.Textbox(label="Player ID or Link to OpenDota/Dotabuff", | |
placeholder="Enter player ID or full profile URL"), | |
gr.Number(label="MMR", value=3000), | |
gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 1)"), | |
gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 2)"), | |
gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 3)"), | |
gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 4)"), | |
gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 5)") | |
], | |
examples=[ | |
["https://www.dotabuff.com/players/188649776", 6812, 5, 5, 4, 2, 1] | |
], | |
outputs=gr.Textbox(label="Prediction Results"), | |
title="RD2L Player Cost Predictor", | |
description="""This tool predicts the auction cost for RD2L players based on their MMR, | |
position comfort levels, and historical performance data from OpenDota. | |
Enter a player's OpenDota ID or profile URL along with their current stats | |
to get a predicted cost.""", | |
article="""### How it works | |
- The predictor uses machine learning trained on historical RD2L draft data | |
- Player statistics are fetched from OpenDota API | |
- Position comfort levels range from 1 (least comfortable) to 5 (most comfortable) | |
- Predictions are based on both current stats and historical performance | |
### Notes | |
- MMR should be the player's current solo MMR | |
- Position comfort should reflect actual role experience | |
- Predictions are estimates and may vary from actual draft results""" | |
) | |
# Load model on startup | |
print(load_model()) | |
if __name__ == "__main__": | |
demo.launch() | |