|
from llama_index.core.objects import ( |
|
SQLTableNodeMapping, |
|
ObjectIndex, |
|
SQLTableSchema, |
|
) |
|
from llama_index.core import SQLDatabase, VectorStoreIndex |
|
from llama_index.core.llms import ChatResponse |
|
from llama_index.core.storage.chat_store import SimpleChatStore |
|
from serpapi import GoogleSearch |
|
from pydantic import BaseModel, Field |
|
from typing import Dict, Any, List, Tuple |
|
from bs4 import BeautifulSoup |
|
import os, requests, re, json |
|
import pymysql |
|
import hashlib |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
SERPAPI_KEY = os.getenv("SERPAPI_KEY") |
|
CHAT_STORE_PATH = os.getenv("CHAT_STORE_PATH") |
|
|
|
TABLE_SUMMARY = { |
|
"t_sur_media_sync_es": "This table is about Porn video information:\n\nt_sur_media_sync_es: Columns:id (integer), web_url (string), duration (integer), pattern_per (integer), like_count (integer), dislike_count (integer), view_count (integer), cover_picture (string), title (string), upload_date (datetime), uploader (string), create_time (datetime), update_time (datetime), categories (list of strings), abbreviate_video_url (string), abbreviate_mp4_video_url (string), resource_type (integer), like_count_show (string), stat_version (string), tags (list of strings), model_name (string), publisher_type (string), period (string), sexual_preference (string), country (string), type (string), rank_number (integer), rank_rate (string), has_pattern (boolean), trace (string), manifest_url (string), is_delete (boolean), web_url_md5 (string), view_key (string)", |
|
"t_sur_models_info": "This table is about Stripchat models' information:\n\nt_sur_models_info: Columns:id (INTEGER), username (VARCHAR(100), image (VARCHAR(500), num_users (INTEGER), pf (VARCHAR(50), pf_model_unite (VARCHAR(50), use_plugin (INTEGER), create_time (DATETIME), update_time (DATETIME), update_time (DATETIME), gender (VARCHAR(50), broadcast_type (VARCHAR(50), common_gender (VARCHAR(50), avatar (VARCHAR(512), age (INTEGER) ", |
|
} |
|
|
|
|
|
class TableInfo(BaseModel): |
|
"""Information regarding a structured table.""" |
|
|
|
table_name: str = Field( |
|
..., description="table name (must be underscores and NO spaces)" |
|
) |
|
table_summary: str = Field( |
|
..., description="short, concise summary/caption of the table" |
|
) |
|
|
|
|
|
class SQLResult(BaseModel): |
|
cols: List[str] = Field(..., description="The columns within the sql result") |
|
results: List[Dict[str, Any]] = Field( |
|
..., description="The results of the sql query" |
|
) |
|
|
|
|
|
class ProcessStatus(BaseModel): |
|
type: str = Field(..., description="The type of process") |
|
status: str = Field(..., description="The status of the process") |
|
|
|
def to_json(self): |
|
dict_obj = {"processing": {"type": self.type, "status": self.status}} |
|
json_str = json.dumps(dict_obj) |
|
return json_str |
|
|
|
def update(self, status: str): |
|
self.status = status |
|
|
|
class MySQLChatStore: |
|
def __init__(self, host, port, user, password, database): |
|
self.host = host |
|
self.port = port |
|
self.user = user |
|
self.password = password |
|
self.database = database |
|
self.config = { |
|
"host": self.host, |
|
"port": self.port, |
|
"user": self.user, |
|
"password": self.password, |
|
"database": self.database, |
|
} |
|
self.connection = pymysql.connect(**self.config) |
|
|
|
def get_chat_history(self, user_id): |
|
table_index = myhash(user_id) % 32 |
|
query = f"SELECT user_role, content FROM t_sur_ai_chat_history_{table_index} WHERE user_id = '{user_id}' ORDER BY create_time DESC LIMIT 4;" |
|
chat_history = [] |
|
with self.connection.cursor() as cursor: |
|
cursor.execute(query) |
|
result = cursor.fetchall() |
|
for row in reversed(result): |
|
chat_history.append(f"'{row[0]}': {row[1]}") |
|
return "\n".join(chat_history) |
|
|
|
def add_message(self, user_id, role, content): |
|
table_index = myhash(user_id) % 32 |
|
query = f"INSERT INTO t_sur_ai_chat_history_{table_index} (user_id, user_role, content, create_time) VALUES (%s, %s, %s, NOW());" |
|
with self.connection.cursor() as cursor: |
|
cursor.execute(query, (user_id, role, content)) |
|
self.connection.commit() |
|
|
|
def del_message(self, user_id, content): |
|
table_index = myhash(user_id) % 32 |
|
query = f"DELETE FROM t_sur_ai_chat_history_{table_index} WHERE user_id = %s AND content = %s;" |
|
with self.connection.cursor() as cursor: |
|
cursor.execute(query, (user_id, content)) |
|
self.connection.commit() |
|
|
|
class ToyStatusStore: |
|
def __init__(self, host, port, user, password, database): |
|
self.host = host |
|
self.port = port |
|
self.user = user |
|
self.password = password |
|
self.database = database |
|
self.config = { |
|
"host": self.host, |
|
"port": self.port, |
|
"user": self.user, |
|
"password": self.password, |
|
"database": self.database, |
|
} |
|
self.connection = pymysql.connect(**self.config) |
|
|
|
def get_latest(self, user_id): |
|
table_index = myhash(user_id) % 8 |
|
query = f"SELECT pattern, toy_name FROM t_sur_ai_toy_status_{table_index} WHERE user_id = '{user_id}' ORDER BY create_time DESC LIMIT 1;" |
|
with self.connection.cursor() as cursor: |
|
cursor.execute(query) |
|
pattern, toy_name = cursor.fetchall()[0] if cursor.rowcount > 0 else ("[]", "") |
|
pattern = json.loads(pattern) |
|
result = { |
|
"pattern": pattern, |
|
"toy_name": toy_name |
|
} |
|
return result |
|
|
|
def update(self, user_id, pattern, toy_name): |
|
table_index = myhash(user_id) % 8 |
|
query = f"INSERT INTO t_sur_ai_toy_status_{table_index} (user_id, pattern, toy_name, create_time) VALUES (%s, %s, %s, NOW());" |
|
with self.connection.cursor() as cursor: |
|
cursor.execute(query, (user_id, pattern, toy_name)) |
|
self.connection.commit() |
|
|
|
class ExtraStatus(BaseModel): |
|
adultMode: int = Field(..., description="The adult mode status") |
|
intentionResult: list | None |
|
sensitiveResult: list | None |
|
questionIsSex: str | None |
|
|
|
def to_json(self): |
|
adultMode = "1" if self.adultMode else "0" |
|
dict_obj = { |
|
"extraResults": { |
|
"adultMode": adultMode, |
|
"intentionResult": self.intentionResult, |
|
"sensitiveResult": self.sensitiveResult, |
|
"questionIsSex": self.questionIsSex, |
|
} |
|
} |
|
json_str = json.dumps(dict_obj) |
|
return json_str |
|
|
|
|
|
def myhash(string): |
|
hash_obj = hashlib.sha256() |
|
hash_obj.update(string.encode('utf-8')) |
|
hash_int = int.from_bytes(hash_obj.digest(), byteorder='big') |
|
return hash_int |
|
|
|
def create_table_retriever(sql_db: SQLDatabase): |
|
""" |
|
Create a table retriever that can retrieve table information from the SQL database. |
|
""" |
|
table_infos = [] |
|
table_names = sql_db.get_usable_table_names() |
|
for table in table_names: |
|
table_info = TableInfo(table_name=table, table_summary=TABLE_SUMMARY[table]) |
|
table_infos.append(table_info) |
|
|
|
node_mapping = SQLTableNodeMapping(sql_db) |
|
table_schema_objs = [ |
|
SQLTableSchema(table_name=t.table_name, context_str=t.table_summary) |
|
for t in table_infos |
|
] |
|
obj_index = ObjectIndex.from_objects( |
|
table_schema_objs, |
|
object_mapping=node_mapping, |
|
index_cls=VectorStoreIndex, |
|
) |
|
retriever = obj_index.as_retriever(similarity_top_k=1) |
|
return retriever |
|
|
|
|
|
def get_table_retriever(sql_db: SQLDatabase): |
|
table_infos = [] |
|
table_names = sql_db.get_usable_table_names() |
|
for table in table_names: |
|
table_info = TableInfo(table_name=table, table_summary=TABLE_SUMMARY[table]) |
|
table_infos.append(table_info) |
|
|
|
node_mapping = SQLTableNodeMapping(sql_db) |
|
obj_index = ObjectIndex.from_persist_dir( |
|
persist_dir="/home/purui/projects/chatbot/kb/sql/table_obj_index", |
|
object_node_mapping=node_mapping, |
|
) |
|
retriever = obj_index.as_retriever(similarity_top_k=1) |
|
return retriever |
|
|
|
|
|
def get_table_context_str( |
|
table_schema_objs: List[SQLTableSchema], sql_database: SQLDatabase |
|
): |
|
"""Get table context string.""" |
|
context_strs = [] |
|
for table_schema_obj in table_schema_objs: |
|
table_info = sql_database.get_single_table_info(table_schema_obj.table_name) |
|
if table_schema_obj.context_str: |
|
table_opt_context = " The table description is: " |
|
table_opt_context += table_schema_obj.context_str |
|
table_info += table_opt_context |
|
|
|
context_strs.append(table_info) |
|
return "\n\n".join(context_strs) |
|
|
|
|
|
def parse_response_to_sql(response: ChatResponse) -> str: |
|
"""Parse response to SQL.""" |
|
response = response.message.content |
|
sql_query_start = response.find("SQLQuery:") |
|
if sql_query_start != -1: |
|
response = response[sql_query_start:] |
|
|
|
if response.startswith("SQLQuery:"): |
|
response = response[len("SQLQuery:") :] |
|
sql_result_start = response.find("SQLResult:") |
|
if sql_result_start != -1: |
|
response = response[:sql_result_start] |
|
return response.strip().strip("```").strip() |
|
|
|
|
|
def parse_web_search_content(content: List[Dict[str, Any]]): |
|
"""Parse web search content.""" |
|
web_search_content = [] |
|
for idx, res in enumerate(content): |
|
keys = res.keys() |
|
if "title" and "link" in keys: |
|
title = res["title"] |
|
link = res["link"] |
|
content = f"-[{title}]({link})" |
|
web_search_content.append(content) |
|
else: |
|
web_search_content.append("") |
|
web_search_content = "\n".join(web_search_content) |
|
return web_search_content |
|
|
|
|
|
def parse_video_content(content: List[Dict[str, Any]]): |
|
"""Parse web search content.""" |
|
video_content = ["Videos:"] |
|
for idx, res in enumerate(content): |
|
try: |
|
title = res["title"] |
|
link = res["link"] |
|
content = f"- [{title}]({link})" |
|
video_content.append(content) |
|
except Exception as e: |
|
video_content.append("") |
|
video_content = "\n".join(video_content) |
|
return video_content |
|
|
|
|
|
def parse_image_content(content: List[Dict[str, Any]]): |
|
"""Parse web search content.""" |
|
image_content = ["Images:"] |
|
for idx, res in enumerate(content): |
|
try: |
|
title = res["title"] |
|
original = res["original"] |
|
content = f"- [{title}]({original})" |
|
image_content.append(content) |
|
except Exception as e: |
|
image_content.append("") |
|
image_content = "\n".join(image_content) |
|
return image_content |
|
|
|
|
|
def pares_sql_result( |
|
sql_result: List[Tuple[str, ...]], sql_query: str, col_keys: List[str] |
|
): |
|
result_list = [] |
|
seen = set() |
|
for row in sql_result: |
|
row_dict = {} |
|
for idx, col in enumerate(col_keys): |
|
row_dict[str(col)] = str(row[idx]) |
|
if tuple(sorted(row_dict.items())) in seen: |
|
continue |
|
result_list.append(row_dict) |
|
|
|
str_result = [] |
|
col_row = [str(col) for col in col_keys] |
|
str_result.append("\t".join(col_row)) |
|
for row in result_list: |
|
row_str = [] |
|
for k, v in row.items(): |
|
row_str.append(v) |
|
str_result.append("\t".join(row_str)) |
|
str_result = "\n".join(str_result) |
|
result = SQLResult(cols=col_keys, results=result_list) |
|
return result, str_result |
|
|
|
|
|
def load_chat_store(chat_store_name: str): |
|
"""Get user's chat history by sessionId""" |
|
path = f"{CHAT_STORE_PATH}/{chat_store_name}.json" |
|
if os.path.exists(path): |
|
chat_store = SimpleChatStore.from_persist_path(path) |
|
else: |
|
chat_store = SimpleChatStore() |
|
chat_store.persist(persist_path=path) |
|
return chat_store |
|
|
|
|
|
def video_search(q: str, mode: str): |
|
|
|
params = { |
|
"engine": "google_videos", |
|
"q": q, |
|
"google_domain": "google.com", |
|
"gl": "us", |
|
"hl": "en", |
|
"safe": mode, |
|
"num": 5, |
|
"api_key": SERPAPI_KEY, |
|
} |
|
result = GoogleSearch(params).get_dict() |
|
try: |
|
if result["video_results"]: |
|
video_result = result["video_results"] |
|
return video_result |
|
except: |
|
return False |
|
|
|
|
|
def image_search(q: str, mode: str): |
|
|
|
params = { |
|
"engine": "google_images", |
|
"q": q, |
|
"google_domain": "google.com", |
|
"gl": "us", |
|
"hl": "en", |
|
"safe": mode, |
|
"num": 20, |
|
"api_key": SERPAPI_KEY, |
|
} |
|
|
|
result = GoogleSearch(params).get_dict() |
|
try: |
|
if result["images_results"]: |
|
image_result = result["images_results"][:20] |
|
return image_result |
|
except: |
|
return False |
|
|
|
|
|
def general_search(q: str, mode: str): |
|
|
|
params = { |
|
"engine": "google_light", |
|
"q": q, |
|
"google_domain": "google.com", |
|
"gl": "us", |
|
"hl": "en", |
|
"safe": mode, |
|
"num": 5, |
|
"api_key": SERPAPI_KEY, |
|
} |
|
|
|
result = GoogleSearch(params).get_dict() |
|
try: |
|
if result["organic_results"]: |
|
general_result = result["organic_results"] |
|
return general_result |
|
except: |
|
return False |
|
|
|
|
|
def web_reader(url: str): |
|
try: |
|
print(f"parsing {url}...") |
|
headers = { |
|
"User-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" |
|
} |
|
response = requests.get(url=url, headers=headers, timeout=2.0) |
|
response.raise_for_status() |
|
html = response.content |
|
text = BeautifulSoup(html, "lxml").get_text() |
|
cleaned_text = re.sub(r"\n+", "\n", text) |
|
if len(cleaned_text) != 0: |
|
|
|
|
|
|
|
|
|
|
|
return cleaned_text |
|
else: |
|
return " " |
|
except Exception as e: |
|
print(e) |
|
return False |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = general_search("Chanell Heart", "off") |
|
print(result) |
|
|
|
{ |
|
"position": 1, |
|
"title": "Victoria SnakeySmut | Fansly", |
|
"link": "https://fansly.com/SnakeySmut", |
|
"displayed_link": "fansly.com/SnakeySmut", |
|
"snippet": "SnakeySmut conjures audio roleplays. Like the little noises I make with my mouth? Come see everything here! 18+ ONLY.", |
|
} |
|
{ |
|
"position": 1, |
|
"title": "Victoria SnakeySmut | Fansly", |
|
"link": "https://fansly.com/SnakeySmut", |
|
"displayed_link": "fansly.com/SnakeySmut", |
|
"snippet": "SnakeySmut conjures audio roleplays. Like the little noises I make with my mouth? Come see everything here! 18+ ONLY.", |
|
} |
|
{ |
|
"position": 1, |
|
"thumbnail": "https://cdn.lovense-api.com/UploadFiles/surfease/x3/chanell-heart.png", |
|
"related_content_id": "WkNzSFNndkhqVlBrOU1cIixcIk16bG1veURtUndJemZN", |
|
"serpapi_related_content_link": "https://cdn.lovense-api.com/UploadFiles/surfease/x3/chanell-heart.png", |
|
"source": "http://www.vibemate.com", |
|
"source_logo": "", |
|
"title": "Chanell Heart", |
|
"link": "https://cdn.lovense-api.com/UploadFiles/surfease/x3/chanell-heart.png", |
|
"original": "https://cdn.lovense-api.com/UploadFiles/surfease/x3/chanell-heart.png", |
|
"original_width": 2160, |
|
"original_height": 2700, |
|
"is_product": False, |
|
} |
|
{ |
|
"position": 1, |
|
"thumbnail": "https://cdn.lovense-api.com/UploadFiles/surfease/x3/SnakeySmut.png", |
|
"related_content_id": "WkNzSFNndkhqVlBrOU1cIixcIk16bG1veURtUndJemZN", |
|
"serpapi_related_content_link": "https://cdn.lovense-api.com/UploadFiles/surfease/x3/SnakeySmut.png", |
|
"source": "http://www.vibemate.com", |
|
"source_logo": "", |
|
"title": "Victoria SnakeySmut", |
|
"link": "https://cdn.lovense-api.com/UploadFiles/surfease/x3/SnakeySmut.png", |
|
"original": "https://cdn.lovense-api.com/UploadFiles/surfease/x3/SnakeySmut.png", |
|
"original_width": 2160, |
|
"original_height": 2700, |
|
"is_product": False, |
|
} |
|
|
|
|
|
def prRed(skk): print("\033[91m{}\033[00m" .format(skk)) |
|
|
|
def prGreen(skk): print("\033[92m{}\033[00m" .format(skk)) |
|
|
|
def prYellow(skk): print("\033[93m{}\033[00m" .format(skk)) |
|
|