# app/main.py from fastapi import FastAPI, UploadFile, File, Request, Form, Query from fastapi.responses import HTMLResponse, PlainTextResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from cbow_logic import MeaningCalculator from ppo_logic import generate_summary import numpy as np import json import shutil from pathlib import Path import uvicorn import os import praw import random from vit_captioning.generate import CaptionGenerator from cbow_logic import MeaningCalculator reddit = praw.Reddit( client_id=os.getenv("REDDIT_CLIENT_ID"), client_secret=os.getenv("REDDIT_CLIENT_SECRET"), user_agent="script:ContentDistilleryBot:v0.1 (by u/ClementHa)" ) app = FastAPI() templates = Jinja2Templates(directory="templates") calculator = MeaningCalculator() # Serve static files static_dir = Path(__file__).parent / "vit_captioning" / "static" app.mount("/static", StaticFiles(directory=static_dir), name="static") #Landing page at `/` @app.get("/", response_class=HTMLResponse) async def landing(): return Path("vit_captioning/static/landing.html").read_text() @app.get("/health") def health_check(): return {"status": "ok"} # ✅ Captioning page at `/captioning` @app.get("/captioning", response_class=HTMLResponse) async def captioning(): return Path("vit_captioning/static/captioning/index.html").read_text() @app.get("/contentdistillery", response_class=HTMLResponse) async def contentdistillery(): return Path("content_distillery/static/content_distillery.html").read_text() # ✅ Caption generation endpoint for captioning app # Keep the path consistent with your JS fetch()! caption_generator = CaptionGenerator( model_type="CLIPEncoder", checkpoint_path="./vit_captioning/artifacts/CLIPEncoder_40epochs_unfreeze12.pth", quantized=False, runAsContainer=False ) @app.post("/generate") async def generate(file: UploadFile = File(...)): temp_file = os.path.join("/tmp", file.filename) with open(temp_file, "wb") as buffer: shutil.copyfileobj(file.file, buffer) captions = caption_generator.generate_caption(temp_file) return captions @app.get("/cbow", response_class=HTMLResponse) async def cbow_form(request: Request): return templates.TemplateResponse("cbow.html", {"request": request}) @app.post("/cbow") async def cbow(request: Request, expression: str = Form(...)): expression = expression.lower() results = MeaningCalculator().evaluate_expression(expression = expression) # formatted = [ # (word, f"{score:.2f}" if score >= 0.4 else "Irrelevant result") # for word, score in results[:5] # ] return templates.TemplateResponse("cbow.html", { "request": request, "expression": expression, "results": results }) @app.get("/contentdistillery", response_class=HTMLResponse) async def contentdistillery_page(): return Path("contentdistillery.html").read_text(encoding="utf-8") @app.post("/contentdistillery", response_class=PlainTextResponse) async def generate_summary_from_post(post: str = Form(...)): return generate_summary(post) @app.get("/get_sample", response_class=PlainTextResponse) def get_sample(source: str = Query(...)): try: if source == "reddit_romance": submissions = reddit.subreddit("relationships").top(limit=10) elif source == "reddit_aita": submissions = reddit.subreddit("AmItheAsshole").hot(limit=10) elif source == "reddit_careers": submissions = reddit.subreddit("careerguidance").hot(limit=10) elif source == "reddit_cars": submissions = reddit.subreddit("cars").hot(limit=10) elif source == "reddit_whatcarshouldibuy": submissions = reddit.subreddit("whatcarshouldibuy").top(limit=10) elif source == "reddit_nosleep": submissions = reddit.subreddit("nosleep").top(limit=10) elif source == "reddit_maliciouscompliance": submissions = reddit.subreddit("maliciouscompliance").hot(limit=10) elif source == "reddit_talesfromtechsupport": submissions = reddit.subreddit("talesfromtechsupport").top(limit=10) elif source == "reddit_decidingtobebetter": submissions = reddit.subreddit("decidingtobebetter").hot(limit=10) elif source == "reddit_askphilosophy": submissions = reddit.subreddit("askphilosophy").top(limit=10) else: return "Unsupported source." posts = [s.selftext.strip() for s in submissions if s.selftext.strip()] if posts: return random.choice(posts) return "No suitable post found." except Exception as e: return f"Error fetching Reddit post: {str(e)}"