RizwanSajad commited on
Commit
96b2c8f
·
verified ·
1 Parent(s): 8e804f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -25
app.py CHANGED
@@ -1,35 +1,36 @@
1
- # app.py
2
  import gradio as gr
3
  import torch
4
  import pandas as pd
5
- from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
6
  import yfinance as yf
 
 
 
 
 
7
 
8
  # Function to fetch and preprocess ICICI Bank data
9
  def fetch_and_preprocess_data():
10
  try:
11
- # Fetch ICICI Bank data using yfinance
12
- ticker = "ICICIBANK.BO" # Use BSE symbol
13
  data = yf.download(ticker, start="2020-01-01", end="2023-01-01")
14
 
15
  if data.empty:
16
  raise ValueError("No data found for the given symbol.")
17
 
18
- # Calculate technical indicators
19
  data['MA_50'] = data['Close'].rolling(window=50).mean()
20
  data['MA_200'] = data['Close'].rolling(window=200).mean()
21
 
22
  return data
23
  except Exception as e:
24
  print(f"Error fetching data: {e}")
25
- return pd.DataFrame() # Return an empty DataFrame if fetching fails
26
 
27
- # Load the fine-tuned RAG model and tokenizer
28
  try:
29
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
30
  print("Tokenizer loaded successfully.")
31
 
32
- # Use a pre-built index (e.g., wiki_dpr) instead of creating a custom index
33
  retriever = RagRetriever.from_pretrained(
34
  "facebook/rag-sequence-base",
35
  index_name="wiki_dpr",
@@ -38,45 +39,44 @@ try:
38
  )
39
  print("Retriever loaded successfully.")
40
 
41
- # Load the RAG model
42
- model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever)
43
  print("Model loaded successfully.")
44
  except Exception as e:
45
- print(f"Error initializing model or retriever: {e}")
46
  retriever = None
47
  model = None
48
 
49
- # Function to analyze trading data using the RAG model
50
  def analyze_trading_data(question):
51
  if model is None or retriever is None:
52
  return "Error: Model or retriever is not initialized. Please check the logs."
53
-
54
  # Fetch and preprocess data
55
  data = fetch_and_preprocess_data()
56
-
57
  if data.empty:
58
  return "Error: No data available for analysis."
59
-
60
- # Prepare context for the RAG model
61
  context = (
62
  f"ICICI Bank stock data:\n"
63
  f"Latest Close Price: {data['Close'].iloc[-1]:.2f}\n"
64
  f"50-Day Moving Average: {data['MA_50'].iloc[-1]:.2f}\n"
65
  f"200-Day Moving Average: {data['MA_200'].iloc[-1]:.2f}\n"
66
  )
67
-
68
  # Combine question and context
69
  input_text = f"Question: {question}\nContext: {context}"
70
-
71
- # Tokenize the input
72
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
73
-
74
- # Generate the answer using the RAG model
75
  outputs = model.generate(inputs['input_ids'])
76
 
77
- # Decode the output to get the answer
78
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
-
80
  return answer
81
 
82
  # Gradio interface
@@ -94,4 +94,5 @@ iface = gr.Interface(
94
  )
95
 
96
  # Launch the app
97
- iface.launch()
 
 
 
1
  import gradio as gr
2
  import torch
3
  import pandas as pd
 
4
  import yfinance as yf
5
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
6
+
7
+ # Check if GPU is available
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ print(f"Using device: {device}")
10
 
11
  # Function to fetch and preprocess ICICI Bank data
12
  def fetch_and_preprocess_data():
13
  try:
14
+ ticker = "ICICIBANK.BO" # ICICI Bank BSE Symbol
 
15
  data = yf.download(ticker, start="2020-01-01", end="2023-01-01")
16
 
17
  if data.empty:
18
  raise ValueError("No data found for the given symbol.")
19
 
20
+ # Calculate Moving Averages
21
  data['MA_50'] = data['Close'].rolling(window=50).mean()
22
  data['MA_200'] = data['Close'].rolling(window=200).mean()
23
 
24
  return data
25
  except Exception as e:
26
  print(f"Error fetching data: {e}")
27
+ return pd.DataFrame() # Return empty DataFrame if fetching fails
28
 
29
+ # Load the RAG model and tokenizer with error handling
30
  try:
31
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")
32
  print("Tokenizer loaded successfully.")
33
 
 
34
  retriever = RagRetriever.from_pretrained(
35
  "facebook/rag-sequence-base",
36
  index_name="wiki_dpr",
 
39
  )
40
  print("Retriever loaded successfully.")
41
 
42
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=retriever).to(device)
 
43
  print("Model loaded successfully.")
44
  except Exception as e:
45
+ print(f"Error loading model or retriever: {e}")
46
  retriever = None
47
  model = None
48
 
49
+ # Function to analyze trading data
50
  def analyze_trading_data(question):
51
  if model is None or retriever is None:
52
  return "Error: Model or retriever is not initialized. Please check the logs."
53
+
54
  # Fetch and preprocess data
55
  data = fetch_and_preprocess_data()
56
+
57
  if data.empty:
58
  return "Error: No data available for analysis."
59
+
60
+ # Prepare context for RAG model
61
  context = (
62
  f"ICICI Bank stock data:\n"
63
  f"Latest Close Price: {data['Close'].iloc[-1]:.2f}\n"
64
  f"50-Day Moving Average: {data['MA_50'].iloc[-1]:.2f}\n"
65
  f"200-Day Moving Average: {data['MA_200'].iloc[-1]:.2f}\n"
66
  )
67
+
68
  # Combine question and context
69
  input_text = f"Question: {question}\nContext: {context}"
70
+
71
+ # Tokenize input
72
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(device)
73
+
74
+ # Generate answer using the model
75
  outputs = model.generate(inputs['input_ids'])
76
 
77
+ # Decode output
78
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
+
80
  return answer
81
 
82
  # Gradio interface
 
94
  )
95
 
96
  # Launch the app
97
+ if __name__ == "__main__":
98
+ iface.launch()