|
import asyncpg |
|
|
|
class NetworkDB: |
|
def __init__(self, database_url): |
|
self.pool = None |
|
self.database_url = database_url |
|
|
|
async def get_pool(self): |
|
if self.pool: |
|
return self.pool |
|
self.pool = await asyncpg.create_pool( |
|
self.database_url, min_size=1, max_size=10 |
|
) |
|
return self.pool |
|
|
|
async def post_text(self, content: str, embeddings: list[float]) -> bool: |
|
|
|
|
|
try: |
|
conn = await asyncpg.connect(self.database_url) |
|
id = await conn.fetchval( |
|
"INSERT INTO text_posts (content, embedding) VALUES ($1, $2) RETURNING id", |
|
content, |
|
f"{embeddings}", |
|
) |
|
await conn.close() |
|
return True if id is not None else False |
|
except Exception as e: |
|
return False |
|
|
|
async def get_text_post_random(self) -> str: |
|
try: |
|
conn = await asyncpg.connect(self.database_url) |
|
id, post = await conn.fetchval( |
|
"SELECT (id, content) from text_posts ORDER BY random() LIMIT 1" |
|
) |
|
await conn.close() |
|
if post is not None: |
|
formatted_post = f"<|PostId_{id}|>\n{post}" |
|
return formatted_post |
|
return "[Internal Message: No post found!]" |
|
except Exception as e: |
|
print(f"Unexpected Error: {e}") |
|
return "[Internal Message: Server Error]" |
|
|
|
async def get_text_posts_latest(self) -> str: |
|
try: |
|
conn = await asyncpg.connect(self.database_url) |
|
posts = await conn.fetch("SELECT (id, content) from text_posts ORDER BY uploaded_at DESC LIMIT 5") |
|
await conn.close() |
|
if len(posts) == 0: |
|
return "[Internal Message: No posts in the database]" |
|
formatted_posts = "" |
|
for i, post in enumerate(posts): |
|
post = post[0] |
|
if i > 0: |
|
formatted_posts += "\n\n" |
|
formatted_posts += f'<|PostId_{post[0]}|>\n{post[1]}' |
|
return formatted_posts |
|
except Exception as e: |
|
print(f"Unexpected Error: {e}") |
|
return "[Internal Message: Server Error]" |
|
|
|
async def get_text_post_similar(self, query_embedding: list[float]) -> str: |
|
try: |
|
conn = await asyncpg.connect(self.database_url) |
|
id, post = await conn.fetchval( |
|
"SELECT (id, content) FROM text_posts ORDER BY embedding <-> $1 LIMIT 1", |
|
f"{query_embedding}", |
|
) |
|
await conn.close() |
|
if post is not None: |
|
formatted_post = f"<|PostId_{id}|>\n{post}" |
|
return formatted_post |
|
return "[Internal Message: No similar post found!]" |
|
except Exception as e: |
|
return "[Internal Message: Server Error]" |
|
|
|
async def get_text_post_comments(self, post_id: int) -> str: |
|
try: |
|
conn = await asyncpg.connect(self.database_url) |
|
comments = await conn.fetch( |
|
"SELECT content FROM text_posts_comments WHERE post_id = $1 ORDER BY uploaded_at DESC LIMIT 5", |
|
post_id |
|
) |
|
await conn.close() |
|
if len(comments) == 0: |
|
return "[Internal Message: No Comments on this post]" |
|
formatted_comments = "" |
|
for i, comment in enumerate(comments): |
|
|
|
if i > 0: |
|
formatted_comments += "\n\n" |
|
formatted_comments += f"<|Comment_{i}|>\n{comment['content']}" |
|
return formatted_comments |
|
except Exception as e: |
|
return ["Internal Message: Server Error"] |
|
|
|
async def comment_on_text_post(self, post_id: int, content: str) -> bool: |
|
try: |
|
conn = await asyncpg.connect(self.database_url) |
|
success = await conn.fetchval("INSERT INTO text_posts_comments (post_id, content) VALUES ($1, $2) RETURNING id", post_id, content) |
|
await conn.close() |
|
return False if success is None else True |
|
except Exception as e: |
|
return False |
|
|
|
async def disconnect(self) -> None: |
|
if self.pool: |
|
self.pool.close() |
|
|