File size: 9,372 Bytes
f688422
d21cd8b
41f3d00
 
f688422
84bdc0f
711bc31
f688422
d21cd8b
 
 
 
 
711bc31
d21cd8b
 
711bc31
d21cd8b
 
 
711bc31
d21cd8b
 
 
 
 
 
bc61879
 
41f3d00
f688422
 
711bc31
f688422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711bc31
f688422
 
 
 
 
33d270b
 
 
 
f688422
 
33d270b
 
f688422
 
 
 
711bc31
 
f688422
 
 
 
 
711bc31
f688422
 
 
 
d21cd8b
 
e441653
711bc31
84bdc0f
105935e
d21cd8b
105935e
 
 
 
 
 
 
 
 
 
 
 
 
f688422
 
 
 
 
 
 
711bc31
 
 
 
 
 
 
 
 
f688422
 
 
 
711bc31
 
 
 
 
 
 
 
f688422
 
711bc31
 
 
 
 
 
f688422
711bc31
 
 
 
 
 
 
f688422
 
 
 
 
 
 
d21cd8b
 
 
 
 
711bc31
d21cd8b
 
 
 
 
 
711bc31
d21cd8b
 
 
 
105935e
 
bc61879
41f3d00
 
d21cd8b
 
 
 
 
711bc31
d21cd8b
 
bc61879
f688422
711bc31
 
 
 
 
 
 
d21cd8b
 
 
 
 
 
711bc31
d21cd8b
711bc31
d21cd8b
 
711bc31
d21cd8b
 
 
 
711bc31
d21cd8b
 
 
711bc31
d21cd8b
 
 
 
 
 
711bc31
d21cd8b
 
 
 
 
 
 
 
f688422
711bc31
d21cd8b
 
 
 
 
 
f688422
d21cd8b
 
711bc31
d21cd8b
711bc31
 
d21cd8b
 
41f3d00
 
d21cd8b
41f3d00
d21cd8b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import gradio as gr
from typing import TypedDict, List, Optional
import os
import pandas as pd

from climateqa.engine.talk_to_data.main import ask_drias
from climateqa.engine.talk_to_data.drias.config import DRIAS_MODELS, DRIAS_UI_TEXT

class DriasUIElements(TypedDict):
    tab: gr.Tab
    details_accordion: gr.Accordion
    examples_hidden: gr.Textbox
    examples: gr.Examples
    image_examples: gr.Row
    drias_direct_question: gr.Textbox
    result_text: gr.Textbox
    table_names_display: gr.Radio
    query_accordion: gr.Accordion
    drias_sql_query: gr.Textbox
    chart_accordion: gr.Accordion
    plot_information: gr.Markdown
    model_selection: gr.Dropdown
    drias_display: gr.Plot
    table_accordion: gr.Accordion
    drias_table: gr.DataFrame


async def ask_drias_query(query: str, index_state: int, user_id: str):
    result = await ask_drias(query, index_state, user_id)
    return result


def show_results(sql_queries_state, dataframes_state, plots_state, table_names):
    if not sql_queries_state or not dataframes_state or not plots_state:
        # If all results are empty, show "No result"
        return (
            gr.update(visible=True),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
        )
    else:
        # Show the appropriate components with their data
        return (
            gr.update(visible=False),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(visible=True),
            gr.update(choices=table_names, value=table_names[0], visible=True),
        )


def filter_by_model(dataframes, figures, index_state, model_selection):
    df = dataframes[index_state]
    if df.empty:
        return df, None
    if "model" not in df.columns:
        return df, figures[index_state](df)
    if model_selection != "ALL":
        df = df[df["model"] == model_selection]
        if df.empty:
            return df, None
    figure = figures[index_state](df)
    return df, figure


def on_table_click(selected_label, table_names, sql_queries, dataframes, plot_informations, plots):
    index = table_names.index(selected_label)
    figure = plots[index](dataframes[index])
    return (
        sql_queries[index],
        dataframes[index],
        figure,
        plot_informations[index],
        index,
    )


