MFBDA's picture
Update app.py
6ac2735 verified
raw
history blame
1.61 kB
import gradio as gr
from transformers import TimeSeriesTransformerForPrediction, TimeSeriesTransformerConfig
import pandas as pd
import numpy as np
# Carregar configuração do modelo
config = TimeSeriesTransformerConfig.from_pretrained("google/timesfm-2.0-500m-pytorch")
# Definir parâmetros obrigatórios
config.prediction_length = 3 # Períodos futuros a prever
config.context_length = 12 # Períodos históricos usados (ex: 12 meses)
# Carregar modelo com a configuração ajustada
model = TimeSeriesTransformerForPrediction.from_pretrained(
"google/timesfm-2.0-500m-pytorch",
config=config,
torch_dtype="auto"
)
def prever_vendas(historico):
# Converter entrada em lista de números
historico = [float(x) for x in historico.split(",")]
# Garantir que o histórico tem o tamanho do context_length
if len(historico) != config.context_length:
raise ValueError(f"Histórico deve ter {config.context_length} valores (context_length).")
# Preparar dados
data = pd.Series(historico)
# Gerar previsão
forecast = model.predict(data, prediction_length=config.prediction_length)
return np.round(forecast.mean, 2).tolist()
# Interface Gradio
iface = gr.Interface(
fn=prever_vendas,
inputs=gr.Textbox(label=f"Histórico de Vendas ({config.context_length} meses, separados por vírgulas)"),
outputs=gr.Textbox(label=f"Previsão para os Próximos {config.prediction_length} Meses"),
examples=[
["140,155,160,145,150,165,170,160,175,160,155,170"], # 12 meses (context_length=12)
]
)
iface.launch()