File size: 3,216 Bytes
31ac823
832da98
074307b
 
832da98
 
 
 
 
 
 
 
 
67781cf
074307b
 
 
 
 
 
31ac823
2e2dfb4
31ac823
 
8b5f94a
31ac823
 
 
7057428
67781cf
 
2e2dfb4
5b5151c
 
2e79a3d
7057428
5b5151c
2e79a3d
832da98
 
 
 
 
 
 
 
 
 
 
074307b
832da98
 
074307b
832da98
 
 
074307b
 
 
 
832da98
 
 
074307b
832da98
8b5f94a
67781cf
 
 
 
 
 
2e2dfb4
5b5151c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from datetime import datetime, date, timedelta
import plotly.graph_objects as go
import pandas as pd
from config import index_options, time_intervals, START_DATE, END_DATE, FORECAST_PERIOD
from data_fetcher import get_stocks_from_index
from stock_analysis import get_stock_graph_and_info

def validate_date(date_string):
    try:
        return datetime.strptime(date_string, "%Y-%m-%d").date()
    except ValueError:
        return None

def create_error_plot(error_message):
    fig = go.Figure()
    fig.add_annotation(x=0.5, y=0.5, text=error_message, showarrow=False, font_size=20)
    fig.update_layout(title="Error", xaxis_title="", yaxis_title="")
    return fig

demo = gr.Blocks()

with demo:
    d1 = gr.Dropdown(index_options, label='Please select Index...', info='Will be adding more indices later on', interactive=True)
    d2 = gr.Dropdown(label='Please Select Stock from your selected index', interactive=True)
    d3 = gr.Dropdown(time_intervals, label='Select Time Interval', value='1d', interactive=True)
    d4 = gr.Radio(['Line Graph', 'Candlestick Graph'], label='Select Graph Type', value='Line Graph', interactive=True)
    d5 = gr.Dropdown(['ARIMA', 'Prophet', 'LSTM'], label='Select Forecasting Method', value='ARIMA', interactive=True)
    
    date_start = gr.Textbox(label="Start Date (YYYY-MM-DD)", value=START_DATE.strftime("%Y-%m-%d"))
    date_end = gr.Textbox(label="End Date (YYYY-MM-DD)", value=END_DATE.strftime("%Y-%m-%d"))

    out_graph = gr.Plot()
    out_fundamentals = gr.DataFrame()

    inputs = [d1, d2, d3, d4, d5, date_start, date_end]
    outputs = [out_graph, out_fundamentals]

    def update_stock_options(index):
        stocks = get_stocks_from_index(index)
        return gr.Dropdown(choices=stocks)

    def process_inputs(*args):
        idx, stock, interval, graph_type, forecast_method, start_date, end_date = args
        
        start = validate_date(start_date)
        end = validate_date(end_date)
        
        if start is None or end is None:
            return create_error_plot("Invalid date format. Please use YYYY-MM-DD."), None
        
        if start > end:
            return create_error_plot("Start date must be before end date."), None
        
        try:
            fig, fundamentals = get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method, start, end)
            if not isinstance(fig, go.Figure):
                raise ValueError("Expected a plotly Figure object")
            if not isinstance(fundamentals, pd.DataFrame):
                raise ValueError("Expected a pandas DataFrame for fundamentals")
            return fig, fundamentals
        except Exception as e:
            error_message = f"An error occurred: {str(e)}"
            return create_error_plot(error_message), pd.DataFrame()

    d1.change(update_stock_options, d1, d2)
    d2.change(process_inputs, inputs, outputs)
    d3.change(process_inputs, inputs, outputs)
    d4.change(process_inputs, inputs, outputs)
    d5.change(process_inputs, inputs, outputs)
    date_start.change(process_inputs, inputs, outputs)
    date_end.change(process_inputs, inputs, outputs)

if __name__ == "__main__":
    demo.launch()