def create_drias_ui() -> DriasUIElements:
    """Create and return all UI elements for the DRIAS tab."""
    with gr.Tab("France - Talk to DRIAS", elem_id="tab-vanna", id=6) as tab:
        with gr.Accordion(label="❓ How to use?", elem_id="details") as details_accordion:
            gr.Markdown(DRIAS_UI_TEXT)
            
        # Add examples for common questions
        examples_hidden = gr.Textbox(visible=False, elem_id="drias-examples-hidden")
        examples = gr.Examples(
            examples=[
                ["What will the temperature be like in Paris?"],
                ["What will be the total rainfall in France in 2030?"],
                ["How frequent will extreme events be in Lyon?"],
                ["Comment va évoluer la température en France entre 2030 et 2050 ?"]
            ],
            label="Example Questions",
            inputs=[examples_hidden],
            outputs=[examples_hidden],
        )
        
        with gr.Row():
            drias_direct_question = gr.Textbox(
                label="Direct Question",
                placeholder="You can write direct question here",
                elem_id="direct-question",
                interactive=True,
            )
        
        
        with gr.Row(visible=True, elem_id="example-img-container") as image_examples:
            gr.Markdown("### Examples of possible visualizations")

            with gr.Row():
                gr.Image("./front/assets/talk_to_drias_winter_temp_paris_example.png", label="Evolution of Mean Winter Temperature in Paris", elem_classes=["example-img"])
                gr.Image("./front/assets/talk_to_drias_annual_temperature_france_example.png", label="Mean Annual Temperature in 2030 in France", elem_classes=["example-img"])
                gr.Image("./front/assets/talk_to_drias_frequency_remarkable_precipitation_lyon_example.png", label="Frequency of Remarkable Daily Precipitation in Lyon", elem_classes=["example-img"])

        result_text = gr.Textbox(
            label="", elem_id="no-result-label", interactive=False, visible=True
        )
        
        with gr.Row():
            table_names_display = gr.Radio(
                choices=[], 
                label="Relevant figures created",
                interactive=True,
                elem_id="table-names",
                visible=False
            )

            with gr.Accordion(label="SQL Query Used", visible=False) as query_accordion:
                drias_sql_query = gr.Textbox(
                    label="", elem_id="sql-query", interactive=False
                )

    
        with gr.Accordion(label="Chart", visible=False) as chart_accordion:
            with gr.Row():
                model_selection = gr.Dropdown(
                    label="Model", choices=DRIAS_MODELS, value="ALL", interactive=True
                )
                with gr.Accordion(label="Informations about the plot", open=False):
                    plot_information = gr.Markdown(value = "")
                
            drias_display = gr.Plot(elem_id="vanna-plot")

        with gr.Accordion(
            label="Data used", open=False, visible=False
        ) as table_accordion:
            drias_table = gr.DataFrame([], elem_id="vanna-table")

        return DriasUIElements(
            tab=tab,
            details_accordion=details_accordion,
            examples_hidden=examples_hidden,
            examples=examples,
            image_examples=image_examples,
            drias_direct_question=drias_direct_question,
            result_text=result_text,
            table_names_display=table_names_display,
            query_accordion=query_accordion,
            drias_sql_query=drias_sql_query,
            chart_accordion=chart_accordion,
            plot_information=plot_information,
            model_selection=model_selection,
            drias_display=drias_display,
            table_accordion=table_accordion,
            drias_table=drias_table,
        )



def setup_drias_events(ui_elements: DriasUIElements, share_client=None, user_id=None) -> None:
    """Set up all event handlers for the DRIAS tab."""
    # Create state variables
    sql_queries_state = gr.State([])
    dataframes_state = gr.State([])
    plots_state = gr.State([])
    plot_informations_state = gr.State([])
    index_state = gr.State(0)
    table_names_list = gr.State([])
    user_id = gr.State(user_id)

    # Handle direct question submission - trigger the same workflow by setting examples_hidden
    ui_elements["drias_direct_question"].submit(
        lambda x: gr.update(value=x),
        inputs=[ui_elements["drias_direct_question"]],
        outputs=[ui_elements["examples_hidden"]],
    )

    # Handle example selection
    ui_elements["examples_hidden"].change(
        lambda x: (gr.Accordion(open=False), gr.Textbox(value=x)),
        inputs=[ui_elements["examples_hidden"]],
        outputs=[ui_elements["details_accordion"], ui_elements["drias_direct_question"]]
    ).then(
        lambda : gr.update(visible=False),
        inputs=None,
        outputs=ui_elements["image_examples"]
    ).then(
        ask_drias_query,
        inputs=[ui_elements["examples_hidden"], index_state, user_id],
        outputs=[
            ui_elements["drias_sql_query"],
            ui_elements["drias_table"],
            ui_elements["drias_display"],
            ui_elements["plot_information"],
            sql_queries_state,
            dataframes_state,
            plots_state,
            plot_informations_state,
            index_state,
            table_names_list,
            ui_elements["result_text"],
        ],
    ).then(
        show_results,
        inputs=[sql_queries_state, dataframes_state, plots_state, table_names_list],
        outputs=[
            ui_elements["result_text"],
            ui_elements["query_accordion"],
            ui_elements["table_accordion"],
            ui_elements["chart_accordion"],
            ui_elements["table_names_display"],
        ],
    )

   
    # Handle model selection change
    ui_elements["model_selection"].change(
        filter_by_model,
        inputs=[dataframes_state, plots_state, index_state, ui_elements["model_selection"]],
        outputs=[ui_elements["drias_table"], ui_elements["drias_display"]],
    )


    # Handle table selection
    ui_elements["table_names_display"].change(
        fn=on_table_click,
        inputs=[ui_elements["table_names_display"], table_names_list, sql_queries_state, dataframes_state, plot_informations_state, plots_state],
        outputs=[ui_elements["drias_sql_query"], ui_elements["drias_table"], ui_elements["drias_display"], ui_elements["plot_information"], index_state],
    )

def create_drias_tab(share_client=None, user_id=None):
    """Create the DRIAS tab with all its components and event handlers."""
    ui_elements = create_drias_ui()
    setup_drias_events(ui_elements, share_client=share_client, user_id=user_id)