Spaces:
Runtime error
Runtime error
File size: 2,898 Bytes
2c35026 9cd8995 2c35026 ef05d9e 2c35026 3c03f61 9cd8995 ef05d9e 9cd8995 8dba466 3c03f61 2c35026 ef05d9e 3c03f61 ef05d9e 3c03f61 ef05d9e 2c35026 8dba466 2c35026 3c03f61 ef05d9e 3c03f61 ef05d9e 3c03f61 ef05d9e 3c03f61 ef05d9e 3c03f61 ef05d9e 3c03f61 ef05d9e 2c35026 ef05d9e 2c35026 8dba466 2c35026 8dba466 2c35026 8dba466 2c35026 70d598e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from fastapi import FastAPI, Form, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from src.inference_lstm import inference_lstm
from src.inference_t5 import inference_t5
# ------ INFERENCE MODEL --------------------------------------------------------------
# appel de la fonction inference, adaptee pour une entree txt
def summarize(text: str):
if choisir_modele.var == "lstm":
return " ".join(inference_lstm(text))
elif choisir_modele.var == "fineTunedT5":
text = inference_t5(text)
# ----------------------------------------------------------------------------------
def choisir_modele(choixModele):
print("ON A RECUP LE CHOIX MODELE")
if choixModele == "lstm":
choisir_modele.var = "lstm"
elif choixModele == "fineTunedT5":
choisir_modele.var = "fineTunedT5"
else:
"le modele n'est pas defini"
# -------- API ---------------------------------------------------------------------
app = FastAPI()
# static files pour envoi du css au navigateur
templates = Jinja2Templates(directory="templates")
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
@app.get("/")
async def index(request: Request):
return templates.TemplateResponse("index.html.jinja", {"request": request})
@app.get("/model")
async def index(request: Request):
return templates.TemplateResponse("index.html.jinja", {"request": request})
@app.get("/predict")
async def index(request: Request):
return templates.TemplateResponse("index.html.jinja", {"request": request})
@app.post("/model")
async def choix_model(request: Request, choixModel: str = Form(None)):
print(choixModel)
if not choixModel:
erreur_modele = "Merci de saisir un modèle."
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "text": erreur_modele}
)
else:
choisir_modele(choixModel)
print("C'est bon on utilise le modèle demandé")
return templates.TemplateResponse("index.html.jinja", {"request": request})
# retourner le texte, les predictions et message d'erreur si formulaire envoye vide
@app.post("/predict")
async def prediction(request: Request, text: str = Form(None)):
if not text:
error = "Merci de saisir votre texte."
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "text": error}
)
else:
summary = summarize(text)
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "text": text, "summary": summary}
)
# ------------------------------------------------------------------------------------
# lancer le serveur et le recharger a chaque modification sauvegardee
# if __name__ == "__main__":
# uvicorn.run("api:app", port=8000, reload=True)
|