Trading_App / app.py
RizwanSajad's picture
Update app.py
c221934 verified
raw
history blame
4.16 kB
# app.py
import gradio as gr
import torch
import pandas as pd
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from datasets import Dataset
import yfinance as yf
import numpy as np
# Function to fetch and preprocess ICICI Bank data
def fetch_and_preprocess_data():
try:
# Fetch ICICI Bank data using yfinance
ticker = "ICICIBANK.BO" # Use BSE symbol
data = yf.download(ticker, start="2020-01-01", end="2023-01-01")
if data.empty:
raise ValueError("No data found for the given symbol.")
# Calculate technical indicators
data['MA_50'] = data['Close'].rolling(window=50).mean()
data['MA_200'] = data['Close'].rolling(window=200).mean()
return data
except Exception as e:
print(f"Error fetching data: {e}")
return pd.DataFrame() # Return an empty DataFrame if fetching fails
# Function to create and save a custom index for the retriever
def create_custom_index():
# Fetch and preprocess data
data = fetch_and_preprocess_data()
if data.empty:
raise ValueError("No data available to create the index.")
# Create a dataset for the retriever
dataset = Dataset.from_dict({
"id": [str(i) for i in range(len(data))],
"text": data.apply(lambda row: f"Date: {row.name}, Close: {row['Close']:.2f}, MA_50: {row['MA_50']:.2f}, MA_200: {row['MA_200']:.2f}", axis=1).tolist(),
"title": [f"ICICI Bank Data {i}" for i in range(len(data))]
})
# Save the dataset and index
dataset_path = "icici_bank_dataset"
index_path = "icici_bank_index"
dataset.save_to_disk(dataset_path)
dataset.add_faiss_index("text")
dataset.get_index("text").save(index_path)
return dataset_path, index_path
# Load the fine-tuned RAG model and tokenizer
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
try:
# Create and save the custom index
dataset_path, index_path = create_custom_index()
# Load the retriever with the custom index
retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-base",
index_name="custom",
passages_path=dataset_path,
index_path=index_path
)
# Load the RAG model
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)
except Exception as e:
print(f"Error initializing model or retriever: {e}")
retriever = None
model = None
# Function to analyze trading data using the RAG model
def analyze_trading_data(question):
if model is None or retriever is None:
return "Error: Model or retriever is not initialized. Please check the logs."
# Fetch and preprocess data
data = fetch_and_preprocess_data()
if data.empty:
return "Error: No data available for analysis."
# Prepare context for the RAG model
context = (
f"ICICI Bank stock data:\n"
f"Latest Close Price: {data['Close'].iloc[-1]:.2f}\n"
f"50-Day Moving Average: {data['MA_50'].iloc[-1]:.2f}\n"
f"200-Day Moving Average: {data['MA_200'].iloc[-1]:.2f}\n"
)
# Combine question and context
input_text = f"Question: {question}\nContext: {context}"
# Tokenize the input
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
# Generate the answer using the RAG model
outputs = model.generate(inputs['input_ids'])
# Decode the output to get the answer
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer
# Gradio interface
iface = gr.Interface(
fn=analyze_trading_data,
inputs="text",
outputs="text",
title="ICICI Bank Trading Analysis",
description="Ask any question about ICICI Bank's trading data and get a detailed analysis.",
examples=[
"What is the current trend of ICICI Bank stock?",
"Is the 50-day moving average above the 200-day moving average?",
"What is the latest closing price of ICICI Bank?"
]
)
# Launch the app
iface.launch()