diff --git a/.gitattributes b/.gitattributes index e436d7c5d104668b7969e23dc6688b3dafbe9c99..0eb8c2e05739c1d905f2c2a19356d46372635988 100644 --- a/.gitattributes +++ b/.gitattributes @@ -44,4 +44,3 @@ documents/climate_gpt_v2_only_giec.faiss filter=lfs diff=lfs merge=lfs -text documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text climateqa_v3.db filter=lfs diff=lfs merge=lfs -text climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text -data/drias/drias.db filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index 8288a2228a648af2e94d03ef1375299785bfe0c8..810e6d2a5f4099116c3b903da22346a544afee54 100644 --- a/.gitignore +++ b/.gitignore @@ -5,16 +5,3 @@ __pycache__/utils.cpython-38.pyc notebooks/ *.pyc - -**/.ipynb_checkpoints/ -**/.flashrank_cache/ - -data/ -sandbox/ - -climateqa/talk_to_data/database/ -*.db - -data_ingestion/ -.vscode -*old/ diff --git a/README.md b/README.md index 4bc553e88e65fd5201809ec9ebd12312f96a9816..b1f4ec3bf80bc3b00e8a839e1d0052789970b96a 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ emoji: 🌍 colorFrom: blue colorTo: red sdk: gradio -sdk_version: 5.0.2 +sdk_version: 4.19.1 app_file: app.py fullWidth: true pinned: false diff --git a/app.py b/app.py index f75557f2fa27e84ad3d270dd0fc143f83eb0e3d0..ab849993528f591307fdd3a6b5be50730fc147f4 100644 --- a/app.py +++ b/app.py @@ -1,33 +1,44 @@ -# Import necessary libraries -import os -import gradio as gr +from climateqa.engine.embeddings import get_embeddings_function +embeddings_function = get_embeddings_function() -from azure.storage.fileshare import ShareServiceClient +from climateqa.papers.openalex import OpenAlex +from sentence_transformers import CrossEncoder -# Import custom modules -from climateqa.engine.embeddings import get_embeddings_function -from climateqa.engine.llm import get_llm -from climateqa.engine.vectorstore import get_pinecone_vectorstore -from climateqa.engine.reranker import get_reranker -from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc -from climateqa.engine.chains.retrieve_papers import find_papers -from climateqa.chat import start_chat, chat_stream, finish_chat -from climateqa.engine.talk_to_data.main import ask_vanna -from climateqa.engine.talk_to_data.myVanna import MyVanna +reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1") +oa = OpenAlex() -from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab) -from front.utils import process_figures -from gradio_modal import Modal +import gradio as gr +import pandas as pd +import numpy as np +import os +import time +import re +import json + +# from gradio_modal import Modal +from io import BytesIO +import base64 + +from datetime import datetime +from azure.storage.fileshare import ShareServiceClient from utils import create_user_id -import logging -logging.basicConfig(level=logging.WARNING) -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppresses INFO and WARNING logs -logging.getLogger().setLevel(logging.WARNING) +# ClimateQ&A imports +from climateqa.engine.llm import get_llm +from climateqa.engine.rag import make_rag_chain +from climateqa.engine.vectorstore import get_pinecone_vectorstore +from climateqa.engine.retriever import ClimateQARetriever +from climateqa.engine.embeddings import get_embeddings_function +from climateqa.engine.prompts import audience_prompts +from climateqa.sample_questions import QUESTIONS +from climateqa.constants import POSSIBLE_REPORTS +from climateqa.utils import get_image_from_azure_blob_storage +from climateqa.engine.keywords import make_keywords_chain +from climateqa.engine.rag import make_rag_papers_chain # Load environment variables in local mode try: @@ -36,7 +47,6 @@ try: except Exception as e: pass - # Set up Gradio Theme theme = gr.themes.Base( primary_hue="blue", @@ -44,7 +54,15 @@ theme = gr.themes.Base( font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], ) -# Azure Blob Storage credentials + + +init_prompt = "" + +system_template = { + "role": "system", + "content": init_prompt, +} + account_key = os.environ["BLOB_ACCOUNT_KEY"] if len(account_key) == 86: account_key += "==" @@ -63,284 +81,597 @@ user_id = create_user_id() -# Create vectorstore and retriever -embeddings_function = get_embeddings_function() -vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")) -vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description") -vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")) +def parse_output_llm_with_sources(output): + # Split the content into a list of text and "[Doc X]" references + content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output) + parts = [] + for part in content_parts: + if part.startswith("Doc"): + subparts = part.split(",") + subparts = [subpart.lower().replace("doc","").strip() for subpart in subparts] + subparts = [f"""{subpart}""" for subpart in subparts] + parts.append("".join(subparts)) + else: + parts.append(part) + content_parts = "".join(parts) + return content_parts + +# Create vectorstore and retriever +vectorstore = get_pinecone_vectorstore(embeddings_function) llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) -if os.environ["GRADIO_ENV"] == "local": - reranker = get_reranker("nano") -else : - reranker = get_reranker("large") -agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2) -agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2 -#Vanna object +def make_pairs(lst): + """from a list of even lenght, make tupple pairs""" + return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)] -vn = MyVanna(config = {"temperature": 0, "api_key": os.getenv('THEO_API_KEY'), 'model': os.getenv('VANNA_MODEL'), 'pc_api_key': os.getenv('VANNA_PINECONE_API_KEY'), 'index_name': os.getenv('VANNA_INDEX_NAME'), "top_k" : 4}) -db_vanna_path = os.path.join(os.getcwd(), "data/drias/drias.db") -vn.connect_to_sqlite(db_vanna_path) -def ask_vanna_query(query): - return ask_vanna(vn, db_vanna_path, query) +def serialize_docs(docs): + new_docs = [] + for doc in docs: + new_doc = {} + new_doc["page_content"] = doc.page_content + new_doc["metadata"] = doc.metadata + new_docs.append(new_doc) + return new_docs -async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only): - print("chat cqa - message received") - async for event in chat_stream(agent, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id): - yield event + + +async def chat(query,history,audience,sources,reports): + """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of: + (messages in gradio format, messages in langchain format, source documents)""" + + print(f">> NEW QUESTION : {query}") + + if audience == "Children": + audience_prompt = audience_prompts["children"] + elif audience == "General public": + audience_prompt = audience_prompts["general"] + elif audience == "Experts": + audience_prompt = audience_prompts["experts"] + else: + audience_prompt = audience_prompts["experts"] + + # Prepare default values + if len(sources) == 0: + sources = ["IPCC"] + + if len(reports) == 0: + reports = [] + + retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5) + rag_chain = make_rag_chain(retriever,llm) + + inputs = {"query": query,"audience": audience_prompt} + result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]}) + # result = rag_chain.stream(inputs) + + path_reformulation = "/logs/reformulation/final_output" + path_keywords = "/logs/keywords/final_output" + path_retriever = "/logs/find_documents/final_output" + path_answer = "/logs/answer/streamed_output_str/-" + + docs_html = "" + output_query = "" + output_language = "" + output_keywords = "" + gallery = [] + + try: + async for op in result: + + op = op.ops[0] + + if op['path'] == path_reformulation: # reforulated question + try: + output_language = op['value']["language"] # str + output_query = op["value"]["question"] + except Exception as e: + raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)") + + if op["path"] == path_keywords: + try: + output_keywords = op['value']["keywords"] # str + output_keywords = " AND ".join(output_keywords) + except Exception as e: + pass + + + elif op['path'] == path_retriever: # documents + try: + docs = op['value']['docs'] # List[Document] + docs_html = [] + for i, d in enumerate(docs, 1): + docs_html.append(make_html_source(d, i)) + docs_html = "".join(docs_html) + except TypeError: + print("No documents found") + print("op: ",op) + continue + + elif op['path'] == path_answer: # final answer + new_token = op['value'] # str + # time.sleep(0.01) + previous_answer = history[-1][1] + previous_answer = previous_answer if previous_answer is not None else "" + answer_yet = previous_answer + new_token + answer_yet = parse_output_llm_with_sources(answer_yet) + history[-1] = (query,answer_yet) + + + + else: + continue + + history = [tuple(x) for x in history] + yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords + + except Exception as e: + raise gr.Error(f"{e}") + + + try: + # Log answer on Azure Blob Storage + if os.getenv("GRADIO_ENV") != "local": + timestamp = str(datetime.now().timestamp()) + file = timestamp + ".json" + prompt = history[-1][0] + logs = { + "user_id": str(user_id), + "prompt": prompt, + "query": prompt, + "question":output_query, + "sources":sources, + "docs":serialize_docs(docs), + "answer": history[-1][1], + "time": timestamp, + } + log_on_azure(file, logs, share_client) + except Exception as e: + print(f"Error logging on Azure Blob Storage: {e}") + raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)") + + image_dict = {} + for i,doc in enumerate(docs): -async def chat_poc(query, history, audience, sources, reports, relevant_content_sources_selection, search_only): - print("chat poc - message received") - async for event in chat_stream(agent_poc, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id): - yield event + if doc.metadata["chunk_type"] == "image": + try: + key = f"Image {i+1}" + image_path = doc.metadata["image_path"].split("documents/")[1] + img = get_image_from_azure_blob_storage(image_path) + + # Convert the image to a byte buffer + buffered = BytesIO() + img.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + + # Embedding the base64 string in Markdown + markdown_image = f"" + image_dict[key] = {"img":img,"md":markdown_image,"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"]} + except Exception as e: + print(f"Skipped adding image {i} because of {e}") + + if len(image_dict) > 0: + + gallery = [x["img"] for x in list(image_dict.values())] + img = list(image_dict.values())[0] + img_md = img["md"] + img_caption = img["caption"] + img_code = img["figure_code"] + if img_code != "N/A": + img_name = f"{img['key']} - {img['figure_code']}" + else: + img_name = f"{img['key']}" + + answer_yet = history[-1][1] + f"\n\n{img_md}\n
{img_name} - {img_caption}
" + history[-1] = (history[-1][0],answer_yet) + history = [tuple(x) for x in history] + + # gallery = [x.metadata["image_path"] for x in docs if (len(x.metadata["image_path"]) > 0 and "IAS" in x.metadata["image_path"])] + # if len(gallery) > 0: + # gallery = list(set("|".join(gallery).split("|"))) + # gallery = [get_image_from_azure_blob_storage(x) for x in gallery] + + yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords + + +def make_html_source(source,i): + meta = source.metadata + # content = source.page_content.split(":",1)[1].strip() + content = source.page_content.strip() + + toc_levels = [] + for j in range(2): + level = meta[f"toc_level{j}"] + if level != "N/A": + toc_levels.append(level) + else: + break + toc_levels = " > ".join(toc_levels) + + if len(toc_levels) > 0: + name = f"{toc_levels}{content}
+{content}
+AI-generated description
+