timeki's picture
clean chat and style
df26154
raw
history blame
29.3 kB
from climateqa.engine.embeddings import get_embeddings_function
embeddings_function = get_embeddings_function()
from sentence_transformers import CrossEncoder
import gradio as gr
from gradio_modal import Modal
import pandas as pd
import numpy as np
import os
import time
import re
import json
from gradio import ChatMessage
from io import BytesIO
import base64
from datetime import datetime
from azure.storage.fileshare import ShareServiceClient
from utils import create_user_id
from gradio_modal import Modal
# ClimateQ&A imports
from climateqa.engine.llm import get_llm
from climateqa.engine.vectorstore import get_pinecone_vectorstore
from climateqa.engine.reranker import get_reranker
from climateqa.sample_questions import QUESTIONS
from climateqa.constants import POSSIBLE_REPORTS
from climateqa.utils import get_image_from_azure_blob_storage
from climateqa.engine.graph import make_graph_agent
from climateqa.engine.embeddings import get_embeddings_function
from climateqa.engine.chains.retrieve_papers import find_papers
from front.utils import serialize_docs,process_figures
from climateqa.event_handler import init_audience, handle_retrieved_documents, stream_answer,handle_retrieved_owid_graphs
# Load environment variables in local mode
try:
from dotenv import load_dotenv
load_dotenv()
except Exception as e:
pass
# Set up Gradio Theme
theme = gr.themes.Base(
primary_hue="blue",
secondary_hue="red",
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
)
init_prompt = """
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**.
❓ How to use
- **Language**: You can ask me your questions in any language.
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
- **Sources**: You can choose to search in the IPCC or IPBES reports, or both.
- **Relevant content sources**: You can choose to search for figures, papers, or graphs that can be relevant for your question.
⚠️ Limitations
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
🛈 Information
Please note that we log your questions for meta-analysis purposes, so avoid sharing any sensitive or personal information.
What do you want to learn ?
"""
system_template = {
"role": "system",
"content": init_prompt,
}
account_key = os.environ["BLOB_ACCOUNT_KEY"]
if len(account_key) == 86:
account_key += "=="
credential = {
"account_key": account_key,
"account_name": os.environ["BLOB_ACCOUNT_NAME"],
}
account_url = os.environ["BLOB_ACCOUNT_URL"]
file_share_name = "climateqa"
service = ShareServiceClient(account_url=account_url, credential=credential)
share_client = service.get_share_client(file_share_name)
user_id = create_user_id()
CITATION_LABEL = "BibTeX citation for ClimateQ&A"
CITATION_TEXT = r"""@misc{climateqa,
author={Théo Alves Da Costa, Timothée Bohe},
title={ClimateQ&A, AI-powered conversational assistant for climate change and biodiversity loss},
year={2024},
howpublished= {\url{https://climateqa.com}},
}
@software{climateqa,
author = {Théo Alves Da Costa, Timothée Bohe},
publisher = {ClimateQ&A},
title = {ClimateQ&A, AI-powered conversational assistant for climate change and biodiversity loss},
}
"""
# Create vectorstore and retriever
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX"))
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
reranker = get_reranker("nano")
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
def update_config_modal_visibility(config_open):
new_config_visibility_status = not config_open
return gr.update(visible=new_config_visibility_status), new_config_visibility_status
async def chat(query, history, audience, sources, reports, relevant_content_sources, search_only):
"""Process a chat query and return response with relevant sources and visualizations.
Args:
query (str): The user's question
history (list): Chat message history
audience (str): Target audience type
sources (list): Knowledge base sources to search
reports (list): Specific reports to search within sources
relevant_content_sources (list): Types of content to retrieve (figures, papers, etc)
search_only (bool): Whether to only search without generating answer
Yields:
tuple: Contains:
- history: Updated chat history
- docs_html: HTML of retrieved documents
- output_query: Processed query
- output_language: Detected language
- related_contents: Related content
- graphs_html: HTML of relevant graphs
"""
# Log incoming question
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f">> NEW QUESTION ({date_now}) : {query}")
audience_prompt = init_audience(audience)
sources = sources or ["IPCC", "IPBES", "IPOS"]
reports = reports or []
# Prepare inputs for agent
inputs = {
"user_input": query,
"audience": audience_prompt,
"sources_input": sources,
"relevant_content_sources": relevant_content_sources,
"search_only": search_only
}
# Get streaming events from agent
result = agent.astream_events(inputs, version="v1")
# Initialize state variables
docs = []
used_figures = []
related_contents = []
docs_html = ""
output_query = ""
output_language = ""
output_keywords = ""
start_streaming = False
graphs_html = ""
figures = '<div class="figures-container"><p></p> </div>'
used_documents = []
answer_message_content = ""
# Define processing steps
steps_display = {
"categorize_intent": ("🔄️ Analyzing user message", True),
"transform_query": ("🔄️ Thinking step by step to answer the question", True),
"retrieve_documents": ("🔄️ Searching in the knowledge base", False),
}
try:
# Process streaming events
async for event in result:
if "langgraph_node" in event["metadata"]:
node = event["metadata"]["langgraph_node"]
# Handle document retrieval
if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents":
docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(
event, history, used_documents
)
# Handle intent categorization
elif (event["event"] == "on_chain_end" and
node == "categorize_intent" and
event["name"] == "_write"):
intent = event["data"]["output"]["intent"]
output_language = event["data"]["output"].get("language", "English")
history[-1].content = f"Language identified: {output_language}\nIntent identified: {intent}"
# Handle processing steps display
elif event["name"] in steps_display and event["event"] == "on_chain_start":
event_description, display_output = steps_display[node]
if (not hasattr(history[-1], 'metadata') or
history[-1].metadata["title"] != event_description):
history.append(ChatMessage(
role="assistant",
content="",
metadata={'title': event_description}
))
# Handle answer streaming
elif (event["name"] != "transform_query" and
event["event"] == "on_chat_model_stream" and
node in ["answer_rag", "answer_search", "answer_chitchat"]):
history, start_streaming, answer_message_content = stream_answer(
history, event, start_streaming, answer_message_content
)
# Handle graph retrieval
elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
graphs_html = handle_retrieved_owid_graphs(event, graphs_html)
# Handle query transformation
if event["name"] == "transform_query" and event["event"] == "on_chain_end":
if hasattr(history[-1], "content"):
sub_questions = [q["question"] for q in event["data"]["output"]["remaining_questions"]]
history[-1].content += "Decompose question into sub-questions:\n\n - " + "\n - ".join(sub_questions)
yield history, docs_html, output_query, output_language, related_contents, graphs_html #,output_query,output_keywords
except Exception as e:
print(f"Event {event} has failed")
raise gr.Error(str(e))
try:
# Log interaction to Azure if not in local environment
if os.getenv("GRADIO_ENV") != "local":
timestamp = str(datetime.now().timestamp())
prompt = history[1]["content"]
logs = {
"user_id": str(user_id),
"prompt": prompt,
"query": prompt,
"question": output_query,
"sources": sources,
"docs": serialize_docs(docs),
"answer": history[-1].content,
"time": timestamp,
}
log_on_azure(f"{timestamp}.json", logs, share_client)
except Exception as e:
print(f"Error logging on Azure Blob Storage: {e}")
error_msg = f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)"
raise gr.Error(error_msg)
yield history, docs_html, output_query, output_language, related_contents, graphs_html
def save_feedback(feed: str, user_id):
if len(feed) > 1:
timestamp = str(datetime.now().timestamp())
file = user_id + timestamp + ".json"
logs = {
"user_id": user_id,
"feedback": feed,
"time": timestamp,
}
log_on_azure(file, logs, share_client)
return "Feedback submitted, thank you!"
def log_on_azure(file, logs, share_client):
logs = json.dumps(logs)
file_client = share_client.get_file_client(file)
file_client.upload_file(logs)
# --------------------------------------------------------------------
# Gradio
# --------------------------------------------------------------------
init_prompt = """
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**.
❓ How to use
- **Language**: You can ask me your questions in any language.
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
- **Sources**: You can choose to search in the IPCC or IPBES reports, or both.
- **Relevant content sources**: You can choose to search for figures, papers, or graphs that can be relevant for your question.
⚠️ Limitations
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
🛈 Information
Please note that we log your questions for meta-analysis purposes, so avoid sharing any sensitive or personal information.
What do you want to learn ?
"""
def vote(data: gr.LikeData):
if data.liked:
print(data.value)
else:
print(data)
def save_graph(saved_graphs_state, embedding, category):
print(f"\nCategory:\n{saved_graphs_state}\n")
if category not in saved_graphs_state:
saved_graphs_state[category] = []
if embedding not in saved_graphs_state[category]:
saved_graphs_state[category].append(embedding)
return saved_graphs_state, gr.Button("Graph Saved")
with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme, elem_id="main-component") as demo:
# State variables
chat_completed_state = gr.State(0)
current_graphs = gr.State([])
saved_graphs = gr.State({})
config_open = gr.State(False)
with gr.Tab("ClimateQ&A"):
with gr.Row(elem_id="chatbot-row"):
# Left column - Chat interface
with gr.Column(scale=2):
chatbot = gr.Chatbot(
value=[ChatMessage(role="assistant", content=init_prompt)],
type="messages",
show_copy_button=True,
show_label=False,
elem_id="chatbot",
layout="panel",
avatar_images=(None, "https://i.ibb.co/YNyd5W2/logo4.png"),
max_height="80vh",
height="100vh"
)
with gr.Row(elem_id="input-message"):
textbox = gr.Textbox(
placeholder="Ask me anything here!",
show_label=False,
scale=7,
lines=1,
interactive=True,
elem_id="input-textbox"
)
config_button = gr.Button("", elem_id="config-button")
# Right column - Content panels
with gr.Column(scale=2, variant="panel", elem_id="right-panel"):
with gr.Tabs(elem_id="right_panel_tab") as tabs:
# Examples tab
with gr.TabItem("Examples", elem_id="tab-examples", id=0):
examples_hidden = gr.Textbox(visible=False)
first_key = list(QUESTIONS.keys())[0]
dropdown_samples = gr.Dropdown(
choices=QUESTIONS.keys(),
value=first_key,
interactive=True,
label="Select a category of sample questions",
elem_id="dropdown-samples"
)
samples = []
for i, key in enumerate(QUESTIONS.keys()):
examples_visible = (i == 0)
with gr.Row(visible=examples_visible) as group_examples:
examples_questions = gr.Examples(
examples=QUESTIONS[key],
inputs=[examples_hidden],
examples_per_page=8,
run_on_click=False,
elem_id=f"examples{i}",
api_name=f"examples{i}"
)
samples.append(group_examples)
# Sources tab
with gr.Tab("Sources", elem_id="tab-sources", id=1) as tab_sources:
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
# Recommended content tab
with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=2) as tab_recommended_content:
with gr.Tabs(elem_id="group-subtabs") as tabs_recommended_content:
# Figures subtab
with gr.Tab("Figures", elem_id="tab-figures", id=3) as tab_figures:
sources_raw = gr.State()
with Modal(visible=False, elem_id="modal_figure_galery") as figure_modal:
gallery_component = gr.Gallery(
object_fit='scale-down',
elem_id="gallery-component",
height="80vh"
)
show_full_size_figures = gr.Button(
"Show figures in full size",
elem_id="show-figures",
interactive=True
)
show_full_size_figures.click(
lambda: Modal(visible=True),
None,
figure_modal
)
figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
# Papers subtab
with gr.Tab("Papers", elem_id="tab-citations", id=4) as tab_papers:
with gr.Accordion(
visible=True,
elem_id="papers-summary-popup",
label="See summary of relevant papers",
open=False
) as summary_popup:
papers_summary = gr.Markdown("", visible=True, elem_id="papers-summary")
with gr.Accordion(
visible=True,
elem_id="papers-relevant-popup",
label="See relevant papers",
open=False
) as relevant_popup:
papers_html = gr.HTML(show_label=False, elem_id="papers-textbox")
btn_citations_network = gr.Button("Explore papers citations network")
with Modal(visible=False) as papers_modal:
citations_network = gr.HTML(
"<h3>Citations Network Graph</h3>",
visible=True,
elem_id="papers-citations-network"
)
btn_citations_network.click(
lambda: Modal(visible=True),
None,
papers_modal
)
# Graphs subtab
with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
graphs_container = gr.HTML(
"<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",
elem_id="graphs-container"
)
current_graphs.change(
lambda x: x,
inputs=[current_graphs],
outputs=[graphs_container]
)
# Configuration modal
with Modal(visible=False, elem_id="modal-config") as config_modal:
gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
dropdown_sources = gr.CheckboxGroup(
choices=["IPCC", "IPBES", "IPOS"],
label="Select source (by default search in all sources)",
value=["IPCC"],
interactive=True
)
dropdown_reports = gr.Dropdown(
choices=POSSIBLE_REPORTS,
label="Or select specific reports",
multiselect=True,
value=None,
interactive=True
)
dropdown_external_sources = gr.CheckboxGroup(
choices=["IPCC figures", "OpenAlex", "OurWorldInData"],
label="Select database to search for relevant content",
value=["IPCC figures"],
interactive=True
)
search_only = gr.Checkbox(
label="Search only for recommended content without chating",
value=False,
interactive=True,
elem_id="checkbox-chat"
)
dropdown_audience = gr.Dropdown(
choices=["Children", "General public", "Experts"],
label="Select audience",
value="Experts",
interactive=True
)
after = gr.Slider(
minimum=1950,
maximum=2023,
step=1,
value=1960,
label="Publication date",
show_label=True,
interactive=True,
elem_id="date-papers",
visible=False
)
output_query = gr.Textbox(
label="Query used for retrieval",
show_label=True,
elem_id="reformulated-query",
lines=2,
interactive=False,
visible=False
)
output_language = gr.Textbox(
label="Language",
show_label=True,
elem_id="language",
lines=1,
interactive=False,
visible=False
)
dropdown_external_sources.change(
lambda x: gr.update(visible="OpenAlex" in x),
inputs=[dropdown_external_sources],
outputs=[after]
)
close_config_modal = gr.Button("Validate and Close", elem_id="close-config-modal")
close_config_modal.click(
fn=update_config_modal_visibility,
inputs=[config_open],
outputs=[config_modal, config_open]
)
config_button.click(
fn=update_config_modal_visibility,
inputs=[config_open],
outputs=[config_modal, config_open]
)
#---------------------------------------------------------------------------------------
# OTHER TABS
#---------------------------------------------------------------------------------------
with gr.Tab("About",elem_classes = "max-height other-tabs"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(
"""
### More info
- See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)
- Feedbacks on this [form](https://forms.office.com/e/1Yzgxm6jbp)
### Citation
"""
)
with gr.Accordion(CITATION_LABEL,elem_id="citation", open = False,):
# # Display citation label and text)
gr.Textbox(
value=CITATION_TEXT,
label="",
interactive=False,
show_copy_button=True,
lines=len(CITATION_TEXT.split('\n')),
)
def start_chat(query,history,search_only):
history = history + [ChatMessage(role="user", content=query)]
if not search_only:
return (gr.update(interactive = False),gr.update(selected=1),history)
else:
return (gr.update(interactive = False),gr.update(selected=2),history)
def finish_chat():
return gr.update(interactive = True,value = "")
# Initialize visibility states
summary_visible = False
relevant_visible = False
# Functions to toggle visibility
def toggle_summary_visibility():
global summary_visible
summary_visible = not summary_visible
return gr.update(visible=summary_visible)
def toggle_relevant_visibility():
global relevant_visible
relevant_visible = not relevant_visible
return gr.update(visible=relevant_visible)
def change_completion_status(current_state):
current_state = 1 - current_state
return current_state
def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html):
sources_number = sources_textbox.count("<h2>")
figures_number = figures_cards.count("<h2>")
graphs_number = current_graphs.count("<iframe")
papers_number = papers_html.count("<h2>")
sources_notif_label = f"Sources ({sources_number})"
figures_notif_label = f"Figures ({figures_number})"
graphs_notif_label = f"Graphs ({graphs_number})"
papers_notif_label = f"Papers ({papers_number})"
recommended_content_notif_label = f"Recommended content ({figures_number + graphs_number + papers_number})"
return gr.update(label = recommended_content_notif_label), gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label), gr.update(label = papers_notif_label)
(textbox
.submit(start_chat, [textbox,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
.then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
)
(examples_hidden
.change(start_chat, [examples_hidden,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
)
def change_sample_questions(key):
index = list(QUESTIONS.keys()).index(key)
visible_bools = [False] * len(samples)
visible_bools[index] = True
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
sources_raw.change(process_figures, inputs=[sources_raw], outputs=[figures_cards, gallery_component])
# update sources numbers
sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
figures_cards.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
current_graphs.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
papers_html.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
# other questions examples
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
# search for papers
textbox.submit(find_papers,[textbox,after, dropdown_external_sources], [papers_html,citations_network,papers_summary])
examples_hidden.change(find_papers,[examples_hidden,after,dropdown_external_sources], [papers_html,citations_network,papers_summary])
# btn_summary.click(toggle_summary_visibility, outputs=summary_popup)
# btn_relevant_papers.click(toggle_relevant_visibility, outputs=relevant_popup)
demo.queue()
demo.launch(ssr_mode=False)