Spaces:
Sleeping
Sleeping
File size: 2,198 Bytes
3bc7eb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from fastapi import FastAPI
from pydantic import BaseModel, Field
from fastapi.middleware.cors import CORSMiddleware
import logging
from model.argq import ArgqClassifier
from datetime import datetime
import firebase_admin
from firebase_admin import credentials, firestore
import uvicorn
from os import getenv, path
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
app = FastAPI(title="ArgQ Backend", version="0.0.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logging.info("Starting application")
cred_file_path = path.join(path.dirname(__file__), "../credentials/firebase-adminsdk.json")
cred = credentials.Certificate(cred_file_path)
firebase_admin.initialize_app(cred)
db = firestore.client()
logging.info("Loading model..")
model = ArgqClassifier()
logging.info("Model loaded")
class Tweet(BaseModel):
text: str
class TextWithAspects(BaseModel):
tweet: Tweet
aspects: list = ["quality", "clarity", "organization", "credibility", "emotional_polarity", "emotional_intensity"]
class FeedbackItem(BaseModel):
text: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
@app.post("/argq/classify")
async def get_text_classification(tweet: Tweet):
classification = await model.classify_text(tweet.text)
return {
"classification": classification
}
@app.post("/argq/classify/aspects")
async def get_text_classification_by_aspects(request: TextWithAspects):
classification = {
aspect: await model.classify_text_by_aspect(request.tweet.text, aspect) for aspect in request.aspects
}
return {
"classification": classification
}
@app.post("/argq/feedback")
async def post_feedback(item: FeedbackItem):
feedback_data = item.dict()
feedback_data['timestamp'] = feedback_data['timestamp'].isoformat()
doc_ref = db.collection('feedback').document()
doc_ref.set(feedback_data)
return {"status": "success", "feedback_received": feedback_data}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(getenv("PORT", 8000)), reload=True) |