File size: 2,139 Bytes
2f090b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
import gradio as gr

# Create a random dataset
rng = np.random.RandomState(1)
X = np.sort(200 * rng.rand(100, 1) - 100, axis=0)
y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T
y[::5, :] += 0.5 - rng.rand(20, 2)


def plot_multi_tree(d1,d2,d3):
    # Fit regression model
    regr_1 = DecisionTreeRegressor(max_depth=d1)
    regr_2 = DecisionTreeRegressor(max_depth=d2)
    regr_3 = DecisionTreeRegressor(max_depth=d3)
    regr_1.fit(X, y)
    regr_2.fit(X, y)
    regr_3.fit(X, y)

    # Predict
    X_test = np.arange(-100.0, 100.0, 0.01)[:, np.newaxis]
    y_1 = regr_1.predict(X_test)
    y_2 = regr_2.predict(X_test)
    y_3 = regr_3.predict(X_test)

    # Plot the results
    fig = plt.figure()
    s = 25
    plt.scatter(y[:, 0], y[:, 1], c="navy", s=s, edgecolor="black", label="data")
    plt.scatter(
        y_1[:, 0],
        y_1[:, 1],
        c="cornflowerblue",
        s=s,
        edgecolor="black",
        label= f"max_depth={d1}",
    )
    plt.scatter(y_2[:, 0], y_2[:, 1], c="red", s=s, edgecolor="black", label= f"max_depth={d2}")
    plt.scatter(
        y_3[:, 0], y_3[:, 1], c="orange", s=s, edgecolor="black", label= f"max_depth={d3}"
    )
    plt.xlim([-6, 6])
    plt.ylim([-6, 6])
    plt.xlabel("target 1")
    plt.ylabel("target 2")
    plt.title("Multi-output Decision Tree Regression")
    plt.legend(loc="best")
    return fig




title = " Illustration of multi-output regression with decision tree.🌲 "
with gr.Blocks(title=title) as demo:
    gr.Markdown(f"## {title}")

    with gr.Row():
        d1 = gr.Slider(minimum=0, maximum=20, step=1,  value = 2,
                label = "Depth 1")
        d2 = gr.Slider(minimum=0, maximum=20, step=1, value = 5,
                label = "Depth 2")
        d3 = gr.Slider(minimum=0, maximum=20, step=1, value = 8,
                label = "Depth 3")
    
    btn = gr.Button(value="Submit")
    btn.click(plot_multi_tree, inputs= [d1,d2,d3], outputs= gr.Plot(label='Multi-output regression with decision trees') ) # 
    

demo.launch()