Spaces:
Sleeping
Sleeping
KrSharangrav
commited on
Commit
Β·
f16063a
1
Parent(s):
0df2e14
more changes to all 3 py files
Browse files- app.py +6 -22
- chatbot.py +28 -21
- db.py +4 -8
app.py
CHANGED
@@ -1,29 +1,15 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
from db import insert_data_if_empty, get_mongo_client
|
4 |
-
from chatbot import chatbot_response #
|
5 |
|
6 |
-
#
|
7 |
insert_data_if_empty()
|
8 |
|
9 |
-
#
|
10 |
collection = get_mongo_client()
|
11 |
|
12 |
-
|
13 |
-
# st.title("π Historical Data and Chatbot Analysis")
|
14 |
-
# st.subheader("Historical Data from MongoDB")
|
15 |
-
# data = list(collection.find({}, {"_id": 0}).limit(5))
|
16 |
-
# if data:
|
17 |
-
# st.write(pd.DataFrame(data))
|
18 |
-
# else:
|
19 |
-
# st.warning("No data found in MongoDB. Please try refreshing.")
|
20 |
-
#
|
21 |
-
# if st.button("Show Complete Data"):
|
22 |
-
# all_data = list(collection.find({}, {"_id": 0}))
|
23 |
-
# st.write(pd.DataFrame(all_data))
|
24 |
-
|
25 |
-
# 3. Chatbot interface
|
26 |
-
st.subheader("π¬ Chatbot with Fine-Tuned Sentiment & Topic Analysis")
|
27 |
user_prompt = st.text_area("Ask me something:")
|
28 |
|
29 |
if st.button("Get AI Response"):
|
@@ -31,11 +17,9 @@ if st.button("Get AI Response"):
|
|
31 |
if ai_response:
|
32 |
st.write("### AI Response:")
|
33 |
st.write(ai_response)
|
34 |
-
|
35 |
st.write("### Sentiment Analysis:")
|
36 |
st.write(f"**Sentiment:** {sentiment_label} ({sentiment_confidence:.2f} confidence)")
|
37 |
-
|
38 |
-
st.write("### Topic Extraction:")
|
39 |
st.write(f"**Detected Category:** {topic_label} ({topic_confidence:.2f} confidence)")
|
40 |
else:
|
41 |
-
st.warning("Please enter
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
from db import insert_data_if_empty, get_mongo_client
|
4 |
+
from chatbot import chatbot_response # Import the updated chatbot functionality
|
5 |
|
6 |
+
# Ensure that historical data is in the database
|
7 |
insert_data_if_empty()
|
8 |
|
9 |
+
# Connect to MongoDB
|
10 |
collection = get_mongo_client()
|
11 |
|
12 |
+
st.subheader("π¬ Chatbot with Sentiment & Topic Analysis")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
user_prompt = st.text_area("Ask me something:")
|
14 |
|
15 |
if st.button("Get AI Response"):
|
|
|
17 |
if ai_response:
|
18 |
st.write("### AI Response:")
|
19 |
st.write(ai_response)
|
|
|
20 |
st.write("### Sentiment Analysis:")
|
21 |
st.write(f"**Sentiment:** {sentiment_label} ({sentiment_confidence:.2f} confidence)")
|
22 |
+
st.write("### Category Extraction:")
|
|
|
23 |
st.write(f"**Detected Category:** {topic_label} ({topic_confidence:.2f} confidence)")
|
24 |
else:
|
25 |
+
st.warning("β οΈ Please enter a question or text for analysis.")
|
chatbot.py
CHANGED
@@ -4,22 +4,37 @@ import google.generativeai as genai
|
|
4 |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
|
5 |
from datasets import load_dataset
|
6 |
|
7 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
GEMINI_API_KEY = os.getenv("gemini_api")
|
9 |
if GEMINI_API_KEY:
|
10 |
genai.configure(api_key=GEMINI_API_KEY)
|
11 |
else:
|
12 |
st.error("β οΈ Google API key is missing! Set it in Hugging Face Secrets.")
|
13 |
|
14 |
-
#
|
15 |
FINE_TUNED_MODEL_DIR = "fine-tuned-sentiment-model"
|
16 |
|
17 |
-
#
|
18 |
def fine_tune_model():
|
19 |
st.info("Fine-tuning sentiment model. This may take a while...")
|
20 |
|
21 |
# Load the dataset from the local CSV file.
|
22 |
-
# Ensure that 'sentiment140.csv' is in your working directory.
|
23 |
try:
|
24 |
dataset = load_dataset('csv', data_files={'train': 'sentiment140.csv'}, encoding='ISO-8859-1')
|
25 |
except Exception as e:
|
@@ -27,7 +42,6 @@ def fine_tune_model():
|
|
27 |
return None, None
|
28 |
|
29 |
# Convert sentiment labels: sentiment140 labels are 0 (Negative), 2 (Neutral), 4 (Positive).
|
30 |
-
# We map them to 0,1,2 respectively.
|
31 |
def convert_labels(example):
|
32 |
mapping = {0: 0, 2: 1, 4: 2}
|
33 |
example["label"] = mapping[int(example["target"])]
|
@@ -35,10 +49,7 @@ def fine_tune_model():
|
|
35 |
|
36 |
dataset = dataset.map(convert_labels)
|
37 |
|
38 |
-
# Base model name
|
39 |
base_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
|
40 |
-
|
41 |
-
# Initialize tokenizer and model
|
42 |
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
43 |
model = AutoModelForSequenceClassification.from_pretrained(base_model_name, num_labels=3)
|
44 |
|
@@ -47,10 +58,9 @@ def fine_tune_model():
|
|
47 |
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
|
48 |
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
49 |
|
50 |
-
# Set training arguments (for demo purposes, we use 1 epoch; adjust as needed)
|
51 |
training_args = TrainingArguments(
|
52 |
output_dir="./results",
|
53 |
-
num_train_epochs=1,
|
54 |
per_device_train_batch_size=8,
|
55 |
logging_steps=10,
|
56 |
save_steps=50,
|
@@ -69,13 +79,12 @@ def fine_tune_model():
|
|
69 |
|
70 |
trainer.train()
|
71 |
|
72 |
-
# Save the fine-tuned model and tokenizer
|
73 |
model.save_pretrained(FINE_TUNED_MODEL_DIR)
|
74 |
tokenizer.save_pretrained(FINE_TUNED_MODEL_DIR)
|
75 |
st.success("β
Fine-tuning complete and model saved.")
|
76 |
return model, tokenizer
|
77 |
|
78 |
-
# Load
|
79 |
if not os.path.exists(FINE_TUNED_MODEL_DIR):
|
80 |
model, tokenizer = fine_tune_model()
|
81 |
if model is None or tokenizer is None:
|
@@ -84,7 +93,7 @@ else:
|
|
84 |
tokenizer = AutoTokenizer.from_pretrained(FINE_TUNED_MODEL_DIR)
|
85 |
model = AutoModelForSequenceClassification.from_pretrained(FINE_TUNED_MODEL_DIR)
|
86 |
|
87 |
-
#
|
88 |
try:
|
89 |
sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
|
90 |
except Exception as e:
|
@@ -106,10 +115,8 @@ TOPIC_LABELS = [
|
|
106 |
def analyze_sentiment(text):
|
107 |
try:
|
108 |
sentiment_result = sentiment_pipeline(text)[0]
|
109 |
-
label = sentiment_result['label']
|
110 |
-
score = sentiment_result['score']
|
111 |
-
|
112 |
-
# Map model labels to human-readable format
|
113 |
sentiment_mapping = {
|
114 |
"LABEL_0": "Negative",
|
115 |
"LABEL_1": "Neutral",
|
@@ -123,13 +130,13 @@ def analyze_sentiment(text):
|
|
123 |
def extract_topic(text):
|
124 |
try:
|
125 |
topic_result = topic_pipeline(text, TOPIC_LABELS)
|
126 |
-
top_topic = topic_result["labels"][0]
|
127 |
confidence = topic_result["scores"][0]
|
128 |
return top_topic, confidence
|
129 |
except Exception as e:
|
130 |
return f"Error extracting topic: {e}", None
|
131 |
|
132 |
-
# Function to generate AI response
|
133 |
def chatbot_response(user_prompt):
|
134 |
if not user_prompt:
|
135 |
return None, None, None, None, None
|
@@ -139,10 +146,10 @@ def chatbot_response(user_prompt):
|
|
139 |
model_gen = genai.GenerativeModel("gemini-1.5-pro")
|
140 |
ai_response = model_gen.generate_content(user_prompt)
|
141 |
|
142 |
-
#
|
143 |
sentiment_label, sentiment_confidence = analyze_sentiment(user_prompt)
|
144 |
|
145 |
-
#
|
146 |
topic_label, topic_confidence = extract_topic(user_prompt)
|
147 |
|
148 |
return ai_response.text, sentiment_label, sentiment_confidence, topic_label, topic_confidence
|
|
|
4 |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
|
5 |
from datasets import load_dataset
|
6 |
|
7 |
+
# --- Monkey Patch for Accelerator ---
|
8 |
+
try:
|
9 |
+
import accelerate
|
10 |
+
from accelerate import Accelerator
|
11 |
+
import inspect
|
12 |
+
# If the Accelerator.__init__ does not accept "dispatch_batches", remove it from kwargs.
|
13 |
+
if 'dispatch_batches' not in inspect.signature(Accelerator.__init__).parameters:
|
14 |
+
old_init = Accelerator.__init__
|
15 |
+
def new_init(self, *args, **kwargs):
|
16 |
+
if 'dispatch_batches' in kwargs:
|
17 |
+
kwargs.pop('dispatch_batches')
|
18 |
+
old_init(self, *args, **kwargs)
|
19 |
+
Accelerator.__init__ = new_init
|
20 |
+
except Exception as e:
|
21 |
+
st.error(f"Error patching Accelerator: {e}")
|
22 |
+
|
23 |
+
# --- Configure Gemini API ---
|
24 |
GEMINI_API_KEY = os.getenv("gemini_api")
|
25 |
if GEMINI_API_KEY:
|
26 |
genai.configure(api_key=GEMINI_API_KEY)
|
27 |
else:
|
28 |
st.error("β οΈ Google API key is missing! Set it in Hugging Face Secrets.")
|
29 |
|
30 |
+
# Path to save/load the fine-tuned model
|
31 |
FINE_TUNED_MODEL_DIR = "fine-tuned-sentiment-model"
|
32 |
|
33 |
+
# --- Fine-tune the Sentiment Model ---
|
34 |
def fine_tune_model():
|
35 |
st.info("Fine-tuning sentiment model. This may take a while...")
|
36 |
|
37 |
# Load the dataset from the local CSV file.
|
|
|
38 |
try:
|
39 |
dataset = load_dataset('csv', data_files={'train': 'sentiment140.csv'}, encoding='ISO-8859-1')
|
40 |
except Exception as e:
|
|
|
42 |
return None, None
|
43 |
|
44 |
# Convert sentiment labels: sentiment140 labels are 0 (Negative), 2 (Neutral), 4 (Positive).
|
|
|
45 |
def convert_labels(example):
|
46 |
mapping = {0: 0, 2: 1, 4: 2}
|
47 |
example["label"] = mapping[int(example["target"])]
|
|
|
49 |
|
50 |
dataset = dataset.map(convert_labels)
|
51 |
|
|
|
52 |
base_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
|
|
|
|
|
53 |
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
54 |
model = AutoModelForSequenceClassification.from_pretrained(base_model_name, num_labels=3)
|
55 |
|
|
|
58 |
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
|
59 |
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
60 |
|
|
|
61 |
training_args = TrainingArguments(
|
62 |
output_dir="./results",
|
63 |
+
num_train_epochs=1, # For demonstration, we train for 1 epoch.
|
64 |
per_device_train_batch_size=8,
|
65 |
logging_steps=10,
|
66 |
save_steps=50,
|
|
|
79 |
|
80 |
trainer.train()
|
81 |
|
|
|
82 |
model.save_pretrained(FINE_TUNED_MODEL_DIR)
|
83 |
tokenizer.save_pretrained(FINE_TUNED_MODEL_DIR)
|
84 |
st.success("β
Fine-tuning complete and model saved.")
|
85 |
return model, tokenizer
|
86 |
|
87 |
+
# Load or fine-tune the sentiment model
|
88 |
if not os.path.exists(FINE_TUNED_MODEL_DIR):
|
89 |
model, tokenizer = fine_tune_model()
|
90 |
if model is None or tokenizer is None:
|
|
|
93 |
tokenizer = AutoTokenizer.from_pretrained(FINE_TUNED_MODEL_DIR)
|
94 |
model = AutoModelForSequenceClassification.from_pretrained(FINE_TUNED_MODEL_DIR)
|
95 |
|
96 |
+
# Create sentiment analysis pipeline from the fine-tuned model
|
97 |
try:
|
98 |
sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
|
99 |
except Exception as e:
|
|
|
115 |
def analyze_sentiment(text):
|
116 |
try:
|
117 |
sentiment_result = sentiment_pipeline(text)[0]
|
118 |
+
label = sentiment_result['label']
|
119 |
+
score = sentiment_result['score']
|
|
|
|
|
120 |
sentiment_mapping = {
|
121 |
"LABEL_0": "Negative",
|
122 |
"LABEL_1": "Neutral",
|
|
|
130 |
def extract_topic(text):
|
131 |
try:
|
132 |
topic_result = topic_pipeline(text, TOPIC_LABELS)
|
133 |
+
top_topic = topic_result["labels"][0]
|
134 |
confidence = topic_result["scores"][0]
|
135 |
return top_topic, confidence
|
136 |
except Exception as e:
|
137 |
return f"Error extracting topic: {e}", None
|
138 |
|
139 |
+
# Function to generate AI response along with sentiment and topic analysis
|
140 |
def chatbot_response(user_prompt):
|
141 |
if not user_prompt:
|
142 |
return None, None, None, None, None
|
|
|
146 |
model_gen = genai.GenerativeModel("gemini-1.5-pro")
|
147 |
ai_response = model_gen.generate_content(user_prompt)
|
148 |
|
149 |
+
# Sentiment Analysis
|
150 |
sentiment_label, sentiment_confidence = analyze_sentiment(user_prompt)
|
151 |
|
152 |
+
# Topic Extraction
|
153 |
topic_label, topic_confidence = extract_topic(user_prompt)
|
154 |
|
155 |
return ai_response.text, sentiment_label, sentiment_confidence, topic_label, topic_confidence
|
db.py
CHANGED
@@ -3,26 +3,22 @@ import requests
|
|
3 |
import io
|
4 |
from pymongo import MongoClient
|
5 |
|
6 |
-
# Function to
|
7 |
def get_mongo_client():
|
8 |
client = MongoClient("mongodb+srv://groupA:[email protected]/?retryWrites=true&w=majority&appName=SentimentCluster")
|
9 |
db = client["sentiment_db"]
|
10 |
return db["tweets"]
|
11 |
|
12 |
-
# Function to insert
|
13 |
def insert_data_if_empty():
|
14 |
collection = get_mongo_client()
|
15 |
-
|
16 |
if collection.count_documents({}) == 0:
|
17 |
-
print("π’ No data found
|
18 |
csv_url = "https://huggingface.co/spaces/sharangrav24/SentimentAnalysis/resolve/main/sentiment140.csv"
|
19 |
-
|
20 |
try:
|
21 |
response = requests.get(csv_url)
|
22 |
-
response.raise_for_status()
|
23 |
df = pd.read_csv(io.StringIO(response.text), encoding="ISO-8859-1")
|
24 |
-
|
25 |
-
# Insert dataset records into MongoDB
|
26 |
collection.insert_many(df.to_dict("records"))
|
27 |
print("β
Data Inserted into MongoDB!")
|
28 |
except Exception as e:
|
|
|
3 |
import io
|
4 |
from pymongo import MongoClient
|
5 |
|
6 |
+
# Function to connect to MongoDB
|
7 |
def get_mongo_client():
|
8 |
client = MongoClient("mongodb+srv://groupA:[email protected]/?retryWrites=true&w=majority&appName=SentimentCluster")
|
9 |
db = client["sentiment_db"]
|
10 |
return db["tweets"]
|
11 |
|
12 |
+
# Function to insert data if the collection is empty
|
13 |
def insert_data_if_empty():
|
14 |
collection = get_mongo_client()
|
|
|
15 |
if collection.count_documents({}) == 0:
|
16 |
+
print("π’ No data found. Inserting dataset...")
|
17 |
csv_url = "https://huggingface.co/spaces/sharangrav24/SentimentAnalysis/resolve/main/sentiment140.csv"
|
|
|
18 |
try:
|
19 |
response = requests.get(csv_url)
|
20 |
+
response.raise_for_status()
|
21 |
df = pd.read_csv(io.StringIO(response.text), encoding="ISO-8859-1")
|
|
|
|
|
22 |
collection.insert_many(df.to_dict("records"))
|
23 |
print("β
Data Inserted into MongoDB!")
|
24 |
except Exception as e:
|