|
import psycopg2 |
|
from sentence_transformers import SentenceTransformer |
|
|
|
class ProductDatabase: |
|
def __init__(self, database_url): |
|
self.database_url = database_url |
|
self.conn = None |
|
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
|
|
def connect(self): |
|
self.conn = psycopg2.connect(self.database_url) |
|
|
|
def close(self): |
|
if self.conn: |
|
self.conn.close() |
|
|
|
def setup_vector_extension_and_column(self): |
|
with self.conn.cursor() as cursor: |
|
|
|
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;") |
|
|
|
|
|
cursor.execute("ALTER TABLE products ADD COLUMN IF NOT EXISTS vector_col vector(384);") |
|
|
|
self.conn.commit() |
|
|
|
def get_embedding(self, text): |
|
embedding = self.model.encode(text) |
|
return embedding |
|
|
|
def insert_vector(self, product_id, text): |
|
vector = self.get_embedding(text).tolist() |
|
with self.conn.cursor() as cursor: |
|
cursor.execute("UPDATE diamondprice SET vector_col = %s WHERE id = %s", (vector, product_id)) |
|
self.conn.commit() |
|
|
|
def search_similar_vectors(self, query_text, top_k=5): |
|
query_vector = self.get_embedding(query_text).tolist() |
|
with self.conn.cursor() as cursor: |
|
cursor.execute(""" |
|
SELECT id, vector_col <=> %s::vector AS distance |
|
FROM diamondprice |
|
ORDER BY distance |
|
LIMIT %s; |
|
""", (query_vector, top_k)) |
|
results = cursor.fetchall() |
|
return results |
|
|
|
def search_similar_all(self, query_text, top_k=5): |
|
query_vector = self.get_embedding(query_text).tolist() |
|
with self.conn.cursor() as cursor: |
|
cursor.execute(""" |
|
SELECT id,'carat', 'cut', 'color', 'clarity', 'depth', 'diamondprice.table', 'x', 'y', 'z' |
|
FROM diamondprice |
|
where id = 1 |
|
""", (query_vector, top_k)) |
|
results = cursor.fetchall() |
|
return results |
|
|
|
def main(): |
|
|
|
DATABASE_URL = "postgresql://miyataken999:[email protected]/neondb?sslmode=require" |
|
|
|
|
|
db = ProductDatabase(DATABASE_URL) |
|
|
|
|
|
db.connect() |
|
|
|
try: |
|
|
|
db.setup_vector_extension_and_column() |
|
print("Vector extension installed and column added successfully.") |
|
query_text="1" |
|
results = db.search_similar_all(query_text) |
|
print("Search results:") |
|
for result in results: |
|
print(result) |
|
id = result[0] |
|
sample_text = str(result[1])+str(result[2])+str(result[3])+str(result[4])+str(result[5])+str(result[6]) |
|
db.insert_vector(sample_product_id, sample_text) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query_text = "12.03" |
|
results = db.search_similar_vectors(query_text) |
|
print("Search results:") |
|
for result in results: |
|
print(result) |
|
|
|
finally: |
|
|
|
db.close() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|