nick-leland commited on
Commit
e55fb20
·
1 Parent(s): 27e0d07

Added example and updated the loading of the onnx model

Browse files
Files changed (1) hide show
  1. app.py +62 -19
app.py CHANGED
@@ -8,9 +8,31 @@ sys.path.append("rd2l_pred")
8
  from training_data_prep import list_format, modification, league_money, df_gen
9
  from feature_engineering import heroes, hero_information
10
 
11
- # Load the ONNX model
12
- model_path = Path("model/rd2l_forest.onnx")
13
- session = ort.InferenceSession(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def process_player_data(player_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5):
16
  """Process player data similar to training pipeline"""
@@ -36,19 +58,23 @@ def process_player_data(player_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5):
36
  # Merge hero stats with player data
37
  player_data.update(hero_stats.to_dict())
38
  except Exception as e:
39
- return f"Error fetching hero data: {str(e)}"
 
 
 
 
 
40
 
41
  # Convert to DataFrame for consistency with training
42
  df = pd.DataFrame([player_data])
43
 
44
- # Ensure columns match training data
45
- required_columns = [col.name for col in session.get_inputs()[0].type.tensor_type.shape.dim]
46
- for col in required_columns:
47
- if col not in df.columns:
48
- df[col] = 0
49
-
50
- # Reorder columns to match model input
51
- df = df[required_columns]
52
 
53
  return df
54
  except Exception as e:
@@ -57,18 +83,29 @@ def process_player_data(player_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5):
57
  def predict_cost(user_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5):
58
  """Main prediction function for Gradio interface"""
59
  try:
 
 
 
 
 
 
60
  # Process input data
61
  processed_data = process_player_data(user_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5)
62
 
63
  if isinstance(processed_data, str): # Error occurred
64
  return processed_data
65
-
66
- # Make prediction
67
- input_name = session.get_inputs()[0].name
68
- prediction = session.run(None, {input_name: processed_data.values.astype(np.float32)})[0]
69
 
70
- # Format prediction
71
- predicted_cost = round(float(prediction[0]), 2)
 
 
 
 
 
 
 
 
 
72
 
73
  return f"""Predicted Cost: {predicted_cost}
74
 
@@ -84,7 +121,7 @@ Player Details:
84
  Note: This prediction is based on historical data and player statistics from OpenDota."""
85
 
86
  except Exception as e:
87
- return f"Error making prediction: {str(e)}"
88
 
89
  # Create Gradio interface
90
  demo = gr.Interface(
@@ -99,6 +136,9 @@ demo = gr.Interface(
99
  gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 4)"),
100
  gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 5)")
101
  ],
 
 
 
102
  outputs=gr.Textbox(label="Prediction Results"),
103
  title="RD2L Player Cost Predictor",
104
  description="""This tool predicts the auction cost for RD2L players based on their MMR,
@@ -117,5 +157,8 @@ demo = gr.Interface(
117
  - Predictions are estimates and may vary from actual draft results"""
118
  )
119
 
 
 
 
120
  if __name__ == "__main__":
121
  demo.launch()
 
8
  from training_data_prep import list_format, modification, league_money, df_gen
9
  from feature_engineering import heroes, hero_information
10
 
11
+ # Global variables for model and feature columns
12
+ MODEL = None
13
+ FEATURE_COLUMNS = None
14
+
15
+ def load_model():
16
+ """Load the ONNX model and get input features"""
17
+ global MODEL, FEATURE_COLUMNS
18
+ try:
19
+ model_path = Path("model/rd2l_predictor.onnx")
20
+ if not model_path.exists():
21
+ return "Model file not found at: " + str(model_path)
22
+
23
+ MODEL = ort.InferenceSession(str(model_path))
24
+
25
+ # Load feature columns from a saved reference - you'll need to create this
26
+ try:
27
+ FEATURE_COLUMNS = pd.read_csv("model/feature_columns.csv")["columns"].tolist()
28
+ except:
29
+ # Fallback to basic features if feature columns file not found
30
+ FEATURE_COLUMNS = ["player_id", "mmr", "p1", "p2", "p3", "p4", "p5",
31
+ "total_games_played", "total_winrate"]
32
+
33
+ return "Model loaded successfully"
34
+ except Exception as e:
35
+ return f"Error loading model: {str(e)}"
36
 
37
  def process_player_data(player_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5):
38
  """Process player data similar to training pipeline"""
 
58
  # Merge hero stats with player data
59
  player_data.update(hero_stats.to_dict())
60
  except Exception as e:
61
+ print(f"Warning - Error fetching hero data: {str(e)}")
62
+ # If hero stats fail, add placeholder values
63
+ player_data.update({
64
+ "total_games_played": 0,
65
+ "total_winrate": 0.0
66
+ })
67
 
68
  # Convert to DataFrame for consistency with training
69
  df = pd.DataFrame([player_data])
70
 
71
+ # Ensure all required columns exist
72
+ if FEATURE_COLUMNS:
73
+ for col in FEATURE_COLUMNS:
74
+ if col not in df.columns:
75
+ df[col] = 0
76
+ # Reorder columns to match model input
77
+ df = df[FEATURE_COLUMNS]
 
78
 
79
  return df
80
  except Exception as e:
 
83
  def predict_cost(user_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5):
84
  """Main prediction function for Gradio interface"""
85
  try:
86
+ # Check if model is loaded
87
+ if MODEL is None:
88
+ result = load_model()
89
+ if not result.startswith("Model loaded"):
90
+ return result
91
+
92
  # Process input data
93
  processed_data = process_player_data(user_id, mmr, comf_1, comf_2, comf_3, comf_4, comf_5)
94
 
95
  if isinstance(processed_data, str): # Error occurred
96
  return processed_data
 
 
 
 
97
 
98
+ # Print debug information
99
+ print("Processed data shape:", processed_data.shape)
100
+ print("Processed data columns:", processed_data.columns.tolist())
101
+
102
+ # Make prediction
103
+ try:
104
+ input_name = MODEL.get_inputs()[0].name
105
+ prediction = MODEL.run(None, {input_name: processed_data.values.astype(np.float32)})[0]
106
+ predicted_cost = round(float(prediction[0]), 2)
107
+ except Exception as e:
108
+ return f"Error during prediction: {str(e)}\nProcessed data shape: {processed_data.shape}"
109
 
110
  return f"""Predicted Cost: {predicted_cost}
111
 
 
121
  Note: This prediction is based on historical data and player statistics from OpenDota."""
122
 
123
  except Exception as e:
124
+ return f"Error in prediction pipeline: {str(e)}"
125
 
126
  # Create Gradio interface
127
  demo = gr.Interface(
 
136
  gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 4)"),
137
  gr.Slider(1, 5, value=3, step=1, label="Comfort (Pos 5)")
138
  ],
139
+ examples=[
140
+ ["https://www.dotabuff.com/players/188649776", 6812, 5, 5, 4, 2, 1]
141
+ ],
142
  outputs=gr.Textbox(label="Prediction Results"),
143
  title="RD2L Player Cost Predictor",
144
  description="""This tool predicts the auction cost for RD2L players based on their MMR,
 
157
  - Predictions are estimates and may vary from actual draft results"""
158
  )
159
 
160
+ # Load model on startup
161
+ print(load_model())
162
+
163
  if __name__ == "__main__":
164
  demo.launch()