|
import os |
|
from llama_index.core import load_index_from_storage, StorageContext, SQLDatabase, VectorStoreIndex, Settings |
|
from llama_index.core.objects import ( |
|
SQLTableNodeMapping, |
|
ObjectIndex, |
|
SQLTableSchema, |
|
) |
|
from typing import List, Dict, Any |
|
from pydantic import BaseModel |
|
from langchain_community.embeddings.ollama import OllamaEmbeddings |
|
from llama_index.core.llms import ChatResponse |
|
from llama_index.core.indices.keyword_table.base import KeywordTableIndex |
|
from llama_index.core.query_pipeline import CustomQueryComponent |
|
from llama_index.core.retrievers import SQLRetriever |
|
from llama_index.core.bridge.pydantic import Field |
|
|
|
embed_model = OllamaEmbeddings(model="pornchat") |
|
Settings.embed_model = embed_model |
|
|
|
class CustomSQLRetriever(CustomQueryComponent): |
|
|
|
sql_db: SQLDatabase = Field(..., description="SQL Database") |
|
|
|
def _validate_component_inputs( |
|
self, input: Dict[str, Any] |
|
) -> Dict[str, Any]: |
|
"""Validate component inputs during run_component.""" |
|
|
|
return input |
|
|
|
@property |
|
def _input_keys(self) -> set: |
|
"""Input keys dict.""" |
|
return {"query_str", "sql_query"} |
|
|
|
@property |
|
def _output_keys(self) -> set: |
|
|
|
return {"output", "is_valid"} |
|
|
|
def _run_component(self, **kwargs) -> Dict[str, Any]: |
|
"""Run the component.""" |
|
|
|
retriever = SQLRetriever(self.sql_db) |
|
try: |
|
query = kwargs["sql_query"] |
|
output = retriever.retrieve(query) |
|
is_valid = True |
|
except Exception as e: |
|
output = "" |
|
is_valid = False |
|
return {"output": output, "is_valid": is_valid} |
|
|
|
|
|
def get_table_obj_retriever(index_path: str, table_infos: List[BaseModel],schema_table_mapping: Dict[str, str], sql_db: SQLDatabase): |
|
if os.path.exists(index_path): |
|
index = load_index_from_storage(StorageContext.from_defaults(persist_dir=index_path)) |
|
node_mapping = SQLTableNodeMapping(sql_db) |
|
table_schema_objs = [ |
|
SQLTableSchema(table_name=schema_table_mapping[t.table_name], context_str=t.table_summary) for t in table_infos |
|
] |
|
obj_index = ObjectIndex.from_objects_and_index(objects=table_schema_objs, index=index) |
|
retriever = obj_index.as_retriever(simliarity_top_k=1) |
|
return retriever |
|
else: |
|
return False |
|
|
|
def create_table_obj_retriever(index_path: str, sql_db: SQLDatabase, |
|
table_infos: List[BaseModel], |
|
schema_table_mapping: Dict[str, str] |
|
): |
|
table_node_mapping = SQLTableNodeMapping(sql_db) |
|
table_schema_objs = [ |
|
SQLTableSchema(table_name=schema_table_mapping[t.table_name], context_str=t.table_summary) for t in table_infos |
|
] |
|
storage_context = StorageContext.from_defaults(persist_dir=index_path) |
|
obj_index = ObjectIndex.from_objects( |
|
table_schema_objs, |
|
table_node_mapping, |
|
KeywordTableIndex, |
|
) |
|
retriever = obj_index.as_retriever(similarity_top_k=2) |
|
return retriever |
|
|
|
def get_table_context_str(schema_table_mapping: Dict[str, str],table_schema_objs: List[SQLTableSchema], sql_database: SQLDatabase): |
|
"""Get table context string.""" |
|
context_strs = [] |
|
for table_schema_obj in table_schema_objs: |
|
table_info = sql_database.get_single_table_info( |
|
schema_table_mapping[table_schema_obj.table_name] |
|
) |
|
if table_schema_obj.context_str: |
|
table_opt_context = " The table description is: " |
|
table_opt_context += table_schema_obj.context_str |
|
table_info += table_opt_context |
|
|
|
context_strs.append(table_info) |
|
return "\n\n".join(context_strs) |
|
|
|
def parse_response_to_sql(response: ChatResponse) -> str: |
|
"""Parse response to SQL.""" |
|
response = response.message.content |
|
sql_query_start = response.find("SQLQuery:") |
|
if sql_query_start != -1: |
|
response = response[sql_query_start:] |
|
|
|
if response.startswith("SQLQuery:"): |
|
response = response[len("SQLQuery:") :] |
|
sql_result_start = response.find("SQLResult:") |
|
if sql_result_start != -1: |
|
response = response[:sql_result_start] |
|
return response.strip().strip("```").strip() |
|
|