Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -1,151 +1,43 @@ | |
| 1 | 
            -
             | 
| 2 | 
            -
             | 
| 3 | 
            -
            import matplotlib.colors as mpl_colors
         | 
| 4 |  | 
| 5 | 
            -
            import  | 
| 6 | 
            -
            import seaborn as sns
         | 
| 7 | 
            -
            import shinyswatch
         | 
| 8 |  | 
| 9 | 
            -
            from  | 
| 10 |  | 
| 11 | 
            -
             | 
| 12 | 
            -
             | 
| 13 | 
            -
             | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
            numeric_cols: List[str] = df.select_dtypes(include=["float64"]).columns.tolist()
         | 
| 17 | 
            -
            species: List[str] = df["Species"].unique().tolist()
         | 
| 18 | 
            -
            species.sort()
         | 
| 19 | 
            -
             | 
| 20 | 
            -
            app_ui = ui.page_fillable(
         | 
| 21 | 
            -
                shinyswatch.theme.minty(),
         | 
| 22 | 
            -
                ui.layout_sidebar(
         | 
| 23 | 
            -
                    ui.sidebar(
         | 
| 24 | 
            -
                        # Artwork by @allison_horst
         | 
| 25 | 
            -
                        ui.input_selectize(
         | 
| 26 | 
            -
                            "xvar",
         | 
| 27 | 
            -
                            "X variable",
         | 
| 28 | 
            -
                            numeric_cols,
         | 
| 29 | 
            -
                            selected="Bill Length (mm)",
         | 
| 30 | 
            -
                        ),
         | 
| 31 | 
            -
                        ui.input_selectize(
         | 
| 32 | 
            -
                            "yvar",
         | 
| 33 | 
            -
                            "Y variable",
         | 
| 34 | 
            -
                            numeric_cols,
         | 
| 35 | 
            -
                            selected="Bill Depth (mm)",
         | 
| 36 | 
            -
                        ),
         | 
| 37 | 
            -
                        ui.input_checkbox_group(
         | 
| 38 | 
            -
                            "species", "Filter by species", species, selected=species
         | 
| 39 | 
            -
                        ),
         | 
| 40 | 
            -
                        ui.hr(),
         | 
| 41 | 
            -
                        ui.input_switch("by_species", "Show species", value=True),
         | 
| 42 | 
            -
                        ui.input_switch("show_margins", "Show marginal plots", value=True),
         | 
| 43 | 
            -
                    ),
         | 
| 44 | 
            -
                    ui.output_ui("value_boxes"),
         | 
| 45 | 
            -
                    ui.output_plot("scatter", fill=True),
         | 
| 46 | 
            -
                    ui.help_text(
         | 
| 47 | 
            -
                        "Artwork by ",
         | 
| 48 | 
            -
                        ui.a("@allison_horst", href="https://twitter.com/allison_horst"),
         | 
| 49 | 
            -
                        class_="text-end",
         | 
| 50 | 
            -
                    ),
         | 
| 51 | 
            -
                ),
         | 
| 52 | 
             
            )
         | 
| 53 |  | 
| 54 |  | 
| 55 | 
            -
             | 
| 56 | 
            -
                 | 
| 57 | 
            -
                 | 
| 58 | 
            -
                     | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
| 63 | 
            -
                    # Filter the rows so we only include the desired species
         | 
| 64 | 
            -
                    return df[df["Species"].isin(input.species())]
         | 
| 65 | 
            -
             | 
| 66 | 
            -
                @output
         | 
| 67 | 
            -
                @render.plot
         | 
| 68 | 
            -
                def scatter():
         | 
| 69 | 
            -
                    """Generates a plot for Shiny to display to the user"""
         | 
| 70 | 
            -
             | 
| 71 | 
            -
                    # The plotting function to use depends on whether margins are desired
         | 
| 72 | 
            -
                    plotfunc = sns.jointplot if input.show_margins() else sns.scatterplot
         | 
| 73 | 
            -
             | 
| 74 | 
            -
                    plotfunc(
         | 
| 75 | 
            -
                        data=filtered_df(),
         | 
| 76 | 
            -
                        x=input.xvar(),
         | 
| 77 | 
            -
                        y=input.yvar(),
         | 
| 78 | 
            -
                        palette=palette,
         | 
| 79 | 
            -
                        hue="Species" if input.by_species() else None,
         | 
| 80 | 
            -
                        hue_order=species,
         | 
| 81 | 
            -
                        legend=False,
         | 
| 82 | 
             
                    )
         | 
| 83 |  | 
| 84 | 
            -
                 | 
| 85 | 
            -
                 | 
| 86 | 
            -
                 | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
| 89 | 
            -
             | 
| 90 | 
            -
             | 
| 91 | 
            -
             | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 | 
            -
             | 
| 95 | 
            -
             | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
| 102 | 
            -
                        )
         | 
| 103 | 
            -
             | 
| 104 | 
            -
                    if not input.by_species():
         | 
| 105 | 
            -
                        return penguin_value_box(
         | 
| 106 | 
            -
                            "Penguins",
         | 
| 107 | 
            -
                            len(df.index),
         | 
| 108 | 
            -
                            bg_palette["default"],
         | 
| 109 | 
            -
                            # Artwork by @allison_horst
         | 
| 110 | 
            -
                            showcase_img="penguins.png",
         | 
| 111 | 
            -
                        )
         | 
