File size: 3,263 Bytes
46bc116
 
663d47c
46bc116
c221934
46bc116
 
663d47c
 
c221934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46bc116
a2bec5b
c221934
8e804f7
 
 
 
c221934
 
8e804f7
 
 
c221934
4a05d26
8e804f7
c221934
 
4a05d26
c221934
 
 
 
a2bec5b
663d47c
 
c221934
 
 
663d47c
 
 
c221934
 
 
663d47c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46bc116
663d47c
 
46bc116
663d47c
46bc116
663d47c
46bc116
663d47c
46bc116
 
663d47c
 
 
 
 
 
 
46bc116
 
663d47c
46bc116
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# app.py
import gradio as gr
import torch
import pandas as pd
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
import yfinance as yf

# Function to fetch and preprocess ICICI Bank data
def fetch_and_preprocess_data():
    try:
        # Fetch ICICI Bank data using yfinance
        ticker = "ICICIBANK.BO"  # Use BSE symbol
        data = yf.download(ticker, start="2020-01-01", end="2023-01-01")
        
        if data.empty:
            raise ValueError("No data found for the given symbol.")
        
        # Calculate technical indicators
        data['MA_50'] = data['Close'].rolling(window=50).mean()
        data['MA_200'] = data['Close'].rolling(window=200).mean()
        
        return data
    except Exception as e:
        print(f"Error fetching data: {e}")
        return pd.DataFrame()  # Return an empty DataFrame if fetching fails

# Load the fine-tuned RAG model and tokenizer
try:
    tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
    print("Tokenizer loaded successfully.")

    # Use a pre-built index (e.g., wiki_dpr) instead of creating a custom index
    retriever = RagRetriever.from_pretrained(
        "facebook/rag-sequence-base",
        index_name="wiki_dpr",
        passages_path=None,
        index_path=None
    )
    print("Retriever loaded successfully.")

    # Load the RAG model
    model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error initializing model or retriever: {e}")
    retriever = None
    model = None

# Function to analyze trading data using the RAG model
def analyze_trading_data(question):
    if model is None or retriever is None:
        return "Error: Model or retriever is not initialized. Please check the logs."
    
    # Fetch and preprocess data
    data = fetch_and_preprocess_data()
    
    if data.empty:
        return "Error: No data available for analysis."
    
    # Prepare context for the RAG model
    context = (
        f"ICICI Bank stock data:\n"
        f"Latest Close Price: {data['Close'].iloc[-1]:.2f}\n"
        f"50-Day Moving Average: {data['MA_50'].iloc[-1]:.2f}\n"
        f"200-Day Moving Average: {data['MA_200'].iloc[-1]:.2f}\n"
    )
    
    # Combine question and context
    input_text = f"Question: {question}\nContext: {context}"
    
    # Tokenize the input
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
    
    # Generate the answer using the RAG model
    outputs = model.generate(inputs['input_ids'])
    
    # Decode the output to get the answer
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return answer

# Gradio interface
iface = gr.Interface(
    fn=analyze_trading_data,
    inputs="text",
    outputs="text",
    title="ICICI Bank Trading Analysis",
    description="Ask any question about ICICI Bank's trading data and get a detailed analysis.",
    examples=[
        "What is the current trend of ICICI Bank stock?",
        "Is the 50-day moving average above the 200-day moving average?",
        "What is the latest closing price of ICICI Bank?"
    ]
)

# Launch the app
iface.launch()