Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.prompts.prompt import PromptTemplate | |
| from langchain.retrievers.self_query.base import SelfQueryRetriever | |
| from langchain.retrievers.self_query.myscale import MyScaleTranslator | |
| from langchain.utilities.sql_database import SQLDatabase | |
| from langchain.vectorstores import MyScaleSettings | |
| from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever | |
| from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain | |
| from sqlalchemy import create_engine, MetaData | |
| from backend.constants.myscale_tables import MYSCALE_TABLES | |
| from backend.constants.prompts import MYSCALE_PROMPT | |
| from backend.constants.variables import TABLE_EMBEDDINGS_MAPPING, GLOBAL_CONFIG | |
| from backend.retrievers.vector_sql_output_parser import VectorSQLRetrieveOutputParser | |
| from backend.vector_store.myscale_without_metadata import MyScaleWithoutMetadataJson | |
| from logger import logger | |
| def build_self_query_retriever(table_name: str) -> SelfQueryRetriever: | |
| with st.spinner(f"Building VectorStore for MyScaleDB/{table_name} ..."): | |
| myscale_connection = { | |
| "host": GLOBAL_CONFIG.myscale_host, | |
| "port": GLOBAL_CONFIG.myscale_port, | |
| "username": GLOBAL_CONFIG.myscale_user, | |
| "password": GLOBAL_CONFIG.myscale_password, | |
| } | |
| myscale_settings = MyScaleSettings( | |
| **myscale_connection, | |
| database=MYSCALE_TABLES[table_name].database, | |
| table=MYSCALE_TABLES[table_name].table, | |
| column_map={ | |
| "id": "id", | |
| "text": MYSCALE_TABLES[table_name].text_col_name, | |
| "vector": MYSCALE_TABLES[table_name].vector_col_name, | |
| # TODO refine MyScaleDB metadata in langchain. | |
| "metadata": MYSCALE_TABLES[table_name].metadata_col_name | |
| } | |
| ) | |
| myscale_vector_store = MyScaleWithoutMetadataJson( | |
| embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name], | |
| config=myscale_settings, | |
| must_have_cols=MYSCALE_TABLES[table_name].must_have_col_names | |
| ) | |
| with st.spinner(f"Building SelfQueryRetriever for MyScaleDB/{table_name} ..."): | |
| retriever: SelfQueryRetriever = SelfQueryRetriever.from_llm( | |
| llm=ChatOpenAI( | |
| model_name=GLOBAL_CONFIG.query_model, | |
| base_url=GLOBAL_CONFIG.openai_api_base, | |
| api_key=GLOBAL_CONFIG.openai_api_key, | |
| temperature=0 | |
| ), | |
| vectorstore=myscale_vector_store, | |
| document_contents=MYSCALE_TABLES[table_name].table_contents, | |
| metadata_field_info=MYSCALE_TABLES[table_name].metadata_col_attributes, | |
| use_original_query=False, | |
| structured_query_translator=MyScaleTranslator() | |
| ) | |
| return retriever | |
| def build_vector_sql_db_chain_retriever(table_name: str) -> VectorSQLDatabaseChainRetriever: | |
| """Get a group of relative docs from MyScaleDB""" | |
| with st.spinner(f'Building Vector SQL Database Retriever for MyScaleDB/{table_name}...'): | |
| if GLOBAL_CONFIG.myscale_enable_https == False: | |
| engine = create_engine( | |
| f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@' | |
| f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}' | |
| f'/{MYSCALE_TABLES[table_name].database}?protocol=http' | |
| ) | |
| else: | |
| engine = create_engine( | |
| f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@' | |
| f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}' | |
| f'/{MYSCALE_TABLES[table_name].database}?protocol=https' | |
| ) | |
| metadata = MetaData(bind=engine) | |
| logger.info(f"{table_name} metadata is : {metadata}") | |
| prompt = PromptTemplate( | |
| input_variables=["input", "table_info", "top_k"], | |
| template=MYSCALE_PROMPT, | |
| ) | |
| # Custom `out_put_parser` rewrite search SQL, make it's possible to query custom column. | |
| output_parser = VectorSQLRetrieveOutputParser.from_embeddings( | |
| model=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name], | |
| # rewrite columns needs be searched. | |
| must_have_columns=MYSCALE_TABLES[table_name].must_have_col_names | |
| ) | |
| # `db_chain` will generate a SQL | |
| vector_sql_db_chain: VectorSQLDatabaseChain = VectorSQLDatabaseChain.from_llm( | |
| llm=ChatOpenAI( | |
| model_name=GLOBAL_CONFIG.query_model, | |
| base_url=GLOBAL_CONFIG.openai_api_base, | |
| api_key=GLOBAL_CONFIG.openai_api_key, | |
| temperature=0 | |
| ), | |
| prompt=prompt, | |
| top_k=10, | |
| return_direct=True, | |
| db=SQLDatabase( | |
| engine, | |
| None, | |
| metadata, | |
| include_tables=[MYSCALE_TABLES[table_name].table], | |
| max_string_length=1024 | |
| ), | |
| sql_cmd_parser=output_parser, # TODO needs update `langchain`, fix return type. | |
| native_format=True | |
| ) | |
| # `retriever` can search a group of documents with `db_chain` | |
| vector_sql_db_chain_retriever = VectorSQLDatabaseChainRetriever( | |
| sql_db_chain=vector_sql_db_chain, | |
| page_content_key=MYSCALE_TABLES[table_name].text_col_name | |
| ) | |
| return vector_sql_db_chain_retriever | |