Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import google.generativeai as genai | |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer | |
# Configure Gemini API key | |
GEMINI_API_KEY = os.getenv("gemini_api") | |
if GEMINI_API_KEY: | |
genai.configure(api_key=GEMINI_API_KEY) | |
else: | |
st.error("β οΈ Google API key is missing! Set it in Hugging Face Secrets.") | |
# Load pre-trained sentiment analysis model | |
MODEL_NAME = "cardiffnlp/twitter-roberta-base-sentiment" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer) | |
except Exception as e: | |
st.error(f"β Error loading sentiment model: {e}") | |
# Load Topic Extraction Model | |
try: | |
topic_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
except Exception as e: | |
st.error(f"β Error loading topic extraction model: {e}") | |
# Predefined topic labels for classification | |
TOPIC_LABELS = [ | |
"Technology", "Politics", "Business", "Sports", "Entertainment", | |
"Health", "Science", "Education", "Finance", "Travel", "Food" | |
] | |
# Function to analyze sentiment using the pre-trained model | |
def analyze_sentiment(text): | |
try: | |
sentiment_result = sentiment_pipeline(text)[0] | |
label = sentiment_result['label'] | |
score = sentiment_result['score'] | |
sentiment_mapping = { | |
"LABEL_0": "Negative", | |
"LABEL_1": "Neutral", | |
"LABEL_2": "Positive" | |
} | |
return sentiment_mapping.get(label, "Unknown"), score | |
except Exception as e: | |
return f"Error analyzing sentiment: {e}", None | |
# Function to extract topic using zero-shot classification | |
def extract_topic(text): | |
try: | |
topic_result = topic_pipeline(text, TOPIC_LABELS) | |
top_topic = topic_result["labels"][0] | |
confidence = topic_result["scores"][0] | |
return top_topic, confidence | |
except Exception as e: | |
return f"Error extracting topic: {e}", None | |
# Function to generate AI response along with sentiment and topic analysis. | |
# Also, if the query relates to the dataset, fetch statistics from MongoDB. | |
def chatbot_response(user_prompt): | |
if not user_prompt: | |
return None, None, None, None, None | |
try: | |
# Generate AI response using Gemini | |
model_gen = genai.GenerativeModel("gemini-1.5-pro") | |
ai_response = model_gen.generate_content(user_prompt) | |
# Perform sentiment analysis on the user prompt | |
sentiment_label, sentiment_confidence = analyze_sentiment(user_prompt) | |
# Perform topic extraction on the user prompt | |
topic_label, topic_confidence = extract_topic(user_prompt) | |
# If the prompt seems related to the dataset, get MongoDB statistics. | |
if any(keyword in user_prompt.lower() for keyword in ["sentiment140", "dataset", "historical", "mongodb", "stored data"]): | |
from db import get_mongo_client | |
collection = get_mongo_client() | |
# Aggregate counts by the 'target' field (assumed to be in the CSV) | |
pipeline = [ | |
{"$group": {"_id": "$target", "count": {"$sum": 1}}} | |
] | |
results = list(collection.aggregate(pipeline)) | |
sentiment_map = {0: "Negative", 2: "Neutral", 4: "Positive"} | |
stats_str = "" | |
total = 0 | |
for r in results: | |
key = sentiment_map.get(r["_id"], r["_id"]) | |
count = r["count"] | |
total += count | |
stats_str += f"{key}: {count}\n" | |
stats_str += f"Total records: {total}" | |
ai_response_text = ai_response.text + "\n\nDataset Information:\n" + stats_str | |
else: | |
ai_response_text = ai_response.text | |
return ai_response_text, sentiment_label, sentiment_confidence, topic_label, topic_confidence | |
except Exception as e: | |
return f"β Error: {e}", None, None, None, None | |