File size: 3,973 Bytes
0bd62e5
1
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: line_plot"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio vega_datasets pandas"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from vega_datasets import data\n", "\n", "stocks = data.stocks()\n", "gapminder = data.gapminder()\n", "gapminder = gapminder.loc[\n", "    gapminder.country.isin([\"Argentina\", \"Australia\", \"Afghanistan\"])\n", "]\n", "climate = data.climate()\n", "seattle_weather = data.seattle_weather()\n", "\n", "## Or generate your own fake data, here's an example for stocks:\n", "#\n", "# import pandas as pd\n", "# import random\n", "#\n", "# stocks = pd.DataFrame(\n", "#     {\n", "#         \"symbol\": [\n", "#             random.choice(\n", "#                 [\n", "#                     \"MSFT\",\n", "#                     \"AAPL\",\n", "#                     \"AMZN\",\n", "#                     \"IBM\",\n", "#                     \"GOOG\",\n", "#                 ]\n", "#             )\n", "#             for _ in range(120)\n", "#         ],\n", "#         \"date\": [\n", "#             pd.Timestamp(year=2000 + i, month=j, day=1)\n", "#             for i in range(10)\n", "#             for j in range(1, 13)\n", "#         ],\n", "#         \"price\": [random.randint(10, 200) for _ in range(120)],\n", "#     }\n", "# )\n", "\n", "def line_plot_fn(dataset):\n", "    if dataset == \"stocks\":\n", "        return gr.LinePlot(\n", "            stocks,\n", "            x=\"date\",\n", "            y=\"price\",\n", "            color=\"symbol\",\n", "            color_legend_position=\"bottom\",\n", "            title=\"Stock Prices\",\n", "            tooltip=[\"date\", \"price\", \"symbol\"],\n", "            height=300,\n", "            width=500,\n", "        )\n", "    elif dataset == \"climate\":\n", "        return gr.LinePlot(\n", "            climate,\n", "            x=\"DATE\",\n", "            y=\"HLY-TEMP-NORMAL\",\n", "            y_lim=[250, 500],\n", "            title=\"Climate\",\n", "            tooltip=[\"DATE\", \"HLY-TEMP-NORMAL\"],\n", "            height=300,\n", "            width=500,\n", "        )\n", "    elif dataset == \"seattle_weather\":\n", "        return gr.LinePlot(\n", "            seattle_weather,\n", "            x=\"date\",\n", "            y=\"temp_min\",\n", "            tooltip=[\"weather\", \"date\"],\n", "            overlay_point=True,\n", "            title=\"Seattle Weather\",\n", "            height=300,\n", "            width=500,\n", "        )\n", "    elif dataset == \"gapminder\":\n", "        return gr.LinePlot(\n", "            gapminder,\n", "            x=\"year\",\n", "            y=\"life_expect\",\n", "            color=\"country\",\n", "            title=\"Life expectancy for countries\",\n", "            stroke_dash=\"cluster\",\n", "            x_lim=[1950, 2010],\n", "            tooltip=[\"country\", \"life_expect\"],\n", "            stroke_dash_legend_title=\"Country Cluster\",\n", "            height=300,\n", "            width=500,\n", "        )\n", "\n", "with gr.Blocks() as line_plot:\n", "    with gr.Row():\n", "        with gr.Column():\n", "            dataset = gr.Dropdown(\n", "                choices=[\"stocks\", \"climate\", \"seattle_weather\", \"gapminder\"],\n", "                value=\"stocks\",\n", "            )\n", "        with gr.Column():\n", "            plot = gr.LinePlot()\n", "    dataset.change(line_plot_fn, inputs=dataset, outputs=plot)\n", "    line_plot.load(fn=line_plot_fn, inputs=dataset, outputs=plot)\n", "\n", "if __name__ == \"__main__\":\n", "    line_plot.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}