File size: 9,724 Bytes
318db6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
from llama_index.core.query_pipeline import (
    QueryPipeline,
    Link,
    InputComponent,
    CustomQueryComponent,
)

from llama_index.core.objects import (
    SQLTableNodeMapping,
    ObjectIndex,
    SQLTableSchema,
)
from pyvis.network import Network
import Stemmer
from IPython.display import display, HTML
from sqlalchemy import create_engine
from llama_index.core import SQLDatabase, VectorStoreIndex, PromptTemplate
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.bridge.pydantic import BaseModel, Field
from typing import Dict, List, Any
from llama_index.core.query_pipeline import FnComponent
from llama_index.core.prompts.default_prompts import DEFAULT_TEXT_TO_SQL_PROMPT
from llama_index.core.retrievers import SQLRetriever
from llama_index.llms.ollama import Ollama
from llama_index.core.objects.base import ObjectRetriever
import pymysql, pandas as pd
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.schema import IndexNode
from modules import (
    get_table_obj_retriever, 
    create_table_obj_retriever,
    get_table_context_str,
    parse_response_to_sql,
    CustomSQLRetriever
    )



PROMPT_STR = """\
    Give me a summary of the table with the following format.
    
    - table_summary: Describe what the table is about in short. Columns: [col1(type), col2(type), ...]
    
    Table:
    {table_str}
     """

db_user = "shenzhen_ai_for_vibemate_eson"
db_password = "dBsnc7OrM0MVi0FEhiHe2y"
db_host = "192.168.1.99"
db_port = 3306
db_name = "hytto_surfease"

TABLE_SUMMARY = {
    "t_sur_media_sync_es": "This table is about Porn video information:\n\nt_sur_media_sync_es: Columns:id (integer), web_url (string), duration (integer), pattern_per (integer), like_count (integer), dislike_count (integer), view_count (integer), cover_picture (string), title (string), upload_date (datetime), uploader (string), create_time (datetime), update_time (datetime), categories (list of strings), abbreviate_video_url (string), abbreviate_mp4_video_url (string), resource_type (string), like_count_show (integer), stat_version (integer), tags (list of strings), model_name (string), publisher_type (string), period (integer), sexual_preference (string), country (string), type (string), rank_number (integer), rank_rate (float), has_pattern (boolean), trace (string), manifest_url (string), is_delete (boolean), web_url_md5 (string), view_key (string)",
    "t_sur_models_info": "This table is about Stripchat models' information:\n\nt_sur_models_info: Columns:id (INTEGER), username (VARCHAR(100), image (VARCHAR(500), num_users (INTEGER), pf (VARCHAR(50), pf_model_unite (VARCHAR(50), use_plugin (INTEGER), create_time (DATETIME), update_time (DATETIME), update_time (DATETIME), gender (VARCHAR(50), broadcast_type (VARCHAR(50), common_gender (VARCHAR(50), avatar (VARCHAR(512), age (INTEGER) "
}

