Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from pydantic import BaseModel, Field | |
from fastapi.middleware.cors import CORSMiddleware | |
import logging | |
from model.argq import ArgqClassifier | |
from datetime import datetime | |
import firebase_admin | |
from firebase_admin import credentials, firestore | |
import uvicorn | |
from os import getenv, path | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
app = FastAPI(title="ArgQ Backend", version="0.0.1") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
logging.info("Starting application") | |
cred_file_path = path.join(path.dirname(__file__), "../credentials/firebase-adminsdk.json") | |
cred = credentials.Certificate(cred_file_path) | |
firebase_admin.initialize_app(cred) | |
db = firestore.client() | |
logging.info("Loading model..") | |
model = ArgqClassifier() | |
logging.info("Model loaded") | |
class Tweet(BaseModel): | |
text: str | |
class TextWithAspects(BaseModel): | |
tweet: Tweet | |
aspects: list = ["quality", "clarity", "organization", "credibility", "emotional_polarity", "emotional_intensity"] | |
class FeedbackItem(BaseModel): | |
text: str | |
timestamp: datetime = Field(default_factory=datetime.utcnow) | |
async def get_text_classification(tweet: Tweet): | |
classification = await model.classify_text(tweet.text) | |
return { | |
"classification": classification | |
} | |
async def get_text_classification_by_aspects(request: TextWithAspects): | |
classification = { | |
aspect: await model.classify_text_by_aspect(request.tweet.text, aspect) for aspect in request.aspects | |
} | |
return { | |
"classification": classification | |
} | |
async def post_feedback(item: FeedbackItem): | |
feedback_data = item.dict() | |
feedback_data['timestamp'] = feedback_data['timestamp'].isoformat() | |
doc_ref = db.collection('feedback').document() | |
doc_ref.set(feedback_data) | |
return {"status": "success", "feedback_received": feedback_data} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=int(getenv("PORT", 8000)), reload=True) |