| 112 | 
            -
             | 
| 113 | 
            -
                    value_boxes = [
         | 
| 114 | 
            -
                        penguin_value_box(
         | 
| 115 | 
            -
                            name,
         | 
| 116 | 
            -
                            len(df[df["Species"] == name]),
         | 
| 117 | 
            -
                            bg_palette[name],
         | 
| 118 | 
            -
                            # Artwork by @allison_horst
         | 
| 119 | 
            -
                            showcase_img=f"{name}.png",
         | 
| 120 | 
            -
                        )
         | 
| 121 | 
            -
                        for name in species
         | 
| 122 | 
            -
                        # Only include boxes for _selected_ species
         | 
| 123 | 
            -
                        if name in input.species()
         | 
| 124 | 
            -
                    ]
         | 
| 125 | 
            -
             | 
| 126 | 
            -
                    return ui.layout_column_wrap(*value_boxes, width = 1 / len(value_boxes))
         | 
| 127 | 
            -
             | 
| 128 | 
            -
             | 
| 129 | 
            -
            # "darkorange", "purple", "cyan4"
         | 
| 130 | 
            -
            colors = [[255, 140, 0], [160, 32, 240], [0, 139, 139]]
         | 
| 131 | 
            -
            colors = [(r / 255.0, g / 255.0, b / 255.0) for r, g, b in colors]
         | 
| 132 | 
            -
             | 
| 133 | 
            -
            palette: Dict[str, Tuple[float, float, float]] = {
         | 
| 134 | 
            -
                "Adelie": colors[0],
         | 
| 135 | 
            -
                "Chinstrap": colors[1],
         | 
| 136 | 
            -
                "Gentoo": colors[2],
         | 
| 137 | 
            -
                "default": sns.color_palette()[0],  # type: ignore
         | 
| 138 | 
            -
            }
         | 
| 139 | 
            -
             | 
| 140 | 
            -
            bg_palette = {}
         | 
| 141 | 
            -
            # Use `sns.set_style("whitegrid")` to help find approx alpha value
         | 
| 142 | 
            -
            for name, col in palette.items():
         | 
| 143 | 
            -
                # Adjusted n_colors until `axe` accessibility did not complain about color contrast
         | 
| 144 | 
            -
                bg_palette[name] = mpl_colors.to_hex(sns.light_palette(col, n_colors=7)[1])  # type: ignore
         | 
| 145 | 
            -
             | 
| 146 | 
            -
             | 
| 147 | 
            -
            app = App(
         | 
| 148 | 
            -
                app_ui,
         | 
| 149 | 
            -
                server,
         | 
| 150 | 
            -
                static_assets=str(www_dir),
         | 
| 151 | 
            -
            )
         | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import sys
         | 
|  | |
| 3 |  | 
| 4 | 
            +
            from shiny.express import input, ui
         | 
|  | |
|  | |
| 5 |  | 
| 6 | 
            +
            from all_rag_fns import do_rag
         | 
| 7 |  | 
| 8 | 
            +
            oai_api_key = os.getenv("OPENAI_API_KEY")
         | 
| 9 | 
            +
            ui.page_opts(
         | 
| 10 | 
            +
                title="Use Shiny to Run RAG on the previous R/Gov Talks",
         | 
| 11 | 
            +
                fillable=True,
         | 
| 12 | 
            +
                fillable_mobile=True,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 13 | 
             
            )
         | 
| 14 |  | 
| 15 |  | 
| 16 | 
            +
            with ui.layout_sidebar():
         | 
| 17 | 
            +
                # Add radio buttons in the sidebar
         | 
| 18 | 
            +
                with ui.sidebar():
         | 
| 19 | 
            +
                    ui.input_radio_buttons(
         | 
| 20 | 
            +
                        "model_choice",
         | 
| 21 | 
            +
                        "Select Model:",
         | 
| 22 | 
            +
                        choices={"gpt-4o-mini": "Cheaper", "gpt-4o": "More Accurate"},
         | 
| 23 | 
            +
                        selected="gpt-4o-mini",
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 24 | 
             
                    )
         | 
| 25 |  | 
| 26 | 
            +
                # Create a chat instance and display it in the main panel
         | 
| 27 | 
            +
                chat = ui.Chat(id="chat")
         | 
| 28 | 
            +
                chat.ui()
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            # Define a callback to run when the user submits a message
         | 
| 32 | 
            +
            @chat.on_user_submit
         | 
| 33 | 
            +
            async def _():
         | 
| 34 | 
            +
                user_message = chat.user_input()
         | 
| 35 | 
            +
                response, _ = do_rag(
         | 
| 36 | 
            +
                    user_input=user_message,
         | 
| 37 | 
            +
                    n_results=3,
         | 
| 38 | 
            +
                    stream=True,
         | 
| 39 | 
            +
                    oai_api_key=oai_api_key,
         | 
| 40 | 
            +
                    model_name=input.model_choice(),
         | 
| 41 | 
            +
                )
         | 
| 42 | 
            +
                # Append the response into the chat
         | 
| 43 | 
            +
                await chat.append_message_stream(response)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  |