Spaces:
Runtime error
Runtime error
File size: 5,097 Bytes
7315e4e bde4352 3c3e49f bde4352 7315e4e 3c3e49f 7315e4e bde4352 7315e4e bde4352 7315e4e 3c3e49f bde4352 7315e4e bde4352 3c3e49f bde4352 7315e4e 3c3e49f 7315e4e 3c3e49f 7315e4e 3c3e49f 7315e4e 3c3e49f 7315e4e 3c3e49f 7315e4e 3c3e49f 7315e4e 3c3e49f 7315e4e 3c3e49f 7315e4e 3c3e49f bde4352 3c3e49f bde4352 7315e4e bde4352 7315e4e bde4352 7315e4e bde4352 3c3e49f 7315e4e bde4352 7315e4e 3c3e49f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import re
import uvicorn
from fastapi import FastAPI, Form, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from src.inference import inferenceAPI
from src.inference_t5 import inferenceAPI_t5
def summarize(text: str):
"""
Returns the summary of an input text.
Parameter
---------
text : str
A text to summarize.
Returns
-------
:str
The summary of the input text.
"""
if global_choose_model.var == "lstm":
text = " ".join(inferenceAPI(text))
return re.sub("^1|1$|<start>|<end>", "", text)
elif global_choose_model.var == "fineTunedT5":
text = inferenceAPI_t5(text)
return re.sub("<extra_id_0> ", "", text)
elif global_choose_model.var == "":
return "You have not chosen a model."
def global_choose_model(model_choice):
"""This function allows to connect the choice of the
model and the summary function by defining global variables.
The aime is to access a variable outside of a function."""
if model_choice == "lstm":
global_choose_model.var = "lstm"
elif model_choice == "fineTunedT5":
global_choose_model.var = "fineTunedT5"
elif model_choice == " --- ":
global_choose_model.var = ""
# definition of the main elements used in the script
model_list = [
{"model": " --- ", "name": " --- "},
{"model": "lstm", "name": "LSTM"},
{"model": "fineTunedT5", "name": "Fine-tuned T5"},
]
selected_model = " --- "
model_choice = ""
# -------- API ---------------------------------------------------------------
app = FastAPI()
# static files to send the css
templates = Jinja2Templates(directory="templates")
app.mount("/templates", StaticFiles(directory="templates"), name="templates")
@app.get("/")
async def index(request: Request):
"""This function is used to create an endpoint for the
index page of the app."""
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"current_route": "/",
"model_list": model_list,
"selected_model": selected_model,
},
)
@app.get("/model")
async def get_model(request: Request):
"""This function is used to create an endpoint for
the model page of the app."""
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"current_route": "/model",
"model_list": model_list,
"selected_model": selected_model,
},
)
@app.get("/predict")
async def get_prediction(request: Request):
"""This function is used to create an endpoint for
the predict page of the app."""
return templates.TemplateResponse(
"index.html.jinja", {"request": request, "current_route": "/predict"}
)
@app.post("/model")
async def choose_model(request: Request, model_choice: str = Form(None)):
"""This functions allows to retrieve the model chosen by the user. Then, it
can end to an error message if it not defined or it is sent to the
global_choose_model function which connects the user choice to the
use of a model."""
selected_model = model_choice
# print(selected_model)
if not model_choice:
model_error = "Please select a model."
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"text": model_error,
"model_list": model_list,
"selected_model": selected_model,
},
)
else:
global_choose_model(model_choice)
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"model_list": model_list,
"selected_model": selected_model,
},
)
@app.post("/predict")
async def prediction(request: Request, text: str = Form(None)):
"""This function allows to retrieve the input text of the user.
Then, it can end to an error message or it can be sent to
the summarize function."""
if not text:
text_error = "Please enter your text."
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"text": text_error,
"model_list": model_list,
"selected_model": selected_model,
},
)
else:
summary = summarize(text)
return templates.TemplateResponse(
"index.html.jinja",
{
"request": request,
"text": text,
"summary": summary,
"model_list": model_list,
"selected_model": selected_model,
},
)
# ------------------------------------------------------------------------------------
# launch the server and reload it each time a change is saved
if __name__ == "__main__":
uvicorn.run("api:app", port=8000, reload=True)
|