Spaces:
Sleeping
Sleeping
import os | |
import re | |
import streamlit as st | |
import google.generativeai as genai | |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer | |
from db import get_entry_by_index, get_dataset_summary | |
# Configure Gemini API key | |
GEMINI_API_KEY = os.getenv("gemini_api") | |
if GEMINI_API_KEY: | |
genai.configure(api_key=GEMINI_API_KEY) | |
else: | |
st.error("β οΈ Google API key is missing! Set it in Hugging Face Secrets.") | |
# Load pre-trained sentiment analysis model | |
MODEL_NAME = "cardiffnlp/twitter-roberta-base-sentiment" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer) | |
except Exception as e: | |
st.error(f"β Error loading sentiment model: {e}") | |
# Load Topic Extraction Model | |
try: | |
topic_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
except Exception as e: | |
st.error(f"β Error loading topic extraction model: {e}") | |
# Predefined topic labels for classification | |
TOPIC_LABELS = [ | |
"Technology π»", "Politics ποΈ", "Business πΌ", "Sports β½", "Entertainment π", "Health π©Ί", | |
"Science π¬", "Education π", "Finance π°", "Travel βοΈ", "Food π", "Environment π±", "Culture π", | |
"History πΊ", "Art π¨", "Literature π", "Automotive π", "Law βοΈ", "Music π΅", "Movies π¬" | |
] | |
def analyze_sentiment(text): | |
try: | |
result = sentiment_pipeline(text)[0] | |
label = result['label'] | |
score = result['score'] | |
sentiment_mapping = { | |
"LABEL_0": "π Negative", | |
"LABEL_1": "π Neutral", | |
"LABEL_2": "π Positive" | |
} | |
return sentiment_mapping.get(label, "Unknown"), score | |
except Exception as e: | |
return f"Error analyzing sentiment: {e}", None | |
def extract_topic(text): | |
try: | |
result = topic_pipeline(text, TOPIC_LABELS) | |
top_topic = result["labels"][0] | |
confidence = result["scores"][0] | |
return top_topic, confidence | |
except Exception as e: | |
return f"Error extracting topic: {e}", None | |
# Helper: Extract entry index from prompt (e.g., "data entry 1" yields index 0) | |
def extract_entry_index(prompt): | |
match = re.search(r'(data entry|entry)\s+(\d+)', prompt, re.IGNORECASE) | |
if match: | |
index = int(match.group(2)) - 1 # convert to 0-based index | |
return index | |
return None | |
# Helper: Detect if the query is asking for a specific dataset entry. | |
def is_entry_query(prompt): | |
index = extract_entry_index(prompt) | |
if index is not None: | |
return True, index | |
return False, None | |
# Helper: Detect if the query is a basic dataset question. | |
def is_basic_dataset_question(prompt): | |
lower = prompt.lower() | |
keywords = ["dataset summary", "total tweets", "sentiment distribution", "overall dataset", "data overview", "data summary"] | |
return any(keyword in lower for keyword in keywords) | |
def chatbot_response(user_prompt): | |
if not user_prompt: | |
return None, None, None, None, None | |
try: | |
# If the query is a basic dataset question, fetch summary from MongoDB. | |
if is_basic_dataset_question(user_prompt): | |
summary = get_dataset_summary() | |
ai_response = "Dataset Summary:\n" + summary | |
sentiment_label, sentiment_confidence = analyze_sentiment(summary) | |
topic_label, topic_confidence = extract_topic(summary) | |
return ai_response, sentiment_label, sentiment_confidence, topic_label, topic_confidence | |
# If the query is about a specific entry in the dataset... | |
entry_query, index = is_entry_query(user_prompt) | |
if entry_query: | |
entry = get_entry_by_index(index) | |
if entry is None: | |
return "β No entry found for the requested index.", None, None, None, None | |
# Retrieve fields from the document | |
entry_text = entry.get("text", "No text available.") | |
entry_user = entry.get("user", "Unknown") | |
entry_date = entry.get("date", "Unknown") | |
# Build a static response message with new lines for each field. | |
ai_response = ( | |
"Let's break down this MongoDB entry:\n\n" | |
f"**Tweet:** {entry_text}\n\n" | |
f"**User:** {entry_user}\n\n" | |
f"**Date:** {entry_date}" | |
) | |
sentiment_label, sentiment_confidence = analyze_sentiment(entry_text) | |
topic_label, topic_confidence = extract_topic(entry_text) | |
return ai_response, sentiment_label, sentiment_confidence, topic_label, topic_confidence | |
# For other queries, use the generative model. | |
model_gen = genai.GenerativeModel("gemini-1.5-pro") | |
ai_response_obj = model_gen.generate_content(user_prompt) | |
ai_response = ai_response_obj.text | |
sentiment_label, sentiment_confidence = analyze_sentiment(user_prompt) | |
topic_label, topic_confidence = extract_topic(user_prompt) | |
return ai_response, sentiment_label, sentiment_confidence, topic_label, topic_confidence | |
except Exception as e: | |
return f"β Error: {e}", None, None, None, None | |