KrSharangrav commited on
Commit
be89ae1
·
1 Parent(s): f16063a

changes made further into the model with chatbot interacting with the dataset

Browse files
Files changed (2) hide show
  1. app.py +4 -4
  2. chatbot.py +37 -94
app.py CHANGED
@@ -1,15 +1,15 @@
1
  import streamlit as st
2
  import pandas as pd
3
  from db import insert_data_if_empty, get_mongo_client
4
- from chatbot import chatbot_response # Import the updated chatbot functionality
5
 
6
- # Ensure that historical data is in the database
7
  insert_data_if_empty()
8
 
9
- # Connect to MongoDB
10
  collection = get_mongo_client()
11
 
12
- st.subheader("💬 Chatbot with Sentiment & Topic Analysis")
13
  user_prompt = st.text_area("Ask me something:")
14
 
15
  if st.button("Get AI Response"):
 
1
  import streamlit as st
2
  import pandas as pd
3
  from db import insert_data_if_empty, get_mongo_client
4
+ from chatbot import chatbot_response
5
 
6
+ # Insert historical data into MongoDB if not already present
7
  insert_data_if_empty()
8
 
9
+ # Connect to MongoDB (available for further extension or analysis)
10
  collection = get_mongo_client()
11
 
12
+ st.subheader("💬 Chatbot with Sentiment, Topic Analysis, and Dataset Insights")
13
  user_prompt = st.text_area("Ask me something:")
14
 
15
  if st.button("Get AI Response"):
chatbot.py CHANGED
@@ -1,103 +1,23 @@
1
  import os
2
  import streamlit as st
3
  import google.generativeai as genai
4
- from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
5
- from datasets import load_dataset
6
 
7
- # --- Monkey Patch for Accelerator ---
8
- try:
9
- import accelerate
10
- from accelerate import Accelerator
11
- import inspect
12
- # If the Accelerator.__init__ does not accept "dispatch_batches", remove it from kwargs.
13
- if 'dispatch_batches' not in inspect.signature(Accelerator.__init__).parameters:
14
- old_init = Accelerator.__init__
15
- def new_init(self, *args, **kwargs):
16
- if 'dispatch_batches' in kwargs:
17
- kwargs.pop('dispatch_batches')
18
- old_init(self, *args, **kwargs)
19
- Accelerator.__init__ = new_init
20
- except Exception as e:
21
- st.error(f"Error patching Accelerator: {e}")
22
-
23
- # --- Configure Gemini API ---
24
  GEMINI_API_KEY = os.getenv("gemini_api")
25
  if GEMINI_API_KEY:
26
  genai.configure(api_key=GEMINI_API_KEY)
27
  else:
28
  st.error("⚠️ Google API key is missing! Set it in Hugging Face Secrets.")
29
 
30
- # Path to save/load the fine-tuned model
31
- FINE_TUNED_MODEL_DIR = "fine-tuned-sentiment-model"
32
-
33
- # --- Fine-tune the Sentiment Model ---
34
- def fine_tune_model():
35
- st.info("Fine-tuning sentiment model. This may take a while...")
36
-
37
- # Load the dataset from the local CSV file.
38
- try:
39
- dataset = load_dataset('csv', data_files={'train': 'sentiment140.csv'}, encoding='ISO-8859-1')
40
- except Exception as e:
41
- st.error(f"❌ Error loading dataset: {e}")
42
- return None, None
43
-
44
- # Convert sentiment labels: sentiment140 labels are 0 (Negative), 2 (Neutral), 4 (Positive).
45
- def convert_labels(example):
46
- mapping = {0: 0, 2: 1, 4: 2}
47
- example["label"] = mapping[int(example["target"])]
48
- return example
49
-
50
- dataset = dataset.map(convert_labels)
51
-
52
- base_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
53
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
54
- model = AutoModelForSequenceClassification.from_pretrained(base_model_name, num_labels=3)
55
-
56
- # Tokenize the dataset; assuming the CSV has a column named "text"
57
- def tokenize_function(examples):
58
- return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
59
- tokenized_dataset = dataset.map(tokenize_function, batched=True)
60
-
61
- training_args = TrainingArguments(
62
- output_dir="./results",
63
- num_train_epochs=1, # For demonstration, we train for 1 epoch.
64
- per_device_train_batch_size=8,
65
- logging_steps=10,
66
- save_steps=50,
67
- evaluation_strategy="no",
68
- learning_rate=2e-5,
69
- weight_decay=0.01,
70
- logging_dir='./logs',
71
- disable_tqdm=False
72
- )
73
-
74
- trainer = Trainer(
75
- model=model,
76
- args=training_args,
77
- train_dataset=tokenized_dataset["train"]
78
- )
79
-
80
- trainer.train()
81
-
82
- model.save_pretrained(FINE_TUNED_MODEL_DIR)
83
- tokenizer.save_pretrained(FINE_TUNED_MODEL_DIR)
84
- st.success("✅ Fine-tuning complete and model saved.")
85
- return model, tokenizer
86
-
87
- # Load or fine-tune the sentiment model
88
- if not os.path.exists(FINE_TUNED_MODEL_DIR):
89
- model, tokenizer = fine_tune_model()
90
- if model is None or tokenizer is None:
91
- st.error("❌ Failed to fine-tune the sentiment analysis model.")
92
- else:
93
- tokenizer = AutoTokenizer.from_pretrained(FINE_TUNED_MODEL_DIR)
94
- model = AutoModelForSequenceClassification.from_pretrained(FINE_TUNED_MODEL_DIR)
95
-
96
- # Create sentiment analysis pipeline from the fine-tuned model
97
  try:
 
 
