File size: 4,779 Bytes
2b3d156
 
79a071f
 
2b3d156
79a071f
 
 
 
 
 
2b3d156
 
84f04ee
405db35
79a071f
 
2b3d156
7b2eca8
79a071f
 
 
d32cb1d
 
 
79a071f
7b2eca8
 
79a071f
 
 
2b3d156
 
 
 
 
623450c
5f629ed
31cc6ff
 
 
54c9e50
 
 
 
2b3d156
 
 
 
 
79a071f
 
 
3aa3d4c
2b3d156
 
 
 
 
79a071f
 
2b3d156
 
 
 
405db35
2b3d156
 
 
 
 
5f629ed
79a071f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# app/main.py

from fastapi import FastAPI, UploadFile, File, Request, Form, Query
from fastapi.responses import HTMLResponse, PlainTextResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from cbow_logic import MeaningCalculator
from ppo_logic import generate_summary

import numpy as np
import json
import shutil
from pathlib import Path
import uvicorn
import os
import praw
import random

from vit_captioning.generate import CaptionGenerator
from cbow_logic import MeaningCalculator

reddit = praw.Reddit(
    client_id=os.getenv("REDDIT_CLIENT_ID"),
    client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
    user_agent="script:ContentDistilleryBot:v0.1 (by u/ClementHa)"
)

app = FastAPI()
templates = Jinja2Templates(directory="templates")
calculator = MeaningCalculator()


# Serve static files
static_dir = Path(__file__).parent / "vit_captioning" / "static"
app.mount("/static", StaticFiles(directory=static_dir), name="static")

#Landing page at `/`
@app.get("/", response_class=HTMLResponse)
async def landing():
    return Path("vit_captioning/static/landing.html").read_text()

@app.get("/health")
def health_check():
    return {"status": "ok"}

# βœ… Captioning page at `/captioning`
@app.get("/captioning", response_class=HTMLResponse)
async def captioning():
    return Path("vit_captioning/static/captioning/index.html").read_text()

@app.get("/contentdistillery", response_class=HTMLResponse)
async def contentdistillery():
    return Path("content_distillery/static/content_distillery.html").read_text()

# βœ… Caption generation endpoint for captioning app
# Keep the path consistent with your JS fetch()!
caption_generator = CaptionGenerator(
    model_type="CLIPEncoder",
    checkpoint_path="./vit_captioning/artifacts/CLIPEncoder_40epochs_unfreeze12.pth",
    quantized=False,
    runAsContainer=False
)

@app.post("/generate")
async def generate(file: UploadFile = File(...)):
    temp_file = os.path.join("/tmp", file.filename)
    with open(temp_file, "wb") as buffer:
        shutil.copyfileobj(file.file, buffer)

    captions = caption_generator.generate_caption(temp_file)
    return captions

@app.get("/cbow", response_class=HTMLResponse)
async def cbow_form(request: Request):
    return templates.TemplateResponse("cbow.html", {"request": request})

@app.post("/cbow")
async def cbow(request: Request, expression: str = Form(...)):
    expression = expression.lower()
    results = MeaningCalculator().evaluate_expression(expression = expression)
    # formatted = [
    #     (word, f"{score:.2f}" if score >= 0.4 else "Irrelevant result")
    #     for word, score in results[:5]
    # ]
    return templates.TemplateResponse("cbow.html", {
        "request": request,
        "expression": expression,
        "results": results
    })

@app.get("/contentdistillery", response_class=HTMLResponse)
async def contentdistillery_page():
    return Path("contentdistillery.html").read_text(encoding="utf-8")

@app.post("/contentdistillery", response_class=PlainTextResponse)
async def generate_summary_from_post(post: str = Form(...)):
    return generate_summary(post)

@app.get("/get_sample", response_class=PlainTextResponse)
def get_sample(source: str = Query(...)):
    try:
        if source == "reddit_romance":
            submissions = reddit.subreddit("relationships").top(limit=10)
        elif source == "reddit_aita":
            submissions = reddit.subreddit("AmItheAsshole").hot(limit=10)
        elif source == "reddit_careers":
            submissions = reddit.subreddit("careerguidance").hot(limit=10)
        elif source == "reddit_cars":
            submissions = reddit.subreddit("cars").hot(limit=10)
        elif source == "reddit_whatcarshouldibuy":
            submissions = reddit.subreddit("whatcarshouldibuy").top(limit=10)
        elif source == "reddit_nosleep":
            submissions = reddit.subreddit("nosleep").top(limit=10)
        elif source == "reddit_maliciouscompliance":
            submissions = reddit.subreddit("maliciouscompliance").hot(limit=10)
        elif source == "reddit_talesfromtechsupport":
            submissions = reddit.subreddit("talesfromtechsupport").top(limit=10)
        elif source == "reddit_decidingtobebetter":
            submissions = reddit.subreddit("decidingtobebetter").hot(limit=10)
        elif source == "reddit_askphilosophy":
            submissions = reddit.subreddit("askphilosophy").top(limit=10)
        else:
            return "Unsupported source."

        posts = [s.selftext.strip() for s in submissions if s.selftext.strip()]
        if posts:
            return random.choice(posts)
        return "No suitable post found."

    except Exception as e:
        return f"Error fetching Reddit post: {str(e)}"