SocialNetwork / database.py
inventwithdean
fix post id formatting
abae45a
import asyncpg
class NetworkDB:
def __init__(self, database_url):
self.pool = None
self.database_url = database_url
async def get_pool(self):
if self.pool:
return self.pool
self.pool = await asyncpg.create_pool(
self.database_url, min_size=1, max_size=10
)
return self.pool
async def post_text(self, content: str, embeddings: list[float]) -> bool:
# pool = await self.get_pool()
# async with pool.acquire() as conn:
try:
conn = await asyncpg.connect(self.database_url)
id = await conn.fetchval(
"INSERT INTO text_posts (content, embedding) VALUES ($1, $2) RETURNING id",
content,
f"{embeddings}",
)
await conn.close()
return True if id is not None else False
except Exception as e:
return False
async def get_text_post_random(self) -> str:
try:
conn = await asyncpg.connect(self.database_url)
id, post = await conn.fetchval(
"SELECT (id, content) from text_posts ORDER BY random() LIMIT 1"
)
await conn.close()
if post is not None:
formatted_post = f"<|PostId_{id}|>\n{post}"
return formatted_post
return "[Internal Message: No post found!]"
except Exception as e:
print(f"Unexpected Error: {e}")
return "[Internal Message: Server Error]"
async def get_text_posts_latest(self) -> str:
try:
conn = await asyncpg.connect(self.database_url)
posts = await conn.fetch("SELECT (id, content) from text_posts ORDER BY uploaded_at DESC LIMIT 5")
await conn.close()
if len(posts) == 0:
return "[Internal Message: No posts in the database]"
formatted_posts = ""
for i, post in enumerate(posts):
post = post[0]
if i > 0:
formatted_posts += "\n\n"
formatted_posts += f'<|PostId_{post[0]}|>\n{post[1]}'
return formatted_posts
except Exception as e:
print(f"Unexpected Error: {e}")
return "[Internal Message: Server Error]"
async def get_text_post_similar(self, query_embedding: list[float]) -> str:
try:
conn = await asyncpg.connect(self.database_url)
id, post = await conn.fetchval(
"SELECT (id, content) FROM text_posts ORDER BY embedding <-> $1 LIMIT 1",
f"{query_embedding}",
)
await conn.close()
if post is not None:
formatted_post = f"<|PostId_{id}|>\n{post}"
return formatted_post
return "[Internal Message: No similar post found!]"
except Exception as e:
return "[Internal Message: Server Error]"
async def get_text_post_comments(self, post_id: int) -> str:
try:
conn = await asyncpg.connect(self.database_url)
comments = await conn.fetch(
"SELECT content FROM text_posts_comments WHERE post_id = $1 ORDER BY uploaded_at DESC LIMIT 5",
post_id
)
await conn.close()
if len(comments) == 0:
return "[Internal Message: No Comments on this post]"
formatted_comments = ""
for i, comment in enumerate(comments):
# Only add new lines before the comments. So last comment won't have extra new lines. Don't add before first comment obviously
if i > 0:
formatted_comments += "\n\n"
formatted_comments += f"<|Comment_{i}|>\n{comment['content']}"
return formatted_comments
except Exception as e:
return ["Internal Message: Server Error"]
async def comment_on_text_post(self, post_id: int, content: str) -> bool:
try:
conn = await asyncpg.connect(self.database_url)
success = await conn.fetchval("INSERT INTO text_posts_comments (post_id, content) VALUES ($1, $2) RETURNING id", post_id, content)
await conn.close()
return False if success is None else True
except Exception as e:
return False
async def disconnect(self) -> None:
if self.pool:
self.pool.close()