File size: 5,219 Bytes
7268351
979706a
7268351
 
be89ae1
84326e0
b83a640
be89ae1
7268351
 
 
 
 
 
be89ae1
 
7268351
be89ae1
 
f5b718b
7268351
be89ae1
7268351
5a94c8e
 
 
 
 
 
 
 
 
 
 
f763dd0
7268351
 
8d3fcda
 
 
f37d2cc
 
 
 
 
 
7268351
 
 
5a94c8e
f763dd0
8d3fcda
 
 
5a94c8e
f763dd0
5a94c8e
f763dd0
84326e0
f89cec9
 
979706a
84326e0
f89cec9
f37d2cc
6e2dc41
84326e0
 
 
 
 
 
 
 
 
 
 
 
 
 
7268351
 
5a94c8e
f37d2cc
979706a
84326e0
 
 
 
 
 
 
 
 
 
 
 
 
979706a
f89cec9
84326e0
979706a
f89cec9
 
84326e0
f89cec9
 
4ec2156
f89cec9
 
 
6e2dc41
 
f89cec9
84326e0
 
 
 
 
 
 
 
 
979706a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import re
import streamlit as st
import google.generativeai as genai
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
from db import get_entry_by_index, get_dataset_summary

# Configure Gemini API key
GEMINI_API_KEY = os.getenv("gemini_api")
if GEMINI_API_KEY:
    genai.configure(api_key=GEMINI_API_KEY)
else:
    st.error("⚠️ Google API key is missing! Set it in Hugging Face Secrets.")

# Load pre-trained sentiment analysis model
MODEL_NAME = "cardiffnlp/twitter-roberta-base-sentiment"
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
    sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
except Exception as e:
    st.error(f"❌ Error loading sentiment model: {e}")

# Load Topic Extraction Model
try:
    topic_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
except Exception as e:
    st.error(f"❌ Error loading topic extraction model: {e}")

# Predefined topic labels for classification
TOPIC_LABELS = [
    "Technology", "Politics", "Business", "Sports", "Entertainment",
    "Health", "Science", "Education", "Finance", "Travel", "Food"
]

def analyze_sentiment(text):
    try:
        result = sentiment_pipeline(text)[0]
        label = result['label']
        score = result['score']
        sentiment_mapping = {
            "LABEL_0": "Negative",
            "LABEL_1": "Neutral",
            "LABEL_2": "Positive"
        }
        return sentiment_mapping.get(label, "Unknown"), score
    except Exception as e:
        return f"Error analyzing sentiment: {e}", None

def extract_topic(text):
    try:
        result = topic_pipeline(text, TOPIC_LABELS)
        top_topic = result["labels"][0]
        confidence = result["scores"][0]
        return top_topic, confidence
    except Exception as e:
        return f"Error extracting topic: {e}", None

# Helper: Extract entry index from prompt (e.g., "data entry 1" yields index 0)
def extract_entry_index(prompt):
    match = re.search(r'(data entry|entry)\s+(\d+)', prompt, re.IGNORECASE)
    if match:
        index = int(match.group(2)) - 1  # convert to 0-based index
        return index
    return None

# Helper: Detect if the query is asking for a specific dataset entry.
def is_entry_query(prompt):
    index = extract_entry_index(prompt)
    if index is not None:
        return True, index
    return False, None

# Helper: Detect if the query is a basic dataset question.
# Examples: "What is the dataset summary?", "Show me the sentiment distribution", etc.
def is_basic_dataset_question(prompt):
    lower = prompt.lower()
    keywords = ["dataset summary", "total tweets", "sentiment distribution", "overall dataset", "data overview", "data summary"]
    return any(keyword in lower for keyword in keywords)

def chatbot_response(user_prompt):
    if not user_prompt:
        return None, None, None, None, None

    try:
        # If the query is a basic dataset question, fetch summary from MongoDB.
        if is_basic_dataset_question(user_prompt):
            summary = get_dataset_summary()
            ai_response = "Dataset Summary:\n" + summary
            # Run analysis on the summary text
            sentiment_label, sentiment_confidence = analyze_sentiment(summary)
            topic_label, topic_confidence = extract_topic(summary)
            return ai_response, sentiment_label, sentiment_confidence, topic_label, topic_confidence

        # If the query is about a specific entry in the dataset...
        entry_query, index = is_entry_query(user_prompt)
        if entry_query:
            entry = get_entry_by_index(index)
            if entry is None:
                return "❌ No entry found for the requested index.", None, None, None, None
            # Retrieve fields from the document
            entry_text = entry.get("text", "No text available.")
            entry_user = entry.get("user", "Unknown")
            entry_date = entry.get("date", "Unknown")
            # Build a static response message with the required format
            ai_response = (
                "Let's break down this tweet-like MongoDB entry:\n\n"
                f"Tweet: {entry_text}\n"
                f"User: {entry_user}\n"
                f"Date: {entry_date}"
            )
            sentiment_label, sentiment_confidence = analyze_sentiment(entry_text)
            topic_label, topic_confidence = extract_topic(entry_text)
            return ai_response, sentiment_label, sentiment_confidence, topic_label, topic_confidence

        # For other queries, use the generative model (this branch may be slower).
        model_gen = genai.GenerativeModel("gemini-1.5-pro")
        ai_response_obj = model_gen.generate_content(user_prompt)
        ai_response = ai_response_obj.text
        sentiment_label, sentiment_confidence = analyze_sentiment(user_prompt)
        topic_label, topic_confidence = extract_topic(user_prompt)
        return ai_response, sentiment_label, sentiment_confidence, topic_label, topic_confidence

    except Exception as e:
        return f"❌ Error: {e}", None, None, None, None