class SQLPipeline:
    def __init__(self, llm: Ollama):
        self.llm = llm
        self.engine = create_engine(f"mysql+pymysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}")
        self.sql_db = SQLDatabase(self.engine)
        self.table_names = self.sql_db.get_usable_table_names()
        self.schema_table_mapping = {}
        self.init_schema_table_mapping()
        self.modules = self.prepare_modules()
        self.pipeline = self.build_pipeline()
        
    def init_schema_table_mapping(self):
        self.table_infos = []
        table_names = set()
        for table in self.table_names:
            table_info = TableInfo(table_name=table, table_summary=TABLE_SUMMARY[table])
            self.table_infos.append(table_info)
            # 摘要表名: 真实表名
            self.schema_table_mapping[table_info.table_name] = table
        
    def prepare_modules(self):
        modules = {}
        # input
        modules["input"] = InputComponent()
        # table retriever
        table_obj_index_path = "/home/purui/projects/chatbot/kb/sql/table_obj_index"
        retriever = create_table_obj_retriever(
                index_path=table_obj_index_path,
                table_infos=self.table_infos,
                sql_db=self.sql_db,
                schema_table_mapping=self.schema_table_mapping
            )
        modules["table_retriever"] = TableRetrieveComponent(
            retriever=retriever,
            sql_database=self.sql_db
        )
        # text2sql_prompt
        text2sql_prompt = DEFAULT_TEXT_TO_SQL_PROMPT.partial_format(
            dialect=self.engine.dialect.name
        )
        modules["text2sql_prompt"] = text2sql_prompt
        # text2sql_llm
        modules["text2sql_llm"] = self.llm
        # sql output parser
        modules["sql_output_parser"] = FnComponent(fn=parse_response_to_sql)
        # sql retriever
        # modules["sql_retriever"] = SQLRetriever(self.sql_db)
        modules["sql_retriever"] = CustomSQLRetriever(sql_db=self.sql_db)
        # response synthesise prompt
        response_synthesis_prompt_str = (
            "Given an input question, synthesize a response from the query results.\n"
            "Query: {query_str}\n"
            "SQL: {sql_query}\n"
            "SQL Response: {context_str}\n"
            "Response: "
        )
        response_synthesis_prompt = PromptTemplate(
            response_synthesis_prompt_str,
        )
        modules["response_synthesis_prompt"] = response_synthesis_prompt
        # response synthesise llm
        modules["response_synthesis_llm"] = self.llm
        
        return modules
        
    def build_pipeline(self):
        qp = QueryPipeline(
            modules=self.modules,
            verbose=True,
        )
        # add chains & links
        qp.add_link("input", "table_retriever", dest_key="query")
        qp.add_link("input", "text2sql_prompt", dest_key="query_str")
        qp.add_link("table_retriever", "text2sql_prompt", dest_key="schema")
        qp.add_chain(
            ["text2sql_prompt", "text2sql_llm", "sql_output_parser"]
        )
        qp.add_link(
            "sql_output_parser", "response_synthesis_prompt", dest_key="sql_query"
        )
        qp.add_link("input", "sql_retriever", dest_key="query_str")
        qp.add_link("sql_output_parser", "sql_retriever", dest_key="sql_query")
        # custom sql_retriever component:定义is_valid字段,如果执行sql检索有正确返回结果,则is_valid为True 作为sql_retriever -> response_synthesis_prompt的链接条件
        # 若is_valid为False,则重新回到text2sql_prompt链路中,重新生成sql
        qp.add_link(
            "sql_retriever", "response_synthesis_prompt", dest_key="context_str", condition_fn=lambda x: x["is_valid"]
        )
        qp.add_link("sql_retriever", "text2sql_prompt", src_key="query_str", dest_key="query_str", condition_fn=lambda x: not x["is_valid"])
        qp.add_link("input", "response_synthesis_prompt", dest_key="query_str")
        qp.add_link("response_synthesis_prompt", "response_synthesis_llm")
        return qp
    
    def get_vision(self):
        net = Network(notebook=True, cdn_resources="in_line", directed=True)
        net.from_nx(self.pipeline.dag)
        net.write_html("text2sql_dag.html")
        with open("text2sql_dag.html", "r") as file:
            html_content = file.read()
        # Display the HTML content
        display(HTML(html_content))
    def run(self, query: str):
        response = self.pipeline.run(query=query)
        return str(response)


class TableInfo(BaseModel):
    """Information regarding a structured table."""

    table_name: str = Field(
        ..., description="table name (must be underscores and NO spaces)"
    )
    table_summary: str = Field(
        ..., description="short, concise summary/caption of the table"
    )


class TableRetrieveComponent(CustomQueryComponent):
    """Retrieves table information from the database."""
    
    retriever: ObjectRetriever = Field(..., description="Retriever to use for table info")
    sql_database: SQLDatabase = Field(..., description="SQL engine to use for table info")
    
    def _validate_component_inputs(
        self, input: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Validate component inputs during run_component."""
        # NOTE: this is OPTIONAL but we show you here how to do validation as an example
        return input

    @property
    def _input_keys(self) -> set:
        """Input keys dict."""
        return {"query"}

    @property
    def _output_keys(self) -> set:
        # can do multi-outputs too
        return {"output"}

    def _run_component(self, **kwargs) -> Dict[str, Any]:
        """Run the component."""
        # run logic
        table_schema = self.retriever.retrieve(kwargs["query"])[0]
        table_name = table_schema.table_name
        table_info = TABLE_SUMMARY[table_name]
        return {"output": table_info}



if __name__ == '__main__': 
    sql_pipeline = SQLPipeline(llm=Ollama(model="mannix/llama3.1-8b-abliterated", 
                                          request_timeout=120))
    response = sql_pipeline.run("I want 5 different big tits milf porn with it's title and web url")
    print(response)
    # table_retriever = sql_pipeline.modules["table_retriever"]
    # # result = table_retriever.retrieve("Give me top 5 videos by view count.")
    # # print(result)
    # qp = QueryPipeline(
    #     modules={
    #         "input": InputComponent(),
    #         "table_retriever": TableRetrieveComponent(retriever=table_retriever, sql_database=sql_pipeline.sql_db),
    #     }
    # )
    # qp.add_link("input", "table_retriever", dest_key="query")
    # response = qp.run(query="Give me top 5 videos by view count.")
    # print(response)