Spaces:
Runtime error
Runtime error
import uvicorn | |
from fastapi import FastAPI, Form, Request | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.templating import Jinja2Templates | |
import re | |
from inference import inferenceAPI | |
from inference_t5 import inferenceAPI_t5 | |
# ------ INFERENCE MODEL ------------------------------------------------------ | |
# appel de la fonction inference, adaptee pour une entree txt | |
def summarize(text: str): | |
if choisir_modele.var == "lstm": | |
return " ".join(inferenceAPI(text)) | |
elif choisir_modele.var == "fineTunedT5": | |
text = inferenceAPI_t5(text) | |
return re.sub("<extra_id_0> ", "", text) | |
# ---------------------------------------------------------------------------- | |
def choisir_modele(choixModele): | |
print("ON A RECUP LE CHOIX MODELE") | |
if choixModele == "lstm": | |
choisir_modele.var = "lstm" | |
elif choixModele == "fineTunedT5": | |
choisir_modele.var = "fineTunedT5" | |
# -------- API --------------------------------------------------------------- | |
app = FastAPI() | |
# static files pour envoi du css au navigateur | |
templates = Jinja2Templates(directory="templates") | |
app.mount("/templates", StaticFiles(directory="templates"), name="templates") | |
async def index(request: Request): | |
return templates.TemplateResponse( | |
"index.html.jinja", {"request": request, "current_route": "/"} | |
) | |
async def get_model(request: Request): | |
return templates.TemplateResponse( | |
"index.html.jinja", {"request": request, "current_route": "/model"} | |
) | |
async def get_prediction(request: Request): | |
return templates.TemplateResponse( | |
"index.html.jinja", {"request": request, "current_route": "/predict"} | |
) | |
async def choix_model(request: Request, choixModel: str = Form(None)): | |
print(choixModel) | |
if not choixModel: | |
erreur_modele = "Merci de saisir un modèle." | |
return templates.TemplateResponse( | |
"index.html.jinja", {"request": request, "text": erreur_modele} | |
) | |
else: | |
choisir_modele(choixModel) | |
print("C'est bon on utilise le modèle demandé") | |
return templates.TemplateResponse( | |
"index.html.jinja", {"request": request} | |
) | |
# retourner le texte, les predictions et message d'erreur si formulaire envoye | |
# vide | |
async def prediction(request: Request, text: str = Form(None)): | |
if not text: | |
error = "Merci de saisir votre texte." | |
return templates.TemplateResponse( | |
"index.html.jinja", {"request": request, "text": error} | |
) | |
else: | |
summary = summarize(text) | |
return templates.TemplateResponse( | |
"index.html.jinja", | |
{"request": request, "text": text, "summary": summary}, | |
) | |
# ------------------------------------------------------------------------------------ | |
# lancer le serveur et le recharger a chaque modification sauvegardee | |
if __name__ == "__main__": | |
uvicorn.run("api:app", port=8000, reload=True) | |