Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import torch | |
import pandas as pd | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
from datasets import Dataset | |
import yfinance as yf | |
import numpy as np | |
# Function to fetch and preprocess ICICI Bank data | |
def fetch_and_preprocess_data(): | |
try: | |
# Fetch ICICI Bank data using yfinance | |
ticker = "ICICIBANK.BO" # Use BSE symbol | |
data = yf.download(ticker, start="2020-01-01", end="2023-01-01") | |
if data.empty: | |
raise ValueError("No data found for the given symbol.") | |
# Calculate technical indicators | |
data['MA_50'] = data['Close'].rolling(window=50).mean() | |
data['MA_200'] = data['Close'].rolling(window=200).mean() | |
return data | |
except Exception as e: | |
print(f"Error fetching data: {e}") | |
return pd.DataFrame() # Return an empty DataFrame if fetching fails | |
# Function to create and save a custom index for the retriever | |
def create_custom_index(): | |
# Fetch and preprocess data | |
data = fetch_and_preprocess_data() | |
if data.empty: | |
raise ValueError("No data available to create the index.") | |
# Create a dataset for the retriever | |
dataset = Dataset.from_dict({ | |
"id": [str(i) for i in range(len(data))], | |
"text": data.apply(lambda row: f"Date: {row.name}, Close: {row['Close']:.2f}, MA_50: {row['MA_50']:.2f}, MA_200: {row['MA_200']:.2f}", axis=1).tolist(), | |
"title": [f"ICICI Bank Data {i}" for i in range(len(data))] | |
}) | |
# Save the dataset and index | |
dataset_path = "icici_bank_dataset" | |
index_path = "icici_bank_index" | |
dataset.save_to_disk(dataset_path) | |
dataset.add_faiss_index("text") | |
dataset.get_index("text").save(index_path) | |
return dataset_path, index_path | |
# Load the fine-tuned RAG model and tokenizer | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base") | |
try: | |
# Create and save the custom index | |
dataset_path, index_path = create_custom_index() | |
# Load the retriever with the custom index | |
retriever = RagRetriever.from_pretrained( | |
"facebook/rag-sequence-base", | |
index_name="custom", | |
passages_path=dataset_path, | |
index_path=index_path | |
) | |
# Load the RAG model | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever) | |
except Exception as e: | |
print(f"Error initializing model or retriever: {e}") | |
retriever = None | |
model = None | |
# Function to analyze trading data using the RAG model | |
def analyze_trading_data(question): | |
if model is None or retriever is None: | |
return "Error: Model or retriever is not initialized. Please check the logs." | |
# Fetch and preprocess data | |
data = fetch_and_preprocess_data() | |
if data.empty: | |
return "Error: No data available for analysis." | |
# Prepare context for the RAG model | |
context = ( | |
f"ICICI Bank stock data:\n" | |
f"Latest Close Price: {data['Close'].iloc[-1]:.2f}\n" | |
f"50-Day Moving Average: {data['MA_50'].iloc[-1]:.2f}\n" | |
f"200-Day Moving Average: {data['MA_200'].iloc[-1]:.2f}\n" | |
) | |
# Combine question and context | |
input_text = f"Question: {question}\nContext: {context}" | |
# Tokenize the input | |
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True) | |
# Generate the answer using the RAG model | |
outputs = model.generate(inputs['input_ids']) | |
# Decode the output to get the answer | |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return answer | |
# Gradio interface | |
iface = gr.Interface( | |
fn=analyze_trading_data, | |
inputs="text", | |
outputs="text", | |
title="ICICI Bank Trading Analysis", | |
description="Ask any question about ICICI Bank's trading data and get a detailed analysis.", | |
examples=[ | |
"What is the current trend of ICICI Bank stock?", | |
"Is the 50-day moving average above the 200-day moving average?", | |
"What is the latest closing price of ICICI Bank?" | |
] | |
) | |
# Launch the app | |
iface.launch() |