SexBot / pipeline /modules.py
Pew404's picture
Upload folder using huggingface_hub
318db6e verified
import os
from llama_index.core import load_index_from_storage, StorageContext, SQLDatabase, VectorStoreIndex, Settings
from llama_index.core.objects import (
SQLTableNodeMapping,
ObjectIndex,
SQLTableSchema,
)
from typing import List, Dict, Any
from pydantic import BaseModel
from langchain_community.embeddings.ollama import OllamaEmbeddings
from llama_index.core.llms import ChatResponse
from llama_index.core.indices.keyword_table.base import KeywordTableIndex
from llama_index.core.query_pipeline import CustomQueryComponent
from llama_index.core.retrievers import SQLRetriever
from llama_index.core.bridge.pydantic import Field
embed_model = OllamaEmbeddings(model="pornchat")
Settings.embed_model = embed_model
class CustomSQLRetriever(CustomQueryComponent):
sql_db: SQLDatabase = Field(..., description="SQL Database")
def _validate_component_inputs(
self, input: Dict[str, Any]
) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
# NOTE: this is OPTIONAL but we show you here how to do validation as an example
return input
@property
def _input_keys(self) -> set:
"""Input keys dict."""
return {"query_str", "sql_query"}
@property
def _output_keys(self) -> set:
# can do multi-outputs too
return {"output", "is_valid"}
def _run_component(self, **kwargs) -> Dict[str, Any]:
"""Run the component."""
# run logic
retriever = SQLRetriever(self.sql_db)
try:
query = kwargs["sql_query"]
output = retriever.retrieve(query)
is_valid = True
except Exception as e:
output = ""
is_valid = False
return {"output": output, "is_valid": is_valid}
def get_table_obj_retriever(index_path: str, table_infos: List[BaseModel],schema_table_mapping: Dict[str, str], sql_db: SQLDatabase):
if os.path.exists(index_path):
index = load_index_from_storage(StorageContext.from_defaults(persist_dir=index_path))
node_mapping = SQLTableNodeMapping(sql_db)
table_schema_objs = [
SQLTableSchema(table_name=schema_table_mapping[t.table_name], context_str=t.table_summary) for t in table_infos
]
obj_index = ObjectIndex.from_objects_and_index(objects=table_schema_objs, index=index)
retriever = obj_index.as_retriever(simliarity_top_k=1)
return retriever
else:
return False
def create_table_obj_retriever(index_path: str, sql_db: SQLDatabase,
table_infos: List[BaseModel],
schema_table_mapping: Dict[str, str]
):
table_node_mapping = SQLTableNodeMapping(sql_db)
table_schema_objs = [
SQLTableSchema(table_name=schema_table_mapping[t.table_name], context_str=t.table_summary) for t in table_infos
]
storage_context = StorageContext.from_defaults(persist_dir=index_path)
obj_index = ObjectIndex.from_objects(
table_schema_objs,
table_node_mapping,
KeywordTableIndex,
)
retriever = obj_index.as_retriever(similarity_top_k=2)
return retriever
def get_table_context_str(schema_table_mapping: Dict[str, str],table_schema_objs: List[SQLTableSchema], sql_database: SQLDatabase):
"""Get table context string."""
context_strs = []
for table_schema_obj in table_schema_objs:
table_info = sql_database.get_single_table_info(
schema_table_mapping[table_schema_obj.table_name]
)
if table_schema_obj.context_str:
table_opt_context = " The table description is: "
table_opt_context += table_schema_obj.context_str
table_info += table_opt_context
context_strs.append(table_info)
return "\n\n".join(context_strs)
def parse_response_to_sql(response: ChatResponse) -> str:
"""Parse response to SQL."""
response = response.message.content
sql_query_start = response.find("SQLQuery:")
if sql_query_start != -1:
response = response[sql_query_start:]
# TODO: move to removeprefix after Python 3.9+
if response.startswith("SQLQuery:"):
response = response[len("SQLQuery:") :]
sql_result_start = response.find("SQLResult:")
if sql_result_start != -1:
response = response[:sql_result_start]
return response.strip().strip("```").strip()