TOPSInfosol commited on
Commit
5f8097d
·
verified ·
1 Parent(s): 96311cb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Iterable
3
+ import gradio as gr
4
+ from gradio.themes.base import Base
5
+ from gradio.themes.utils import colors, fonts, sizes
6
+ import time
7
+
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+
12
+ import torch
13
+ from chronos import ChronosPipeline
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+
17
+ class Seafoam(Base):
18
+ def __init__(
19
+ self,
20
+ *,
21
+ primary_hue: colors.Color | str = colors.emerald,
22
+ secondary_hue: colors.Color | str = colors.blue,
23
+ neutral_hue: colors.Color | str = colors.blue,
24
+ spacing_size: sizes.Size | str = sizes.spacing_md,
25
+ radius_size: sizes.Size | str = sizes.radius_md,
26
+ text_size: sizes.Size | str = sizes.text_lg,
27
+ font: fonts.Font
28
+ | str
29
+ | Iterable[fonts.Font | str] = (
30
+ fonts.GoogleFont("Quicksand"),
31
+ "ui-sans-serif",
32
+ "sans-serif",
33
+ ),
34
+ font_mono: fonts.Font
35
+ | str
36
+ | Iterable[fonts.Font | str] = (
37
+ fonts.GoogleFont("IBM Plex Mono"),
38
+ "ui-monospace",
39
+ "monospace",
40
+ ),
41
+ ):
42
+ super().__init__(
43
+ primary_hue=primary_hue,
44
+ secondary_hue=secondary_hue,
45
+ neutral_hue=neutral_hue,
46
+ spacing_size=spacing_size,
47
+ radius_size=radius_size,
48
+ text_size=text_size,
49
+ font=font,
50
+ font_mono=font_mono,
51
+ )
52
+ super().set(
53
+ body_background_fill="repeating-linear-gradient(45deg, *primary_200, *primary_200 10px, *primary_50 10px, *primary_50 20px)",
54
+ body_background_fill_dark="repeating-linear-gradient(45deg, *primary_800, *primary_800 10px, *primary_900 10px, *primary_900 20px)",
55
+ button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
56
+ button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
57
+ button_primary_text_color="white",
58
+ button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
59
+ slider_color="*secondary_300",
60
+ slider_color_dark="*secondary_600",
61
+ block_title_text_weight="600",
62
+ block_border_width="3px",
63
+ block_shadow="*shadow_drop_lg",
64
+ button_primary_shadow="*shadow_drop_lg",
65
+ button_large_padding="32px",
66
+ )
67
+
68
+ seafoam = Seafoam()
69
+
70
+
71
+ import numpy as np
72
+ import matplotlib.ticker as ticker
73
+
74
+ def process_data(csv_file):
75
+ try:
76
+ # Read the CSV file
77
+ df = pd.read_csv(csv_file.name)
78
+
79
+ df['date'] = pd.to_datetime(df['date'])
80
+ df['month'] = df['date'].dt.month
81
+ df['year'] = df['date'].dt.year
82
+
83
+ monthly_sales = df.groupby(['year', 'month'])['sold_qty'].sum().reset_index()
84
+ monthly_sales = monthly_sales.rename(columns={'year': 'year', 'month': 'month', 'sold_qty': 'y'})
85
+
86
+ pipeline = ChronosPipeline.from_pretrained(
87
+ "amazon/chronos-t5-base",
88
+ device_map="cpu",
89
+ torch_dtype=torch.float32,
90
+ )
91
+ context = torch.tensor(monthly_sales["y"])
92
+ prediction_length = 12
93
+ forecast = pipeline.predict(context, prediction_length)
94
+
95
+ # Prepare forecast data
96
+ forecast_index = range(len(monthly_sales), len(monthly_sales) + prediction_length)
97
+ low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
98
+
99
+ # Visualization
100
+ plt.figure(figsize=(30, 10))
101
+ plt.plot(monthly_sales["y"], color="royalblue", label="Historical Data", linewidth=2)
102
+ plt.plot(forecast_index, median, color="tomato", label="Median Forecast", linewidth=2)
103
+ plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval")
104
+ plt.title("Sales Forecasting Visualization", fontsize=16)
105
+ plt.xlabel("Months", fontsize=20)
106
+ plt.ylabel("Sold Qty", fontsize=20)
107
+
108
+ plt.xticks(fontsize=18)
109
+ plt.yticks(fontsize=18)
110
+
111
+ ax = plt.gca()
112
+ ax.xaxis.set_major_locator(ticker.MultipleLocator(3))
113
+ ax.yaxis.set_major_locator(ticker.MultipleLocator(5))
114
+ ax.grid(which='major', linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
115
+
116
+ plt.legend(fontsize=18)
117
+ plt.grid(linestyle='--', linewidth=1.2, color='gray', alpha=0.7)
118
+ plt.tight_layout()
119
+
120
+ return plt.gcf()
121
+
122
+ except Exception as e:
123
+ print(f"Error: {str(e)}")
124
+ return None
125
+
126
+ # Create Gradio interface
127
+ with gr.Blocks(theme=seafoam) as demo:
128
+ gr.Markdown("# Chronos Forecasting - Tops Infosolution Pvt. Ltd")
129
+ gr.Markdown("Upload a CSV file and click 'Forecast' to generate sales forecast for next 12 months .")
130
+
131
+ with gr.Row():
132
+ file_input = gr.File(label="Upload CSV File", file_types=[".csv"])
133
+
134
+ with gr.Row():
135
+ visualize_btn = gr.Button("Forecast", variant="primary")
136
+
137
+ with gr.Row():
138
+ plot_output = gr.Plot(label="Chronos Forecasting Visualization")
139
+
140
+ with gr.Row():
141
+ plot_output = gr.Plot(label="Chronos Forecasting Visualization")
142
+
143
+ visualize_btn.click(
144
+ fn=process_data,
145
+ inputs=[file_input],
146
+ outputs=[plot_output]
147
+ )
148
+
149
+ # Launch the app
150
+ if __name__ == "__main__":
151
+ demo.launch()