Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import google.generativeai as genai | |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer | |
import pandas as pd | |
from db import get_mongo_client | |
# Configure Gemini API key | |
GEMINI_API_KEY = os.getenv("gemini_api") | |
if GEMINI_API_KEY: | |
genai.configure(api_key=GEMINI_API_KEY) | |
else: | |
st.error("β οΈ Google API key is missing! Set it in Hugging Face Secrets.") | |
# Load pre-trained sentiment analysis model | |
MODEL_NAME = "cardiffnlp/twitter-roberta-base-sentiment" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer) | |
except Exception as e: | |
st.error(f"β Error loading sentiment model: {e}") | |
# Load Topic Extraction Model | |
try: | |
topic_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
except Exception as e: | |
st.error(f"β Error loading topic extraction model: {e}") | |
# Predefined topic labels for classification | |
TOPIC_LABELS = [ | |
"Technology", "Politics", "Business", "Sports", "Entertainment", | |
"Health", "Science", "Education", "Finance", "Travel", "Food" | |
] | |
# Function to analyze sentiment using the pre-trained model | |
def analyze_sentiment(text): | |
try: | |
sentiment_result = sentiment_pipeline(text)[0] | |
label = sentiment_result['label'] | |
score = sentiment_result['score'] | |
sentiment_mapping = { | |
"LABEL_0": "Negative", | |
"LABEL_1": "Neutral", | |
"LABEL_2": "Positive" | |
} | |
return sentiment_mapping.get(label, "Unknown"), score | |
except Exception as e: | |
return f"Error analyzing sentiment: {e}", None | |
# Function to extract topic using zero-shot classification | |
def extract_topic(text): | |
try: | |
topic_result = topic_pipeline(text, TOPIC_LABELS) | |
top_topic = topic_result["labels"][0] | |
confidence = topic_result["scores"][0] | |
return top_topic, confidence | |
except Exception as e: | |
return f"Error extracting topic: {e}", None | |
# Function to determine if the user's query is about the dataset | |
def is_dataset_query(text): | |
keywords = ["dataset", "data", "historical", "csv", "stored"] | |
text_lower = text.lower() | |
for keyword in keywords: | |
if keyword in text_lower: | |
return True | |
return False | |
# Function to retrieve insights from the dataset stored in MongoDB | |
def get_dataset_insights(): | |
try: | |
collection = get_mongo_client() | |
data = list(collection.find({}, {"_id": 0})) | |
if not data: | |
return "The dataset in MongoDB is empty." | |
df = pd.DataFrame(data) | |
# Map the sentiment labels from sentiment140.csv: 0 -> Negative, 2 -> Neutral, 4 -> Positive. | |
sentiment_mapping = {0: "Negative", 2: "Neutral", 4: "Positive"} | |
if "target" in df.columns: | |
df['sentiment_label'] = df['target'].apply(lambda x: sentiment_mapping.get(int(x), "Unknown")) | |
summary = df['sentiment_label'].value_counts().to_dict() | |
summary_str = ", ".join([f"{k}: {v}" for k, v in summary.items()]) | |
return f"The dataset sentiment distribution is: {summary_str}." | |
else: | |
return "The dataset does not have a 'target' field." | |
except Exception as e: | |
return f"Error retrieving dataset insights: {e}" | |
# Function to generate AI response along with sentiment and topic analysis | |
def chatbot_response(user_prompt): | |
if not user_prompt: | |
return None, None, None, None, None | |
# Check if the query is about the dataset | |
if is_dataset_query(user_prompt): | |
dataset_insights = get_dataset_insights() | |
sentiment_label, sentiment_confidence = analyze_sentiment(user_prompt) | |
topic_label, topic_confidence = extract_topic(user_prompt) | |
return dataset_insights, sentiment_label, sentiment_confidence, topic_label, topic_confidence | |
else: | |
try: | |
# Generate AI response using Gemini | |
model_gen = genai.GenerativeModel("gemini-1.5-pro") | |
ai_response = model_gen.generate_content(user_prompt) | |
sentiment_label, sentiment_confidence = analyze_sentiment(user_prompt) | |
topic_label, topic_confidence = extract_topic(user_prompt) | |
return ai_response.text, sentiment_label, sentiment_confidence, topic_label, topic_confidence | |
except Exception as e: | |
return f"β Error: {e}", None, None, None, None | |