import os from llama_index.core import load_index_from_storage, StorageContext, SQLDatabase, VectorStoreIndex, Settings from llama_index.core.objects import ( SQLTableNodeMapping, ObjectIndex, SQLTableSchema, ) from typing import List, Dict, Any from pydantic import BaseModel from langchain_community.embeddings.ollama import OllamaEmbeddings from llama_index.core.llms import ChatResponse from llama_index.core.indices.keyword_table.base import KeywordTableIndex from llama_index.core.query_pipeline import CustomQueryComponent from llama_index.core.retrievers import SQLRetriever from llama_index.core.bridge.pydantic import Field embed_model = OllamaEmbeddings(model="pornchat") Settings.embed_model = embed_model class CustomSQLRetriever(CustomQueryComponent): sql_db: SQLDatabase = Field(..., description="SQL Database") def _validate_component_inputs( self, input: Dict[str, Any] ) -> Dict[str, Any]: """Validate component inputs during run_component.""" # NOTE: this is OPTIONAL but we show you here how to do validation as an example return input @property def _input_keys(self) -> set: """Input keys dict.""" return {"query_str", "sql_query"} @property def _output_keys(self) -> set: # can do multi-outputs too return {"output", "is_valid"} def _run_component(self, **kwargs) -> Dict[str, Any]: """Run the component.""" # run logic retriever = SQLRetriever(self.sql_db) try: query = kwargs["sql_query"] output = retriever.retrieve(query) is_valid = True except Exception as e: output = "" is_valid = False return {"output": output, "is_valid": is_valid} def get_table_obj_retriever(index_path: str, table_infos: List[BaseModel],schema_table_mapping: Dict[str, str], sql_db: SQLDatabase): if os.path.exists(index_path): index = load_index_from_storage(StorageContext.from_defaults(persist_dir=index_path)) node_mapping = SQLTableNodeMapping(sql_db) table_schema_objs = [ SQLTableSchema(table_name=schema_table_mapping[t.table_name], context_str=t.table_summary) for t in table_infos ] obj_index = ObjectIndex.from_objects_and_index(objects=table_schema_objs, index=index) retriever = obj_index.as_retriever(simliarity_top_k=1) return retriever else: return False def create_table_obj_retriever(index_path: str, sql_db: SQLDatabase, table_infos: List[BaseModel], schema_table_mapping: Dict[str, str] ): table_node_mapping = SQLTableNodeMapping(sql_db) table_schema_objs = [ SQLTableSchema(table_name=schema_table_mapping[t.table_name], context_str=t.table_summary) for t in table_infos ] storage_context = StorageContext.from_defaults(persist_dir=index_path) obj_index = ObjectIndex.from_objects( table_schema_objs, table_node_mapping, KeywordTableIndex, ) retriever = obj_index.as_retriever(similarity_top_k=2) return retriever def get_table_context_str(schema_table_mapping: Dict[str, str],table_schema_objs: List[SQLTableSchema], sql_database: SQLDatabase): """Get table context string.""" context_strs = [] for table_schema_obj in table_schema_objs: table_info = sql_database.get_single_table_info( schema_table_mapping[table_schema_obj.table_name] ) if table_schema_obj.context_str: table_opt_context = " The table description is: " table_opt_context += table_schema_obj.context_str table_info += table_opt_context context_strs.append(table_info) return "\n\n".join(context_strs) def parse_response_to_sql(response: ChatResponse) -> str: """Parse response to SQL.""" response = response.message.content sql_query_start = response.find("SQLQuery:") if sql_query_start != -1: response = response[sql_query_start:] # TODO: move to removeprefix after Python 3.9+ if response.startswith("SQLQuery:"): response = response[len("SQLQuery:") :] sql_result_start = response.find("SQLResult:") if sql_result_start != -1: response = response[:sql_result_start] return response.strip().strip("```").strip()