File size: 9,724 Bytes
318db6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
from llama_index.core.query_pipeline import (
QueryPipeline,
Link,
InputComponent,
CustomQueryComponent,
)
from llama_index.core.objects import (
SQLTableNodeMapping,
ObjectIndex,
SQLTableSchema,
)
from pyvis.network import Network
import Stemmer
from IPython.display import display, HTML
from sqlalchemy import create_engine
from llama_index.core import SQLDatabase, VectorStoreIndex, PromptTemplate
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.bridge.pydantic import BaseModel, Field
from typing import Dict, List, Any
from llama_index.core.query_pipeline import FnComponent
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.core.retrievers import SQLRetriever
from llama_index.llms.ollama import Ollama
from llama_index.core.objects.base import ObjectRetriever
import pymysql, pandas as pd
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.schema import IndexNode
from modules import (
get_table_obj_retriever,
create_table_obj_retriever,
get_table_context_str,
parse_response_to_sql,
CustomSQLRetriever
)
PROMPT_STR = """\
Give me a summary of the table with the following format.
- table_summary: Describe what the table is about in short. Columns: [col1(type), col2(type), ...]
Table:
{table_str}
"""
db_user = "shenzhen_ai_for_vibemate_eson"
db_password = "dBsnc7OrM0MVi0FEhiHe2y"
db_host = "192.168.1.99"
db_port = 3306
db_name = "hytto_surfease"
TABLE_SUMMARY = {
"t_sur_media_sync_es": "This table is about Porn video information:\n\nt_sur_media_sync_es: Columns:id (integer), web_url (string), duration (integer), pattern_per (integer), like_count (integer), dislike_count (integer), view_count (integer), cover_picture (string), title (string), upload_date (datetime), uploader (string), create_time (datetime), update_time (datetime), categories (list of strings), abbreviate_video_url (string), abbreviate_mp4_video_url (string), resource_type (string), like_count_show (integer), stat_version (integer), tags (list of strings), model_name (string), publisher_type (string), period (integer), sexual_preference (string), country (string), type (string), rank_number (integer), rank_rate (float), has_pattern (boolean), trace (string), manifest_url (string), is_delete (boolean), web_url_md5 (string), view_key (string)",
"t_sur_models_info": "This table is about Stripchat models' information:\n\nt_sur_models_info: Columns:id (INTEGER), username (VARCHAR(100), image (VARCHAR(500), num_users (INTEGER), pf (VARCHAR(50), pf_model_unite (VARCHAR(50), use_plugin (INTEGER), create_time (DATETIME), update_time (DATETIME), update_time (DATETIME), gender (VARCHAR(50), broadcast_type (VARCHAR(50), common_gender (VARCHAR(50), avatar (VARCHAR(512), age (INTEGER) "
}
class SQLPipeline:
def __init__(self, llm: Ollama):
self.llm = llm
self.engine = create_engine(f"mysql+pymysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}")
self.sql_db = SQLDatabase(self.engine)
self.table_names = self.sql_db.get_usable_table_names()
self.schema_table_mapping = {}
self.init_schema_table_mapping()
self.modules = self.prepare_modules()
self.pipeline = self.build_pipeline()
def init_schema_table_mapping(self):
self.table_infos = []
table_names = set()
for table in self.table_names:
table_info = TableInfo(table_name=table, table_summary=TABLE_SUMMARY[table])
self.table_infos.append(table_info)
# 摘要表名: 真实表名
self.schema_table_mapping[table_info.table_name] = table
def prepare_modules(self):
modules = {}
# input
modules["input"] = InputComponent()
# table retriever
table_obj_index_path = "/home/purui/projects/chatbot/kb/sql/table_obj_index"
retriever = create_table_obj_retriever(
index_path=table_obj_index_path,
table_infos=self.table_infos,
sql_db=self.sql_db,
schema_table_mapping=self.schema_table_mapping
)
modules["table_retriever"] = TableRetrieveComponent(
retriever=retriever,
sql_database=self.sql_db
)
# text2sql_prompt
text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
dialect=self.engine.dialect.name
)
modules["text2sql_prompt"] = text2sql_prompt
# text2sql_llm
modules["text2sql_llm"] = self.llm
# sql output parser
modules["sql_output_parser"] = FnComponent(fn=parse_response_to_sql)
# sql retriever
# modules["sql_retriever"] = SQLRetriever(self.sql_db)
modules["sql_retriever"] = CustomSQLRetriever(sql_db=self.sql_db)
# response synthesise prompt
response_synthesis_prompt_str = (
"Given an input question, synthesize a response from the query results.\n"
"Query: {query_str}\n"
"SQL: {sql_query}\n"
"SQL Response: {context_str}\n"
"Response: "
)
response_synthesis_prompt = PromptTemplate(
response_synthesis_prompt_str,
)
modules["response_synthesis_prompt"] = response_synthesis_prompt
# response synthesise llm
modules["response_synthesis_llm"] = self.llm
return modules
def build_pipeline(self):
qp = QueryPipeline(
modules=self.modules,
verbose=True,
)
# add chains & links
qp.add_link("input", "table_retriever", dest_key="query")
qp.add_link("input", "text2sql_prompt", dest_key="query_str")
qp.add_link("table_retriever", "text2sql_prompt", dest_key="schema")
qp.add_chain(
["text2sql_prompt", "text2sql_llm", "sql_output_parser"]
)
qp.add_link(
"sql_output_parser", "response_synthesis_prompt", dest_key="sql_query"
)
qp.add_link("input", "sql_retriever", dest_key="query_str")
qp.add_link("sql_output_parser", "sql_retriever", dest_key="sql_query")
# custom sql_retriever component:定义is_valid字段,如果执行sql检索有正确返回结果,则is_valid为True 作为sql_retriever -> response_synthesis_prompt的链接条件
# 若is_valid为False,则重新回到text2sql_prompt链路中,重新生成sql
qp.add_link(
"sql_retriever", "response_synthesis_prompt", dest_key="context_str", condition_fn=lambda x: x["is_valid"]
)
qp.add_link("sql_retriever", "text2sql_prompt", src_key="query_str", dest_key="query_str", condition_fn=lambda x: not x["is_valid"])
qp.add_link("input", "response_synthesis_prompt", dest_key="query_str")
qp.add_link("response_synthesis_prompt", "response_synthesis_llm")
return qp
def get_vision(self):
net = Network(notebook=True, cdn_resources="in_line", directed=True)
net.from_nx(self.pipeline.dag)
net.write_html("text2sql_dag.html")
with open("text2sql_dag.html", "r") as file:
html_content = file.read()
# Display the HTML content
display(HTML(html_content))
def run(self, query: str):
response = self.pipeline.run(query=query)
return str(response)
class TableInfo(BaseModel):
"""Information regarding a structured table."""
table_name: str = Field(
..., description="table name (must be underscores and NO spaces)"
)
table_summary: str = Field(
..., description="short, concise summary/caption of the table"
)
class TableRetrieveComponent(CustomQueryComponent):
"""Retrieves table information from the database."""
retriever: ObjectRetriever = Field(..., description="Retriever to use for table info")
sql_database: SQLDatabase = Field(..., description="SQL engine to use for table info")
def _validate_component_inputs(
self, input: Dict[str, Any]
) -> Dict[str, Any]:
"""Validate component inputs during run_component."""
# NOTE: this is OPTIONAL but we show you here how to do validation as an example
return input
@property
def _input_keys(self) -> set:
"""Input keys dict."""
return {"query"}
@property
def _output_keys(self) -> set:
# can do multi-outputs too
return {"output"}
def _run_component(self, **kwargs) -> Dict[str, Any]:
"""Run the component."""
# run logic
table_schema = self.retriever.retrieve(kwargs["query"])[0]
table_name = table_schema.table_name
table_info = TABLE_SUMMARY[table_name]
return {"output": table_info}
if __name__ == '__main__':
sql_pipeline = SQLPipeline(llm=Ollama(model="mannix/llama3.1-8b-abliterated",
request_timeout=120))
response = sql_pipeline.run("I want 5 different big tits milf porn with it's title and web url")
print(response)
# table_retriever = sql_pipeline.modules["table_retriever"]
# # result = table_retriever.retrieve("Give me top 5 videos by view count.")
# # print(result)
# qp = QueryPipeline(
# modules={
# "input": InputComponent(),
# "table_retriever": TableRetrieveComponent(retriever=table_retriever, sql_database=sql_pipeline.sql_db),
# }
# )
# qp.add_link("input", "table_retriever", dest_key="query")
# response = qp.run(query="Give me top 5 videos by view count.")
# print(response) |