|
import asyncpg |
|
|
|
class NetworkDB: |
|
def __init__(self, database_url): |
|
self.pool = None |
|
self.database_url = database_url |
|
|
|
async def get_pool(self): |
|
if self.pool: |
|
return self.pool |
|
self.pool = await asyncpg.create_pool( |
|
self.database_url, min_size=1, max_size=10 |
|
) |
|
return self.pool |
|
|
|
async def post_text(self, content: str, embeddings: list[float]) -> bool: |
|
|
|
|
|
try: |
|
conn = await asyncpg.connect(self.database_url) |
|
id = await conn.fetchval( |
|
"INSERT INTO text_posts (content, embedding) VALUES ($1, $2) RETURNING id", |
|
content, |
|
f"{embeddings}", |
|
) |
|
await conn.close() |
|
return True if id is not None else False |
|
except Exception as e: |
|
return False |
|
|
|
async def get_text_post_random(self) -> str: |
|
try: |
|
conn = await asyncpg.connect(self.database_url) |
|
post = await conn.fetchval( |
|
"SELECT content from text_posts ORDER BY random() LIMIT 1" |
|
) |
|
await conn.close() |
|
return post if post is not None else "[Internal Message: No post found!]" |
|
except Exception as e: |
|
return "[Internal Message: Server Error]" |
|
|
|
async def get_text_post_similar(self, query_embedding: list[float]) -> str: |
|
try: |
|
conn = await asyncpg.connect(self.database_url) |
|
post = await conn.fetchval( |
|
"SELECT content FROM text_posts ORDER BY embedding <-> $1 LIMIT 1", |
|
f"{query_embedding}", |
|
) |
|
await conn.close() |
|
return ( |
|
post |
|
if post is not None |
|
else "[Internal Message: No similar post found!]" |
|
) |
|
except Exception as e: |
|
return "[Internal Message: Server Error]" |
|
|
|
async def disconnect(self) -> None: |
|
if self.pool: |
|
self.pool.close() |
|
|