ai-lab / main.py
ClemSummer's picture
Handled reddit env info. Added back buttons.
d32cb1d
# app/main.py
from fastapi import FastAPI, UploadFile, File, Request, Form, Query
from fastapi.responses import HTMLResponse, PlainTextResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from cbow_logic import MeaningCalculator
from ppo_logic import generate_summary
import numpy as np
import json
import shutil
from pathlib import Path
import uvicorn
import os
import praw
import random
from vit_captioning.generate import CaptionGenerator
from cbow_logic import MeaningCalculator
reddit = praw.Reddit(
client_id=os.getenv("REDDIT_CLIENT_ID"),
client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
user_agent="script:ContentDistilleryBot:v0.1 (by u/ClementHa)"
)
app = FastAPI()
templates = Jinja2Templates(directory="templates")
calculator = MeaningCalculator()
# Serve static files
static_dir = Path(__file__).parent / "vit_captioning" / "static"
app.mount("/static", StaticFiles(directory=static_dir), name="static")
#Landing page at `/`
@app.get("/", response_class=HTMLResponse)
async def landing():
return Path("vit_captioning/static/landing.html").read_text()
@app.get("/health")
def health_check():
return {"status": "ok"}
# βœ… Captioning page at `/captioning`
@app.get("/captioning", response_class=HTMLResponse)
async def captioning():
return Path("vit_captioning/static/captioning/index.html").read_text()
@app.get("/contentdistillery", response_class=HTMLResponse)
async def contentdistillery():
return Path("content_distillery/static/content_distillery.html").read_text()
# βœ… Caption generation endpoint for captioning app
# Keep the path consistent with your JS fetch()!
caption_generator = CaptionGenerator(
model_type="CLIPEncoder",
checkpoint_path="./vit_captioning/artifacts/CLIPEncoder_40epochs_unfreeze12.pth",
quantized=False,
runAsContainer=False
)
@app.post("/generate")
async def generate(file: UploadFile = File(...)):
temp_file = os.path.join("/tmp", file.filename)
with open(temp_file, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
captions = caption_generator.generate_caption(temp_file)
return captions
@app.get("/cbow", response_class=HTMLResponse)
async def cbow_form(request: Request):
return templates.TemplateResponse("cbow.html", {"request": request})
@app.post("/cbow")
async def cbow(request: Request, expression: str = Form(...)):
expression = expression.lower()
results = MeaningCalculator().evaluate_expression(expression = expression)
# formatted = [
# (word, f"{score:.2f}" if score >= 0.4 else "Irrelevant result")
# for word, score in results[:5]
# ]
return templates.TemplateResponse("cbow.html", {
"request": request,
"expression": expression,
"results": results
})
@app.get("/contentdistillery", response_class=HTMLResponse)
async def contentdistillery_page():
return Path("contentdistillery.html").read_text(encoding="utf-8")
@app.post("/contentdistillery", response_class=PlainTextResponse)
async def generate_summary_from_post(post: str = Form(...)):
return generate_summary(post)
@app.get("/get_sample", response_class=PlainTextResponse)
def get_sample(source: str = Query(...)):
try:
if source == "reddit_romance":
submissions = reddit.subreddit("relationships").top(limit=10)
elif source == "reddit_aita":
submissions = reddit.subreddit("AmItheAsshole").hot(limit=10)
elif source == "reddit_careers":
submissions = reddit.subreddit("careerguidance").hot(limit=10)
elif source == "reddit_cars":
submissions = reddit.subreddit("cars").hot(limit=10)
elif source == "reddit_whatcarshouldibuy":
submissions = reddit.subreddit("whatcarshouldibuy").top(limit=10)
elif source == "reddit_nosleep":
submissions = reddit.subreddit("nosleep").top(limit=10)
elif source == "reddit_maliciouscompliance":
submissions = reddit.subreddit("maliciouscompliance").hot(limit=10)
elif source == "reddit_talesfromtechsupport":
submissions = reddit.subreddit("talesfromtechsupport").top(limit=10)
elif source == "reddit_decidingtobebetter":
submissions = reddit.subreddit("decidingtobebetter").hot(limit=10)
elif source == "reddit_askphilosophy":
submissions = reddit.subreddit("askphilosophy").top(limit=10)
else:
return "Unsupported source."
posts = [s.selftext.strip() for s in submissions if s.selftext.strip()]
if posts:
return random.choice(posts)
return "No suitable post found."
except Exception as e:
return f"Error fetching Reddit post: {str(e)}"