MFBDA's picture
Update app.py
d1d52f3 verified
import gradio as gr
from transformers import TimeSeriesTransformerForPrediction, TimeSeriesTransformerConfig
import torch
import numpy as np
# Carregar configuração
config = TimeSeriesTransformerConfig.from_pretrained("google/timesfm-2.0-500m-pytorch")
config.prediction_length = 3
config.context_length = 20 # Aumentado para acomodar os lags
config.lags_sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Lags menores que context_length
# Carregar modelo com a configuração ajustada
model = TimeSeriesTransformerForPrediction.from_pretrained(
"google/timesfm-2.0-500m-pytorch",
config=config,
torch_dtype="auto"
)
def prever_vendas(historico):
# Converter entrada em tensor
historico = [float(x) for x in historico.split(",") if x.strip()]
if len(historico) != config.context_length:
raise ValueError(f"Histórico deve ter {config.context_length} valores.")
# Formatar dados
inputs = torch.tensor(historico).unsqueeze(0)
# Adicionar parâmetros necessários
past_time_features = torch.zeros(1, config.context_length, 1) # Características temporais dummy
past_observed_mask = torch.ones(1, config.context_length) # Dados observados
# Gerar previsão
with torch.no_grad():
outputs = model(
inputs,
past_time_features=past_time_features,
past_observed_mask=past_observed_mask
)
forecast = outputs.mean.squeeze().tolist()
return np.round(forecast, 2)
# Interface Gradio
iface = gr.Interface(
fn=prever_vendas,
inputs=gr.Textbox(label=f"Histórico de Vendas ({config.context_length} meses, separados por vírgulas)"),
outputs=gr.Textbox(label=f"Previsão para os Próximos {config.prediction_length} Meses"),
examples=[
["140,155,160,145,150,165,170,160,175,160,155,170,180,190,200,210,220,230,240,250"], # 20 meses
],
cache_examples=False # Desativar cache para evitar erros
)
iface.launch()