File size: 5,425 Bytes
dd5de47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "896cacc6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on local URL:  http://127.0.0.1:7860\n",
      "\n",
      "To create a public link, set `share=True` in `launch()`.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.linear_model import MultiTaskLasso, Lasso\n",
    "import gradio as gr\n",
    "\n",
    "rng = np.random.RandomState(42)\n",
    "\n",
    "# Generate some 2D coefficients with sine waves with random frequency and phase\n",
    "def make_plot(n_samples, n_features, n_tasks, n_relevant_features, alpha):\n",
    "     \n",
    "    coef = np.zeros((n_tasks, n_features))\n",
    "    times = np.linspace(0, 2 * np.pi, n_tasks)\n",
    "    for k in range(n_relevant_features):\n",
    "        coef[:, k] = np.sin((1.0 + rng.randn(1)) * times + 3 * rng.randn(1))\n",
    "    \n",
    "    X = rng.randn(n_samples, n_features)\n",
    "    Y = np.dot(X, coef.T) + rng.randn(n_samples, n_tasks)\n",
    "    \n",
    "    coef_lasso_ = np.array([Lasso(alpha=0.5).fit(X, y).coef_ for y in Y.T])\n",
    "    coef_multi_task_lasso_ = MultiTaskLasso(alpha=alpha).fit(X, Y).coef_\n",
    "    \n",
    "    fig = plt.figure(figsize=(8, 5))\n",
    "    \n",
    "    feature_to_plot = 0\n",
    "    fig = plt.figure()\n",
    "    lw = 2\n",
    "    plt.plot(coef[:, feature_to_plot], color=\"seagreen\", linewidth=lw, label=\"Ground truth\")\n",
    "    plt.plot(\n",
    "        coef_lasso_[:, feature_to_plot], color=\"cornflowerblue\", linewidth=lw, label=\"Lasso\"\n",
    "    )\n",
    "    plt.plot(\n",
    "        coef_multi_task_lasso_[:, feature_to_plot],\n",
    "        color=\"gold\",\n",
    "        linewidth=lw,\n",
    "        label=\"MultiTaskLasso\",\n",
    "    )\n",
    "    plt.legend(loc=\"upper center\")\n",
    "    plt.axis(\"tight\")\n",
    "    plt.ylim([-1.1, 1.1])\n",
    "    fig.suptitle(\"Lasso, MultiTaskLasso and Ground truth time series\")\n",
    "    return fig\n",
    "   \n",
    "    \n",
    "model_card=f\"\"\"\n",
    "## Description\n",
    "The multi-task lasso allows to fit multiple regression problems jointly enforcing the selected\n",
    "features to be the same across tasks. This example simulates sequential measurements, each task \n",
    "is a time instant, and the relevant features vary in amplitude over time while being the same. \n",
    "The multi-task lasso imposes that features that are selected at one time point are select \n",
    "for all time point. This makes feature selection by the Lasso more stable.\n",
    "## Model\n",
    "currentmodule: sklearn.linear_model\n",
    "class:`Lasso` and class: `MultiTaskLasso` are used in this example.\n",
    "Plots represent Lasso, MultiTaskLasso and Ground truth time series\n",
    "\"\"\"\n",
    "\n",
    "with gr.Blocks() as demo:\n",
    "    gr.Markdown('''\n",
    "            <div>\n",
    "            <h1 style='text-align: center'> Joint feature selection with multi-task Lasso </h1>\n",
    "            </div>\n",
    "        ''')\n",
    "    gr.Markdown(model_card)\n",
    "    gr.Markdown(\"Original example Author: Alexandre Gramfort <[email protected]>\")\n",
    "    gr.Markdown(\n",
    "        \"Iterative conversion by: <a href=\\\"https://github.com/DeaMariaLeon\\\">Dea María Léon</a>\"\n",
    "    )\n",
    "    n_samples = gr.Slider(50,500,value=100,step=50,label='Select number of samples')\n",
    "    n_features = gr.Slider(5,50,value=30,step=5,label='Select number of features')\n",
    "    n_tasks = gr.Slider(5,50,value=40,step=5,label='Select number of tasks')\n",
    "    n_relevant_features = gr.Slider(1,10,value=5,step=1,label='Select number of relevant_features')\n",
    "    with gr.Column():\n",
    "        with gr.Tab('Select Alpha Range'):\n",
    "            alpha = gr.Slider(0,10,value=1.0,step=0.5,label='alpha')\n",
    "            \n",
    "    btn = gr.Button(value = 'Submit')\n",
    "\n",
    "    btn.click(make_plot,inputs=[n_samples,n_features, n_tasks, n_relevant_features, alpha],outputs=[gr.Plot()])\n",
    "\n",
    "demo.launch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8043d31",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "scikit-ex",
   "language": "python",
   "name": "scikit-ex"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}