Spaces:
Sleeping
Sleeping
KrSharangrav
commited on
Commit
Β·
f5b718b
1
Parent(s):
b83a640
fine tuning changes
Browse files- app.py +20 -28
- chatbot.py +87 -21
- db.py +5 -6
app.py
CHANGED
@@ -1,23 +1,33 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
from db import insert_data_if_empty, get_mongo_client
|
4 |
-
from chatbot import chatbot_response #
|
5 |
|
6 |
-
# 1. Ensure
|
7 |
insert_data_if_empty()
|
8 |
|
9 |
-
# 2. MongoDB
|
10 |
collection = get_mongo_client()
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
user_prompt = st.text_area("Ask me something:")
|
16 |
|
17 |
if st.button("Get AI Response"):
|
18 |
-
# Generate real-time AI response, sentiment, and topic extraction
|
19 |
ai_response, sentiment_label, sentiment_confidence, topic_label, topic_confidence = chatbot_response(user_prompt)
|
20 |
-
|
21 |
if ai_response:
|
22 |
st.write("### AI Response:")
|
23 |
st.write(ai_response)
|
@@ -25,25 +35,7 @@ if st.button("Get AI Response"):
|
|
25 |
st.write("### Sentiment Analysis:")
|
26 |
st.write(f"**Sentiment:** {sentiment_label} ({sentiment_confidence:.2f} confidence)")
|
27 |
|
28 |
-
st.write("###
|
29 |
st.write(f"**Detected Category:** {topic_label} ({topic_confidence:.2f} confidence)")
|
30 |
-
|
31 |
-
# 3. Historical Insight: Compare with historical data from sentiment140.csv
|
32 |
-
historical_data = list(collection.find({}, {"_id": 0}))
|
33 |
-
if historical_data:
|
34 |
-
df = pd.DataFrame(historical_data)
|
35 |
-
# Assume the CSV has a 'sentiment' column with numeric labels:
|
36 |
-
# 0: Negative, 2: Neutral, 4: Positive.
|
37 |
-
if 'sentiment' in df.columns:
|
38 |
-
sentiment_map = {0: "Negative", 2: "Neutral", 4: "Positive"}
|
39 |
-
# Ensure the sentiment column is numeric and map it to readable labels
|
40 |
-
df['sentiment_label'] = df['sentiment'].astype(int).map(sentiment_map)
|
41 |
-
matching_count = df[df['sentiment_label'] == sentiment_label].shape[0]
|
42 |
-
st.write("### Historical Insights:")
|
43 |
-
st.info(f"There are {matching_count} tweets in our dataset with a {sentiment_label} sentiment.")
|
44 |
-
else:
|
45 |
-
st.warning("Historical data does not contain a sentiment field.")
|
46 |
-
else:
|
47 |
-
st.warning("No historical data available.")
|
48 |
else:
|
49 |
-
st.warning("
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
from db import insert_data_if_empty, get_mongo_client
|
4 |
+
from chatbot import chatbot_response # Updated chatbot functionality using the fine-tuned model
|
5 |
|
6 |
+
# 1. Ensure historical data is loaded into MongoDB
|
7 |
insert_data_if_empty()
|
8 |
|
9 |
+
# 2. Connect to MongoDB collection (for potential historical data display)
|
10 |
collection = get_mongo_client()
|
11 |
|
12 |
+
# Optional: Display historical data from the dataset (uncomment if needed)
|
13 |
+
# st.title("π Historical Data and Chatbot Analysis")
|
14 |
+
# st.subheader("Historical Data from MongoDB")
|
15 |
+
# data = list(collection.find({}, {"_id": 0}).limit(5))
|
16 |
+
# if data:
|
17 |
+
# st.write(pd.DataFrame(data))
|
18 |
+
# else:
|
19 |
+
# st.warning("No data found in MongoDB. Please try refreshing.")
|
20 |
+
#
|
21 |
+
# if st.button("Show Complete Data"):
|
22 |
+
# all_data = list(collection.find({}, {"_id": 0}))
|
23 |
+
# st.write(pd.DataFrame(all_data))
|
24 |
+
|
25 |
+
# 3. Chatbot interface
|
26 |
+
st.subheader("π¬ Chatbot with Fine-Tuned Sentiment & Topic Analysis")
|
27 |
user_prompt = st.text_area("Ask me something:")
|
28 |
|
29 |
if st.button("Get AI Response"):
|
|
|
30 |
ai_response, sentiment_label, sentiment_confidence, topic_label, topic_confidence = chatbot_response(user_prompt)
|
|
|
31 |
if ai_response:
|
32 |
st.write("### AI Response:")
|
33 |
st.write(ai_response)
|
|
|
35 |
st.write("### Sentiment Analysis:")
|
36 |
st.write(f"**Sentiment:** {sentiment_label} ({sentiment_confidence:.2f} confidence)")
|
37 |
|
38 |
+
st.write("### Topic Extraction:")
|
39 |
st.write(f"**Detected Category:** {topic_label} ({topic_confidence:.2f} confidence)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
else:
|
41 |
+
st.warning("Please enter some text for analysis.")
|
chatbot.py
CHANGED
@@ -1,28 +1,94 @@
|
|
1 |
import os
|
2 |
import streamlit as st
|
3 |
import google.generativeai as genai
|
4 |
-
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
|
5 |
-
|
6 |
-
# Disable HF transfer mechanism to avoid rate limiting errors
|
7 |
-
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
|
8 |
|
9 |
# π Fetch API key from Hugging Face Secrets
|
10 |
GEMINI_API_KEY = os.getenv("gemini_api")
|
11 |
-
|
12 |
if GEMINI_API_KEY:
|
13 |
genai.configure(api_key=GEMINI_API_KEY)
|
14 |
else:
|
15 |
st.error("β οΈ Google API key is missing! Set it in Hugging Face Secrets.")
|
16 |
|
17 |
-
#
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
#
|
21 |
try:
|
22 |
-
|
23 |
-
sentiment_pipeline = pipeline("sentiment-analysis", model=MODEL_NAME, tokenizer=tokenizer)
|
24 |
except Exception as e:
|
25 |
-
st.error(f"β Error loading sentiment
|
26 |
|
27 |
# Load Topic Extraction Model
|
28 |
try:
|
@@ -40,10 +106,10 @@ TOPIC_LABELS = [
|
|
40 |
def analyze_sentiment(text):
|
41 |
try:
|
42 |
sentiment_result = sentiment_pipeline(text)[0]
|
43 |
-
label = sentiment_result['label'] #
|
44 |
-
score = sentiment_result['score'] #
|
45 |
|
46 |
-
#
|
47 |
sentiment_mapping = {
|
48 |
"LABEL_0": "Negative",
|
49 |
"LABEL_1": "Neutral",
|
@@ -57,8 +123,8 @@ def analyze_sentiment(text):
|
|
57 |
def extract_topic(text):
|
58 |
try:
|
59 |
topic_result = topic_pipeline(text, TOPIC_LABELS)
|
60 |
-
top_topic = topic_result["labels"][0] #
|
61 |
-
confidence = topic_result["scores"][0]
|
62 |
return top_topic, confidence
|
63 |
except Exception as e:
|
64 |
return f"Error extracting topic: {e}", None
|
@@ -69,14 +135,14 @@ def chatbot_response(user_prompt):
|
|
69 |
return None, None, None, None, None
|
70 |
|
71 |
try:
|
72 |
-
# AI Response
|
73 |
-
|
74 |
-
ai_response =
|
75 |
|
76 |
-
# Sentiment Analysis
|
77 |
sentiment_label, sentiment_confidence = analyze_sentiment(user_prompt)
|
78 |
|
79 |
-
# Topic Extraction
|
80 |
topic_label, topic_confidence = extract_topic(user_prompt)
|
81 |
|
82 |
return ai_response.text, sentiment_label, sentiment_confidence, topic_label, topic_confidence
|
|
|
1 |
import os
|
2 |
import streamlit as st
|
3 |
import google.generativeai as genai
|
4 |
+
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
|
5 |
+
from datasets import load_dataset
|
|
|
|
|
6 |
|
7 |
# π Fetch API key from Hugging Face Secrets
|
8 |
GEMINI_API_KEY = os.getenv("gemini_api")
|
|
|
9 |
if GEMINI_API_KEY:
|
10 |
genai.configure(api_key=GEMINI_API_KEY)
|
11 |
else:
|
12 |
st.error("β οΈ Google API key is missing! Set it in Hugging Face Secrets.")
|
13 |
|
14 |
+
# Define path for the fine-tuned model
|
15 |
+
FINE_TUNED_MODEL_DIR = "fine-tuned-sentiment-model"
|
16 |
+
|
17 |
+
# Function to fine-tune sentiment analysis model using sentiment140.csv
|
18 |
+
def fine_tune_model():
|
19 |
+
st.info("Fine-tuning sentiment model. This may take a while...")
|
20 |
+
|
21 |
+
# Load the dataset from the local CSV file.
|
22 |
+
# Ensure that 'sentiment140.csv' is in your working directory.
|
23 |
+
try:
|
24 |
+
dataset = load_dataset('csv', data_files={'train': 'sentiment140.csv'}, encoding='ISO-8859-1')
|
25 |
+
except Exception as e:
|
26 |
+
st.error(f"β Error loading dataset: {e}")
|
27 |
+
return None, None
|
28 |
+
|
29 |
+
# Convert sentiment labels: sentiment140 labels are 0 (Negative), 2 (Neutral), 4 (Positive).
|
30 |
+
# We map them to 0,1,2 respectively.
|
31 |
+
def convert_labels(example):
|
32 |
+
mapping = {0: 0, 2: 1, 4: 2}
|
33 |
+
example["label"] = mapping[int(example["target"])]
|
34 |
+
return example
|
35 |
+
|
36 |
+
dataset = dataset.map(convert_labels)
|
37 |
+
|
38 |
+
# Base model name
|
39 |
+
base_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
|
40 |
+
|
41 |
+
# Initialize tokenizer and model
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
43 |
+
model = AutoModelForSequenceClassification.from_pretrained(base_model_name, num_labels=3)
|
44 |
+
|
45 |
+
# Tokenize the dataset; assuming the CSV has a column named "text"
|
46 |
+
def tokenize_function(examples):
|
47 |
+
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
|
48 |
+
tokenized_dataset = dataset.map(tokenize_function, batched=True)
|
49 |
+
|
50 |
+
# Set training arguments (for demo purposes, we use 1 epoch; adjust as needed)
|
51 |
+
training_args = TrainingArguments(
|
52 |
+
output_dir="./results",
|
53 |
+
num_train_epochs=1,
|
54 |
+
per_device_train_batch_size=8,
|
55 |
+
logging_steps=10,
|
56 |
+
save_steps=50,
|
57 |
+
evaluation_strategy="no",
|
58 |
+
learning_rate=2e-5,
|
59 |
+
weight_decay=0.01,
|
60 |
+
logging_dir='./logs',
|
61 |
+
disable_tqdm=False
|
62 |
+
)
|
63 |
+
|
64 |
+
trainer = Trainer(
|
65 |
+
model=model,
|
66 |
+
args=training_args,
|
67 |
+
train_dataset=tokenized_dataset["train"]
|
68 |
+
)
|
69 |
+
|
70 |
+
trainer.train()
|
71 |
+
|
72 |
+
# Save the fine-tuned model and tokenizer
|
73 |
+
model.save_pretrained(FINE_TUNED_MODEL_DIR)
|
74 |
+
tokenizer.save_pretrained(FINE_TUNED_MODEL_DIR)
|
75 |
+
st.success("β
Fine-tuning complete and model saved.")
|
76 |
+
return model, tokenizer
|
77 |
+
|
78 |
+
# Load (or fine-tune) the sentiment analysis model and tokenizer
|
79 |
+
if not os.path.exists(FINE_TUNED_MODEL_DIR):
|
80 |
+
model, tokenizer = fine_tune_model()
|
81 |
+
if model is None or tokenizer is None:
|
82 |
+
st.error("β Failed to fine-tune the sentiment analysis model.")
|
83 |
+
else:
|
84 |
+
tokenizer = AutoTokenizer.from_pretrained(FINE_TUNED_MODEL_DIR)
|
85 |
+
model = AutoModelForSequenceClassification.from_pretrained(FINE_TUNED_MODEL_DIR)
|
86 |
|
87 |
+
# Initialize sentiment analysis pipeline using the fine-tuned model
|
88 |
try:
|
89 |
+
sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
|
|
|
90 |
except Exception as e:
|
91 |
+
st.error(f"β Error loading sentiment pipeline: {e}")
|
92 |
|
93 |
# Load Topic Extraction Model
|
94 |
try:
|
|
|
106 |
def analyze_sentiment(text):
|
107 |
try:
|
108 |
sentiment_result = sentiment_pipeline(text)[0]
|
109 |
+
label = sentiment_result['label'] # e.g., "LABEL_0", "LABEL_1", "LABEL_2"
|
110 |
+
score = sentiment_result['score'] # Confidence score
|
111 |
|
112 |
+
# Map model labels to human-readable format
|
113 |
sentiment_mapping = {
|
114 |
"LABEL_0": "Negative",
|
115 |
"LABEL_1": "Neutral",
|
|
|
123 |
def extract_topic(text):
|
124 |
try:
|
125 |
topic_result = topic_pipeline(text, TOPIC_LABELS)
|
126 |
+
top_topic = topic_result["labels"][0] # Highest confidence topic
|
127 |
+
confidence = topic_result["scores"][0]
|
128 |
return top_topic, confidence
|
129 |
except Exception as e:
|
130 |
return f"Error extracting topic: {e}", None
|
|
|
135 |
return None, None, None, None, None
|
136 |
|
137 |
try:
|
138 |
+
# Generate AI Response using Gemini
|
139 |
+
model_gen = genai.GenerativeModel("gemini-1.5-pro")
|
140 |
+
ai_response = model_gen.generate_content(user_prompt)
|
141 |
|
142 |
+
# Run Sentiment Analysis
|
143 |
sentiment_label, sentiment_confidence = analyze_sentiment(user_prompt)
|
144 |
|
145 |
+
# Run Topic Extraction
|
146 |
topic_label, topic_confidence = extract_topic(user_prompt)
|
147 |
|
148 |
return ai_response.text, sentiment_label, sentiment_confidence, topic_label, topic_confidence
|
db.py
CHANGED
@@ -3,27 +3,26 @@ import requests
|
|
3 |
import io
|
4 |
from pymongo import MongoClient
|
5 |
|
6 |
-
#
|
7 |
def get_mongo_client():
|
8 |
client = MongoClient("mongodb+srv://groupA:[email protected]/?retryWrites=true&w=majority&appName=SentimentCluster")
|
9 |
db = client["sentiment_db"]
|
10 |
return db["tweets"]
|
11 |
|
12 |
-
#
|
13 |
def insert_data_if_empty():
|
14 |
collection = get_mongo_client()
|
15 |
|
16 |
if collection.count_documents({}) == 0:
|
17 |
-
print("π’ No data found. Inserting dataset...")
|
18 |
-
|
19 |
csv_url = "https://huggingface.co/spaces/sharangrav24/SentimentAnalysis/resolve/main/sentiment140.csv"
|
20 |
|
21 |
try:
|
22 |
response = requests.get(csv_url)
|
23 |
-
response.raise_for_status() # Ensure request was successful
|
24 |
df = pd.read_csv(io.StringIO(response.text), encoding="ISO-8859-1")
|
25 |
|
26 |
-
# Insert into MongoDB
|
27 |
collection.insert_many(df.to_dict("records"))
|
28 |
print("β
Data Inserted into MongoDB!")
|
29 |
except Exception as e:
|
|
|
3 |
import io
|
4 |
from pymongo import MongoClient
|
5 |
|
6 |
+
# Function to establish MongoDB connection and return the collection
|
7 |
def get_mongo_client():
|
8 |
client = MongoClient("mongodb+srv://groupA:[email protected]/?retryWrites=true&w=majority&appName=SentimentCluster")
|
9 |
db = client["sentiment_db"]
|
10 |
return db["tweets"]
|
11 |
|
12 |
+
# Function to insert the dataset into MongoDB if the collection is empty
|
13 |
def insert_data_if_empty():
|
14 |
collection = get_mongo_client()
|
15 |
|
16 |
if collection.count_documents({}) == 0:
|
17 |
+
print("π’ No data found in MongoDB. Inserting dataset...")
|
|
|
18 |
csv_url = "https://huggingface.co/spaces/sharangrav24/SentimentAnalysis/resolve/main/sentiment140.csv"
|
19 |
|
20 |
try:
|
21 |
response = requests.get(csv_url)
|
22 |
+
response.raise_for_status() # Ensure the request was successful
|
23 |
df = pd.read_csv(io.StringIO(response.text), encoding="ISO-8859-1")
|
24 |
|
25 |
+
# Insert dataset records into MongoDB
|
26 |
collection.insert_many(df.to_dict("records"))
|
27 |
print("β
Data Inserted into MongoDB!")
|
28 |
except Exception as e:
|