File size: 2,898 Bytes
2c35026
 
 
 
9cd8995
 
2c35026
 
ef05d9e
2c35026
 
3c03f61
9cd8995
ef05d9e
9cd8995
8dba466
3c03f61
2c35026
 
 
ef05d9e
 
3c03f61
 
ef05d9e
 
3c03f61
ef05d9e
 
 
2c35026
 
 
 
 
 
 
8dba466
2c35026
 
 
 
3c03f61
ef05d9e
 
 
 
3c03f61
ef05d9e
 
 
 
 
 
3c03f61
ef05d9e
 
 
 
3c03f61
ef05d9e
3c03f61
ef05d9e
 
3c03f61
ef05d9e
 
2c35026
ef05d9e
2c35026
8dba466
2c35026
 
8dba466
 
 
2c35026
 
 
 
8dba466
 
2c35026
 
 
 
70d598e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from fastapi import FastAPI, Form, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates

from src.inference_lstm import inference_lstm
from src.inference_t5 import inference_t5


# ------ INFERENCE MODEL --------------------------------------------------------------
# appel de la fonction inference, adaptee pour une entree txt
def summarize(text: str):
    if choisir_modele.var == "lstm":
        return " ".join(inference_lstm(text))
    elif choisir_modele.var == "fineTunedT5":
        text = inference_t5(text)


# ----------------------------------------------------------------------------------


def choisir_modele(choixModele):
    print("ON A RECUP LE CHOIX MODELE")
    if choixModele == "lstm":
        choisir_modele.var = "lstm"
    elif choixModele == "fineTunedT5":
        choisir_modele.var = "fineTunedT5"
    else:
        "le modele n'est pas defini"


# -------- API ---------------------------------------------------------------------
app = FastAPI()

# static files pour envoi du css au navigateur
templates = Jinja2Templates(directory="templates")
app.mount("/templates", StaticFiles(directory="templates"), name="templates")


@app.get("/")
async def index(request: Request):
    return templates.TemplateResponse("index.html.jinja", {"request": request})


@app.get("/model")
async def index(request: Request):
    return templates.TemplateResponse("index.html.jinja", {"request": request})


@app.get("/predict")
async def index(request: Request):
    return templates.TemplateResponse("index.html.jinja", {"request": request})


@app.post("/model")
async def choix_model(request: Request, choixModel: str = Form(None)):
    print(choixModel)
    if not choixModel:
        erreur_modele = "Merci de saisir un modèle."
        return templates.TemplateResponse(
            "index.html.jinja", {"request": request, "text": erreur_modele}
        )
    else:
        choisir_modele(choixModel)
        print("C'est bon on utilise le modèle demandé")
        return templates.TemplateResponse("index.html.jinja", {"request": request})


# retourner le texte, les predictions et message d'erreur si formulaire envoye vide
@app.post("/predict")
async def prediction(request: Request, text: str = Form(None)):
    if not text:
        error = "Merci de saisir votre texte."
        return templates.TemplateResponse(
            "index.html.jinja", {"request": request, "text": error}
        )
    else:
        summary = summarize(text)
        return templates.TemplateResponse(
            "index.html.jinja", {"request": request, "text": text, "summary": summary}
        )


# ------------------------------------------------------------------------------------


# lancer le serveur et le recharger a chaque modification sauvegardee
# if __name__ == "__main__":
#     uvicorn.run("api:app", port=8000, reload=True)