File size: 3,282 Bytes
0bd62e5
1
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: outbreak_forecast\n", "### Generate a plot based on 5 inputs.\n", "        "]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib bokeh plotly altair"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import altair\n", "\n", "import gradio as gr\n", "from math import sqrt\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly.express as px\n", "import pandas as pd\n", "\n", "def outbreak(plot_type, r, month, countries, social_distancing):\n", "    months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", "    m = months.index(month)\n", "    start_day = 30 * m\n", "    final_day = 30 * (m + 1)\n", "    x = np.arange(start_day, final_day + 1)\n", "    pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", "    if social_distancing:\n", "        r = sqrt(r)\n", "    df = pd.DataFrame({\"day\": x})\n", "    for country in countries:\n", "        df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", "    if plot_type == \"Matplotlib\":\n", "        fig = plt.figure()\n", "        plt.plot(df[\"day\"], df[countries].to_numpy())\n", "        plt.title(\"Outbreak in \" + month)\n", "        plt.ylabel(\"Cases\")\n", "        plt.xlabel(\"Days since Day 0\")\n", "        plt.legend(countries)\n", "        return fig\n", "    elif plot_type == \"Plotly\":\n", "        fig = px.line(df, x=\"day\", y=countries)\n", "        fig.update_layout(\n", "            title=\"Outbreak in \" + month,\n", "            xaxis_title=\"Cases\",\n", "            yaxis_title=\"Days Since Day 0\",\n", "        )\n", "        return fig\n", "    elif plot_type == \"Altair\":\n", "        df = df.melt(id_vars=\"day\").rename(columns={\"variable\": \"country\"})\n", "        fig = altair.Chart(df).mark_line().encode(x=\"day\", y=\"value\", color=\"country\")\n", "        return fig\n", "    else:\n", "        raise ValueError(\"A plot type must be selected\")\n", "\n", "inputs = [\n", "    gr.Dropdown([\"Matplotlib\", \"Plotly\", \"Altair\"], label=\"Plot Type\"),\n", "    gr.Slider(1, 4, 3.2, label=\"R\"),\n", "    gr.Dropdown([\"January\", \"February\", \"March\", \"April\", \"May\"], label=\"Month\"),\n", "    gr.CheckboxGroup(\n", "        [\"USA\", \"Canada\", \"Mexico\", \"UK\"], label=\"Countries\", value=[\"USA\", \"Canada\"]\n", "    ),\n", "    gr.Checkbox(label=\"Social Distancing?\"),\n", "]\n", "outputs = gr.Plot()\n", "\n", "demo = gr.Interface(\n", "    fn=outbreak,\n", "    inputs=inputs,\n", "    outputs=outputs,\n", "    examples=[\n", "        [\"Matplotlib\", 2, \"March\", [\"Mexico\", \"UK\"], True],\n", "        [\"Altair\", 2, \"March\", [\"Mexico\", \"Canada\"], True],\n", "        [\"Plotly\", 3.6, \"February\", [\"Canada\", \"Mexico\", \"UK\"], False],\n", "    ],\n", "    cache_examples=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", "    demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}