Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,779 Bytes
2b3d156 79a071f 2b3d156 79a071f 2b3d156 84f04ee 405db35 79a071f 2b3d156 7b2eca8 79a071f d32cb1d 79a071f 7b2eca8 79a071f 2b3d156 623450c 5f629ed 31cc6ff 54c9e50 2b3d156 79a071f 3aa3d4c 2b3d156 79a071f 2b3d156 405db35 2b3d156 5f629ed 79a071f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# app/main.py
from fastapi import FastAPI, UploadFile, File, Request, Form, Query
from fastapi.responses import HTMLResponse, PlainTextResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from cbow_logic import MeaningCalculator
from ppo_logic import generate_summary
import numpy as np
import json
import shutil
from pathlib import Path
import uvicorn
import os
import praw
import random
from vit_captioning.generate import CaptionGenerator
from cbow_logic import MeaningCalculator
reddit = praw.Reddit(
client_id=os.getenv("REDDIT_CLIENT_ID"),
client_secret=os.getenv("REDDIT_CLIENT_SECRET"),
user_agent="script:ContentDistilleryBot:v0.1 (by u/ClementHa)"
)
app = FastAPI()
templates = Jinja2Templates(directory="templates")
calculator = MeaningCalculator()
# Serve static files
static_dir = Path(__file__).parent / "vit_captioning" / "static"
app.mount("/static", StaticFiles(directory=static_dir), name="static")
#Landing page at `/`
@app.get("/", response_class=HTMLResponse)
async def landing():
return Path("vit_captioning/static/landing.html").read_text()
@app.get("/health")
def health_check():
return {"status": "ok"}
# β
Captioning page at `/captioning`
@app.get("/captioning", response_class=HTMLResponse)
async def captioning():
return Path("vit_captioning/static/captioning/index.html").read_text()
@app.get("/contentdistillery", response_class=HTMLResponse)
async def contentdistillery():
return Path("content_distillery/static/content_distillery.html").read_text()
# β
Caption generation endpoint for captioning app
# Keep the path consistent with your JS fetch()!
caption_generator = CaptionGenerator(
model_type="CLIPEncoder",
checkpoint_path="./vit_captioning/artifacts/CLIPEncoder_40epochs_unfreeze12.pth",
quantized=False,
runAsContainer=False
)
@app.post("/generate")
async def generate(file: UploadFile = File(...)):
temp_file = os.path.join("/tmp", file.filename)
with open(temp_file, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
captions = caption_generator.generate_caption(temp_file)
return captions
@app.get("/cbow", response_class=HTMLResponse)
async def cbow_form(request: Request):
return templates.TemplateResponse("cbow.html", {"request": request})
@app.post("/cbow")
async def cbow(request: Request, expression: str = Form(...)):
expression = expression.lower()
results = MeaningCalculator().evaluate_expression(expression = expression)
# formatted = [
# (word, f"{score:.2f}" if score >= 0.4 else "Irrelevant result")
# for word, score in results[:5]
# ]
return templates.TemplateResponse("cbow.html", {
"request": request,
"expression": expression,
"results": results
})
@app.get("/contentdistillery", response_class=HTMLResponse)
async def contentdistillery_page():
return Path("contentdistillery.html").read_text(encoding="utf-8")
@app.post("/contentdistillery", response_class=PlainTextResponse)
async def generate_summary_from_post(post: str = Form(...)):
return generate_summary(post)
@app.get("/get_sample", response_class=PlainTextResponse)
def get_sample(source: str = Query(...)):
try:
if source == "reddit_romance":
submissions = reddit.subreddit("relationships").top(limit=10)
elif source == "reddit_aita":
submissions = reddit.subreddit("AmItheAsshole").hot(limit=10)
elif source == "reddit_careers":
submissions = reddit.subreddit("careerguidance").hot(limit=10)
elif source == "reddit_cars":
submissions = reddit.subreddit("cars").hot(limit=10)
elif source == "reddit_whatcarshouldibuy":
submissions = reddit.subreddit("whatcarshouldibuy").top(limit=10)
elif source == "reddit_nosleep":
submissions = reddit.subreddit("nosleep").top(limit=10)
elif source == "reddit_maliciouscompliance":
submissions = reddit.subreddit("maliciouscompliance").hot(limit=10)
elif source == "reddit_talesfromtechsupport":
submissions = reddit.subreddit("talesfromtechsupport").top(limit=10)
elif source == "reddit_decidingtobebetter":
submissions = reddit.subreddit("decidingtobebetter").hot(limit=10)
elif source == "reddit_askphilosophy":
submissions = reddit.subreddit("askphilosophy").top(limit=10)
else:
return "Unsupported source."
posts = [s.selftext.strip() for s in submissions if s.selftext.strip()]
if posts:
return random.choice(posts)
return "No suitable post found."
except Exception as e:
return f"Error fetching Reddit post: {str(e)}"
|