98
  sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
99
  except Exception as e:
100
- st.error(f"❌ Error loading sentiment pipeline: {e}")
101
 
102
  # Load Topic Extraction Model
103
  try:
@@ -111,7 +31,7 @@ TOPIC_LABELS = [
111
  "Health", "Science", "Education", "Finance", "Travel", "Food"
112
  ]
113
 
114
- # Function to analyze sentiment
115
  def analyze_sentiment(text):
116
  try:
117
  sentiment_result = sentiment_pipeline(text)[0]
@@ -126,7 +46,7 @@ def analyze_sentiment(text):
126
  except Exception as e:
127
  return f"Error analyzing sentiment: {e}", None
128
 
129
- # Function to extract topic
130
  def extract_topic(text):
131
  try:
132
  topic_result = topic_pipeline(text, TOPIC_LABELS)
@@ -136,22 +56,45 @@ def extract_topic(text):
136
  except Exception as e:
137
  return f"Error extracting topic: {e}", None
138
 
139
- # Function to generate AI response along with sentiment and topic analysis
 
140
  def chatbot_response(user_prompt):
141
  if not user_prompt:
142
  return None, None, None, None, None
143
 
144
  try:
145
- # Generate AI Response using Gemini
146
  model_gen = genai.GenerativeModel("gemini-1.5-pro")
147
  ai_response = model_gen.generate_content(user_prompt)
148
 
149
- # Sentiment Analysis
150
  sentiment_label, sentiment_confidence = analyze_sentiment(user_prompt)
151
 
152
- # Topic Extraction
153
  topic_label, topic_confidence = extract_topic(user_prompt)
154
 
155
- return ai_response.text, sentiment_label, sentiment_confidence, topic_label, topic_confidence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  except Exception as e:
157
  return f"❌ Error: {e}", None, None, None, None
 
1
  import os
2
  import streamlit as st
3
  import google.generativeai as genai
4
+ from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
 
5
 
6
+ # Configure Gemini API key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  GEMINI_API_KEY = os.getenv("gemini_api")
8
  if GEMINI_API_KEY:
9
  genai.configure(api_key=GEMINI_API_KEY)
10
  else:
11
  st.error("⚠️ Google API key is missing! Set it in Hugging Face Secrets.")
12
 
13
+ # Load pre-trained sentiment analysis model
14
+ MODEL_NAME = "cardiffnlp/twitter-roberta-base-sentiment"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  try:
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
18
  sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
19
  except Exception as e:
20
+ st.error(f"❌ Error loading sentiment model: {e}")
21
 
22
  # Load Topic Extraction Model
23
  try:
 
31
  "Health", "Science", "Education", "Finance", "Travel", "Food"
32
  ]
33
 
34
+ # Function to analyze sentiment using the pre-trained model
35
  def analyze_sentiment(text):
36
  try:
37
  sentiment_result = sentiment_pipeline(text)[0]
 
46
  except Exception as e:
47
  return f"Error analyzing sentiment: {e}", None
48
 
49
+ # Function to extract topic using zero-shot classification
50
  def extract_topic(text):
51
  try:
52
  topic_result = topic_pipeline(text, TOPIC_LABELS)
 
56
  except Exception as e:
57
  return f"Error extracting topic: {e}", None
58
 
59
+ # Function to generate AI response along with sentiment and topic analysis.
60
+ # Also, if the query relates to the dataset, fetch statistics from MongoDB.
61
  def chatbot_response(user_prompt):
62
  if not user_prompt:
63
  return None, None, None, None, None
64
 
65
  try:
66
+ # Generate AI response using Gemini
67
  model_gen = genai.GenerativeModel("gemini-1.5-pro")
68
  ai_response = model_gen.generate_content(user_prompt)
69
 
70
+ # Perform sentiment analysis on the user prompt
71
  sentiment_label, sentiment_confidence = analyze_sentiment(user_prompt)
72
 
73
+ # Perform topic extraction on the user prompt
74
  topic_label, topic_confidence = extract_topic(user_prompt)
75
 
76
+ # If the prompt seems related to the dataset, get MongoDB statistics.
77
+ if any(keyword in user_prompt.lower() for keyword in ["sentiment140", "dataset", "historical", "mongodb", "stored data"]):
78
+ from db import get_mongo_client
79
+ collection = get_mongo_client()
80
+ # Aggregate counts by the 'target' field (assumed to be in the CSV)
81
+ pipeline = [
82
+ {"$group": {"_id": "$target", "count": {"$sum": 1}}}
83
+ ]
84
+ results = list(collection.aggregate(pipeline))
85
+ sentiment_map = {0: "Negative", 2: "Neutral", 4: "Positive"}
86
+ stats_str = ""
87
+ total = 0
88
+ for r in results:
89
+ key = sentiment_map.get(r["_id"], r["_id"])
90
+ count = r["count"]
91
+ total += count
92
+ stats_str += f"{key}: {count}\n"
93
+ stats_str += f"Total records: {total}"
94
+ ai_response_text = ai_response.text + "\n\nDataset Information:\n" + stats_str
95
+ else:
96
+ ai_response_text = ai_response.text
97
+
98
+ return ai_response_text, sentiment_label, sentiment_confidence, topic_label, topic_confidence
99
  except Exception as e:
100
  return f"❌ Error: {e}", None, None, None, None