Viral-808 / utils /database.py
Sam Fred
Update database.py
e53edf3
raw
history blame
3.54 kB
import json
from typing import List, Dict
from sqlalchemy import create_engine, Column, Integer, String, Float, Boolean, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
DATABASE_URL = "postgresql://postgres.lgbnxplydqdymepehirg:[email protected]:5432/postgres"
Base = declarative_base()
# Define the posts table using SQLAlchemy ORM
class Post(Base):
__tablename__ = "posts"
id = Column(Integer, primary_key=True, index=True)
username = Column(String, nullable=False)
caption = Column(Text, nullable=True)
hashtags = Column(Text, nullable=True) # Store as JSON string
likes = Column(Integer, default=0)
comments = Column(Integer, default=0)
date = Column(String, nullable=True)
image_url = Column(String, unique=True, nullable=False)
engagement_rate = Column(Float, default=0.0)
viral_score = Column(Float, default=0.0)
promote = Column(Boolean, default=False)
# Initialize SQLAlchemy engine and session maker
engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_db():
"""
Initialize the PostgreSQL database by creating tables.
"""
Base.metadata.create_all(bind=engine)
print("Database initialized.")
def post_exists(session: Session, image_url: str) -> bool:
"""
Check if a post with the given image_url already exists in the database.
"""
return session.query(Post.id).filter(Post.image_url == image_url).first() is not None
def save_to_db(data: List[Dict]):
"""
Save data to the PostgreSQL database, avoiding duplicates.
"""
with SessionLocal() as session:
for post in data:
if not post_exists(session, post.get("image_url")):
new_post = Post(
username=post.get("username", ""),
caption=post.get("caption", ""),
hashtags=json.dumps(post.get("hashtags", [])), # Convert list to JSON string
likes=post.get("likes", 0),
comments=post.get("comments", 0),
date=post.get("date", ""),
image_url=post.get("image_url", ""),
engagement_rate=post.get("engagement_rate", 0.0),
viral_score=post.get("viral_score", 0.0),
promote=post.get("promote", False),
)
session.add(new_post)
session.commit()
print("Data saved to database.")
def fetch_posts_from_db(username: str) -> List[Dict]:
"""
Fetch posts from the database for a given username.
"""
with SessionLocal() as session:
posts = session.query(Post).filter(Post.username == username).all()
return [
{
"username": post.username,
"caption": post.caption,
"hashtags": json.loads(post.hashtags), # Convert JSON string back to list
"likes": post.likes,
"comments": post.comments,
"date": post.date,
"image_url": post.image_url,
"engagement_rate": post.engagement_rate,
"viral_score": post.viral_score,
"promote": post.promote,
}
for post in posts
]
def get_db():
"""
Dependency to get a database session.
"""
db = SessionLocal()
try:
yield db
finally:
db.close()