RizwanSajad commited on
Commit
a2bec5b
·
verified ·
1 Parent(s): 0535188

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -6
app.py CHANGED
@@ -2,13 +2,10 @@
2
  import gradio as gr
3
  import torch
4
  import pandas as pd
5
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
 
6
  import yfinance as yf
7
-
8
- # Load the fine-tuned RAG model and tokenizer
9
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
10
- retriever = RagRetriever.from_pretrained("facebook/rag-sequence-base", index_name="custom")
11
- model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)
12
 
13
  # Function to fetch and preprocess ICICI Bank data
14
  def fetch_and_preprocess_data():
@@ -22,6 +19,44 @@ def fetch_and_preprocess_data():
22
 
23
  return data
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Function to analyze trading data using the RAG model
26
  def analyze_trading_data(question):
27
  # Fetch and preprocess data
 
2
  import gradio as gr
3
  import torch
4
  import pandas as pd
5
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, RagConfig
6
+ from datasets import Dataset
7
  import yfinance as yf
8
+ import numpy as np
 
 
 
 
9
 
10
  # Function to fetch and preprocess ICICI Bank data
11
  def fetch_and_preprocess_data():
 
19
 
20
  return data
21
 
22
+ # Function to create and save a custom index for the retriever
23
+ def create_custom_index():
24
+ # Fetch and preprocess data
25
+ data = fetch_and_preprocess_data()
26
+
27
+ # Create a dataset for the retriever
28
+ dataset = Dataset.from_dict({
29
+ "id": [str(i) for i in range(len(data))],
30
+ "text": data.apply(lambda row: f"Date: {row.name}, Close: {row['Close']:.2f}, MA_50: {row['MA_50']:.2f}, MA_200: {row['MA_200']:.2f}", axis=1).tolist(),
31
+ "title": [f"ICICI Bank Data {i}" for i in range(len(data))]
32
+ })
33
+
34
+ # Save the dataset and index
35
+ dataset_path = "icici_bank_dataset"
36
+ index_path = "icici_bank_index"
37
+ dataset.save_to_disk(dataset_path)
38
+ dataset.add_faiss_index("text")
39
+ dataset.get_index("text").save(index_path)
40
+
41
+ return dataset_path, index_path
42
+
43
+ # Load the fine-tuned RAG model and tokenizer
44
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
45
+
46
+ # Create and save the custom index
47
+ dataset_path, index_path = create_custom_index()
48
+
49
+ # Load the retriever with the custom index
50
+ retriever = RagRetriever.from_pretrained(
51
+ "facebook/rag-sequence-base",
52
+ index_name="custom",
53
+ passages_path=dataset_path,
54
+ index_path=index_path
55
+ )
56
+
57
+ # Load the RAG model
58
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)
59
+
60
  # Function to analyze trading data using the RAG model
61
  def analyze_trading_data(question):
62
  # Fetch and preprocess data