File size: 4,155 Bytes
46bc116
 
663d47c
46bc116
c221934
a2bec5b
46bc116
a2bec5b
46bc116
663d47c
 
c221934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46bc116
a2bec5b
 
 
 
 
c221934
 
 
a2bec5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c221934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2bec5b
663d47c
 
c221934
 
 
663d47c
 
 
c221934
 
 
663d47c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46bc116
663d47c
 
46bc116
663d47c
46bc116
663d47c
46bc116
663d47c
46bc116
 
663d47c
 
 
 
 
 
 
46bc116
 
663d47c
46bc116
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# app.py
import gradio as gr
import torch
import pandas as pd
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from datasets import Dataset
import yfinance as yf
import numpy as np

# Function to fetch and preprocess ICICI Bank data
def fetch_and_preprocess_data():
    try:
        # Fetch ICICI Bank data using yfinance
        ticker = "ICICIBANK.BO"  # Use BSE symbol
        data = yf.download(ticker, start="2020-01-01", end="2023-01-01")
        
        if data.empty:
            raise ValueError("No data found for the given symbol.")
        
        # Calculate technical indicators
        data['MA_50'] = data['Close'].rolling(window=50).mean()
        data['MA_200'] = data['Close'].rolling(window=200).mean()
        
        return data
    except Exception as e:
        print(f"Error fetching data: {e}")
        return pd.DataFrame()  # Return an empty DataFrame if fetching fails

# Function to create and save a custom index for the retriever
def create_custom_index():
    # Fetch and preprocess data
    data = fetch_and_preprocess_data()
    
    if data.empty:
        raise ValueError("No data available to create the index.")
    
    # Create a dataset for the retriever
    dataset = Dataset.from_dict({
        "id": [str(i) for i in range(len(data))],
        "text": data.apply(lambda row: f"Date: {row.name}, Close: {row['Close']:.2f}, MA_50: {row['MA_50']:.2f}, MA_200: {row['MA_200']:.2f}", axis=1).tolist(),
        "title": [f"ICICI Bank Data {i}" for i in range(len(data))]
    })
    
    # Save the dataset and index
    dataset_path = "icici_bank_dataset"
    index_path = "icici_bank_index"
    dataset.save_to_disk(dataset_path)
    dataset.add_faiss_index("text")
    dataset.get_index("text").save(index_path)
    
    return dataset_path, index_path

# Load the fine-tuned RAG model and tokenizer
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")

try:
    # Create and save the custom index
    dataset_path, index_path = create_custom_index()
    
    # Load the retriever with the custom index
    retriever = RagRetriever.from_pretrained(
        "facebook/rag-sequence-base",
        index_name="custom",
        passages_path=dataset_path,
        index_path=index_path
    )
    
    # Load the RAG model
    model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)
except Exception as e:
    print(f"Error initializing model or retriever: {e}")
    retriever = None
    model = None

# Function to analyze trading data using the RAG model
def analyze_trading_data(question):
    if model is None or retriever is None:
        return "Error: Model or retriever is not initialized. Please check the logs."
    
    # Fetch and preprocess data
    data = fetch_and_preprocess_data()
    
    if data.empty:
        return "Error: No data available for analysis."
    
    # Prepare context for the RAG model
    context = (
        f"ICICI Bank stock data:\n"
        f"Latest Close Price: {data['Close'].iloc[-1]:.2f}\n"
        f"50-Day Moving Average: {data['MA_50'].iloc[-1]:.2f}\n"
        f"200-Day Moving Average: {data['MA_200'].iloc[-1]:.2f}\n"
    )
    
    # Combine question and context
    input_text = f"Question: {question}\nContext: {context}"
    
    # Tokenize the input
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
    
    # Generate the answer using the RAG model
    outputs = model.generate(inputs['input_ids'])
    
    # Decode the output to get the answer
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return answer

# Gradio interface
iface = gr.Interface(
    fn=analyze_trading_data,
    inputs="text",
    outputs="text",
    title="ICICI Bank Trading Analysis",
    description="Ask any question about ICICI Bank's trading data and get a detailed analysis.",
    examples=[
        "What is the current trend of ICICI Bank stock?",
        "Is the 50-day moving average above the 200-day moving average?",
        "What is the latest closing price of ICICI Bank?"
    ]
)

# Launch the app
iface.launch()