Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow.keras.utils import to_categorical
|
4 |
+
from sklearn.preprocessing import LabelEncoder
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
# Load the saved model
|
8 |
+
model = tf.keras.models.load_model("enhanced_adaptive_model.keras", custom_objects={'EnhancedTransformerBlock': EnhancedTransformerBlock})
|
9 |
+
|
10 |
+
# Initialize global variables
|
11 |
+
sequence_length = 10
|
12 |
+
data = [] # This will store the recent outcomes
|
13 |
+
encoder = LabelEncoder()
|
14 |
+
encoder.classes_ = np.load('label_encoder_classes.npy', allow_pickle=True)
|
15 |
+
|
16 |
+
def update_data(data, new_outcome):
|
17 |
+
data.append(new_outcome)
|
18 |
+
if len(data) > sequence_length:
|
19 |
+
data.pop(0)
|
20 |
+
return data
|
21 |
+
|
22 |
+
def enhanced_predict_next(model, data, sequence_length, encoder):
|
23 |
+
last_sequence = data[-sequence_length:]
|
24 |
+
last_sequence = np.array(encoder.transform(last_sequence)).reshape((1, sequence_length))
|
25 |
+
|
26 |
+
# Monte Carlo Dropout for uncertainty estimation
|
27 |
+
predictions = []
|
28 |
+
for _ in range(100):
|
29 |
+
prediction = model(last_sequence, training=True)
|
30 |
+
predictions.append(prediction)
|
31 |
+
|
32 |
+
mean_prediction = np.mean(predictions, axis=0)
|
33 |
+
std_prediction = np.std(predictions, axis=0)
|
34 |
+
|
35 |
+
predicted_label = encoder.inverse_transform([np.argmax(mean_prediction)])
|
36 |
+
uncertainty = np.mean(std_prediction)
|
37 |
+
|
38 |
+
return predicted_label[0], uncertainty
|
39 |
+
|
40 |
+
def gradio_predict(outcome):
|
41 |
+
global data
|
42 |
+
|
43 |
+
if outcome not in encoder.classes_:
|
44 |
+
return "Invalid outcome. Please try again."
|
45 |
+
|
46 |
+
data = update_data(data, outcome)
|
47 |
+
|
48 |
+
if len(data) < sequence_length:
|
49 |
+
return f"Not enough data to make a prediction. Please enter {sequence_length - len(data)} more outcomes."
|
50 |
+
|
51 |
+
predicted_next, uncertainty = enhanced_predict_next(model, data, sequence_length, encoder)
|
52 |
+
return f'Predicted next outcome: {predicted_next} (Uncertainty: {uncertainty:.4f})'
|
53 |
+
|
54 |
+
def gradio_update(actual_next):
|
55 |
+
global data, model
|
56 |
+
|
57 |
+
if actual_next not in encoder.classes_:
|
58 |
+
return "Invalid outcome. Please try again."
|
59 |
+
|
60 |
+
data = update_data(data, actual_next)
|
61 |
+
|
62 |
+
if len(data) < sequence_length:
|
63 |
+
return f"Not enough data to update the model. Please enter {sequence_length - len(data)} more outcomes."
|
64 |
+
|
65 |
+
encoded_actual_next = encoder.transform([actual_next])[0]
|
66 |
+
new_X = np.array(encoder.transform(data[-sequence_length:])).reshape((1, sequence_length))
|
67 |
+
new_y = to_categorical(encoded_actual_next, num_classes=len(encoder.classes_))
|
68 |
+
|
69 |
+
model.fit(new_X, new_y, epochs=1, verbose=0)
|
70 |
+
return "Model updated with new data."
|
71 |
+
|
72 |
+
# Gradio interface
|
73 |
+
with gr.Blocks() as demo:
|
74 |
+
gr.Markdown("## Enhanced Outcome Prediction Model")
|
75 |
+
gr.Markdown(f"Enter a sequence of {sequence_length} outcomes to get started.")
|
76 |
+
gr.Markdown(f"Valid outcomes: {', '.join(encoder.classes_)}")
|
77 |
+
|
78 |
+
with gr.Row():
|
79 |
+
outcome_input = gr.Textbox(label="Enter an outcome")
|
80 |
+
predict_button = gr.Button("Predict Next")
|
81 |
+
predicted_output = gr.Textbox(label="Prediction")
|
82 |
+
|
83 |
+
with gr.Row():
|
84 |
+
actual_input = gr.Textbox(label="Enter actual next outcome")
|
85 |
+
update_button = gr.Button("Update Model")
|
86 |
+
update_output = gr.Textbox(label="Update Status")
|
87 |
+
|
88 |
+
predict_button.click(gradio_predict, inputs=outcome_input, outputs=predicted_output)
|
89 |
+
update_button.click(gradio_update, inputs=actual_input, outputs=update_output)
|
90 |
+
|
91 |
+
demo.launch()
|