from llama_index.core.query_pipeline import ( QueryPipeline, Link, InputComponent, CustomQueryComponent, ) from llama_index.core.objects import ( SQLTableNodeMapping, ObjectIndex, SQLTableSchema, ) from pyvis.network import Network import Stemmer from IPython.display import display, HTML from sqlalchemy import create_engine from llama_index.core import SQLDatabase, VectorStoreIndex, PromptTemplate from llama_index.core.program import LLMTextCompletionProgram from llama_index.core.bridge.pydantic import BaseModel, Field from typing import Dict, List, Any from llama_index.core.query_pipeline import FnComponent from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT from llama_index.core.retrievers import SQLRetriever from llama_index.llms.ollama import Ollama from llama_index.core.objects.base import ObjectRetriever import pymysql, pandas as pd from llama_index.retrievers.bm25 import BM25Retriever from llama_index.core.schema import IndexNode from modules import ( get_table_obj_retriever, create_table_obj_retriever, get_table_context_str, parse_response_to_sql, CustomSQLRetriever ) PROMPT_STR = """\ Give me a summary of the table with the following format. - table_summary: Describe what the table is about in short. Columns: [col1(type), col2(type), ...] Table: {table_str} """ db_user = "shenzhen_ai_for_vibemate_eson" db_password = "dBsnc7OrM0MVi0FEhiHe2y" db_host = "192.168.1.99" db_port = 3306 db_name = "hytto_surfease" TABLE_SUMMARY = { "t_sur_media_sync_es": "This table is about Porn video information:\n\nt_sur_media_sync_es: Columns:id (integer), web_url (string), duration (integer), pattern_per (integer), like_count (integer), dislike_count (integer), view_count (integer), cover_picture (string), title (string), upload_date (datetime), uploader (string), create_time (datetime), update_time (datetime), categories (list of strings), abbreviate_video_url (string), abbreviate_mp4_video_url (string), resource_type (string), like_count_show (integer), stat_version (integer), tags (list of strings), model_name (string), publisher_type (string), period (integer), sexual_preference (string), country (string), type (string), rank_number (integer), rank_rate (float), has_pattern (boolean), trace (string), manifest_url (string), is_delete (boolean), web_url_md5 (string), view_key (string)", "t_sur_models_info": "This table is about Stripchat models' information:\n\nt_sur_models_info: Columns:id (INTEGER), username (VARCHAR(100), image (VARCHAR(500), num_users (INTEGER), pf (VARCHAR(50), pf_model_unite (VARCHAR(50), use_plugin (INTEGER), create_time (DATETIME), update_time (DATETIME), update_time (DATETIME), gender (VARCHAR(50), broadcast_type (VARCHAR(50), common_gender (VARCHAR(50), avatar (VARCHAR(512), age (INTEGER) " } class SQLPipeline: def __init__(self, llm: Ollama): self.llm = llm self.engine = create_engine(f"mysql+pymysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}") self.sql_db = SQLDatabase(self.engine) self.table_names = self.sql_db.get_usable_table_names() self.schema_table_mapping = {} self.init_schema_table_mapping() self.modules = self.prepare_modules() self.pipeline = self.build_pipeline() def init_schema_table_mapping(self): self.table_infos = [] table_names = set() for table in self.table_names: table_info = TableInfo(table_name=table, table_summary=TABLE_SUMMARY[table]) self.table_infos.append(table_info) # 摘要表名: 真实表名 self.schema_table_mapping[table_info.table_name] = table def prepare_modules(self): modules = {} # input modules["input"] = InputComponent() # table retriever table_obj_index_path = "/home/purui/projects/chatbot/kb/sql/table_obj_index" retriever = create_table_obj_retriever( index_path=table_obj_index_path, table_infos=self.table_infos, sql_db=self.sql_db, schema_table_mapping=self.schema_table_mapping ) modules["table_retriever"] = TableRetrieveComponent( retriever=retriever, sql_database=self.sql_db ) # text2sql_prompt text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format( dialect=self.engine.dialect.name ) modules["text2sql_prompt"] = text2sql_prompt # text2sql_llm modules["text2sql_llm"] = self.llm # sql output parser modules["sql_output_parser"] = FnComponent(fn=parse_response_to_sql) # sql retriever # modules["sql_retriever"] = SQLRetriever(self.sql_db) modules["sql_retriever"] = CustomSQLRetriever(sql_db=self.sql_db) # response synthesise prompt response_synthesis_prompt_str = ( "Given an input question, synthesize a response from the query results.\n" "Query: {query_str}\n" "SQL: {sql_query}\n" "SQL Response: {context_str}\n" "Response: " ) response_synthesis_prompt = PromptTemplate( response_synthesis_prompt_str, ) modules["response_synthesis_prompt"] = response_synthesis_prompt # response synthesise llm modules["response_synthesis_llm"] = self.llm return modules def build_pipeline(self): qp = QueryPipeline( modules=self.modules, verbose=True, ) # add chains & links qp.add_link("input", "table_retriever", dest_key="query") qp.add_link("input", "text2sql_prompt", dest_key="query_str") qp.add_link("table_retriever", "text2sql_prompt", dest_key="schema") qp.add_chain( ["text2sql_prompt", "text2sql_llm", "sql_output_parser"] ) qp.add_link( "sql_output_parser", "response_synthesis_prompt", dest_key="sql_query" ) qp.add_link("input", "sql_retriever", dest_key="query_str") qp.add_link("sql_output_parser", "sql_retriever", dest_key="sql_query") # custom sql_retriever component:定义is_valid字段,如果执行sql检索有正确返回结果,则is_valid为True 作为sql_retriever -> response_synthesis_prompt的链接条件 # 若is_valid为False,则重新回到text2sql_prompt链路中,重新生成sql qp.add_link( "sql_retriever", "response_synthesis_prompt", dest_key="context_str", condition_fn=lambda x: x["is_valid"] ) qp.add_link("sql_retriever", "text2sql_prompt", src_key="query_str", dest_key="query_str", condition_fn=lambda x: not x["is_valid"]) qp.add_link("input", "response_synthesis_prompt", dest_key="query_str") qp.add_link("response_synthesis_prompt", "response_synthesis_llm") return qp def get_vision(self): net = Network(notebook=True, cdn_resources="in_line", directed=True) net.from_nx(self.pipeline.dag) net.write_html("text2sql_dag.html") with open("text2sql_dag.html", "r") as file: html_content = file.read() # Display the HTML content display(HTML(html_content)) def run(self, query: str): response = self.pipeline.run(query=query) return str(response) class TableInfo(BaseModel): """Information regarding a structured table.""" table_name: str = Field( ..., description="table name (must be underscores and NO spaces)" ) table_summary: str = Field( ..., description="short, concise summary/caption of the table" ) class TableRetrieveComponent(CustomQueryComponent): """Retrieves table information from the database.""" retriever: ObjectRetriever = Field(..., description="Retriever to use for table info") sql_database: SQLDatabase = Field(..., description="SQL engine to use for table info") def _validate_component_inputs( self, input: Dict[str, Any] ) -> Dict[str, Any]: """Validate component inputs during run_component.""" # NOTE: this is OPTIONAL but we show you here how to do validation as an example return input @property def _input_keys(self) -> set: """Input keys dict.""" return {"query"} @property def _output_keys(self) -> set: # can do multi-outputs too return {"output"} def _run_component(self, **kwargs) -> Dict[str, Any]: """Run the component.""" # run logic table_schema = self.retriever.retrieve(kwargs["query"])[0] table_name = table_schema.table_name table_info = TABLE_SUMMARY[table_name] return {"output": table_info} if __name__ == '__main__': sql_pipeline = SQLPipeline(llm=Ollama(model="mannix/llama3.1-8b-abliterated", request_timeout=120)) response = sql_pipeline.run("I want 5 different big tits milf porn with it's title and web url") print(response) # table_retriever = sql_pipeline.modules["table_retriever"] # # result = table_retriever.retrieve("Give me top 5 videos by view count.") # # print(result) # qp = QueryPipeline( # modules={ # "input": InputComponent(), # "table_retriever": TableRetrieveComponent(retriever=table_retriever, sql_database=sql_pipeline.sql_db), # } # ) # qp.add_link("input", "table_retriever", dest_key="query") # response = qp.run(query="Give me top 5 videos by view count.") # print(response)