File size: 1,695 Bytes
0380be0
 
2ff7669
 
 
896731f
2ff7669
 
896731f
0380be0
2ff7669
 
0380be0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ff7669
6c6f69c
2ff7669
 
0380be0
6c6f69c
2ff7669
6c6f69c
c14b26c
8f698da
c480d50
b8c76b7
80acb4b
1ac9e6f
80acb4b
2ff7669
9edaeee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from keras.api.models import Sequential
from keras.api.layers import InputLayer, Dense
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
from typing import List

class InputData(BaseModel):
    data: List[float]  # Lista de caracter铆sticas num茅ricas (flotantes)

app = FastAPI()

# Funci贸n para construir el modelo manualmente
def build_model():
    model = Sequential(
        [
            InputLayer(
                input_shape=(2,), name="dense_2_input"
            ),  # Ajusta el tama帽o de entrada seg煤n tu modelo
            Dense(16, activation="relu", name="dense_2"),
            Dense(1, activation="sigmoid", name="dense_3"),
        ]
    )
    model.load_weights(
        "model.h5"
    )  # Aseg煤rate de que los nombres de las capas coincidan para que los pesos se carguen correctamente
    model.compile(
        loss="mean_squared_error", optimizer="adam", metrics=["binary_accuracy"]
    )
    return model


model = build_model()  # Construir el modelo al iniciar la aplicaci贸n


# Ruta de predicci贸n
@app.post("/predict/")
async def predict(data: InputData):
    print(f"Data: {data}")
    global model
    try:
        # Convertir la lista de entrada a un array de NumPy para la predicci贸n
        input_data = np.array(data.data).reshape(1, -1)  # Asumiendo que la entrada debe ser de forma (1, num_features)
        print(input_data)
        prediction = model.predict(input_data).round()
        #return {"prediction": prediction.tolist()}
        prediction = 9
        print(prediction)
        return {"prediction": prediction}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))