Spaces:
Sleeping
Sleeping
import os | |
import re | |
import streamlit as st | |
import google.generativeai as genai | |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer | |
from db import get_entry_by_index # Helper to fetch a document by index | |
# Configure Gemini API key | |
GEMINI_API_KEY = os.getenv("gemini_api") | |
if GEMINI_API_KEY: | |
genai.configure(api_key=GEMINI_API_KEY) | |
else: | |
st.error("β οΈ Google API key is missing! Set it in Hugging Face Secrets.") | |
# Load pre-trained sentiment analysis model | |
MODEL_NAME = "cardiffnlp/twitter-roberta-base-sentiment" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer) | |
except Exception as e: | |
st.error(f"β Error loading sentiment model: {e}") | |
# Load Topic Extraction Model | |
try: | |
topic_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
except Exception as e: | |
st.error(f"β Error loading topic extraction model: {e}") | |
# Predefined topic labels for classification | |
TOPIC_LABELS = [ | |
"Technology", "Politics", "Business", "Sports", "Entertainment", | |
"Health", "Science", "Education", "Finance", "Travel", "Food" | |
] | |
# Function to analyze sentiment using the pre-trained model | |
def analyze_sentiment(text): | |
try: | |
sentiment_result = sentiment_pipeline(text)[0] | |
label = sentiment_result['label'] | |
score = sentiment_result['score'] | |
sentiment_mapping = { | |
"LABEL_0": "Negative", | |
"LABEL_1": "Neutral", | |
"LABEL_2": "Positive" | |
} | |
return sentiment_mapping.get(label, "Unknown"), score | |
except Exception as e: | |
return f"Error analyzing sentiment: {e}", None | |
# Function to extract topic using zero-shot classification | |
def extract_topic(text): | |
try: | |
topic_result = topic_pipeline(text, TOPIC_LABELS) | |
top_topic = topic_result["labels"][0] | |
confidence = topic_result["scores"][0] | |
return top_topic, confidence | |
except Exception as e: | |
return f"Error extracting topic: {e}", None | |
# Helper to detect if the user asks for a specific entry. | |
# Searches for patterns like "data entry 1" or "entry 2" (case-insensitive). | |
def get_entry_index(prompt): | |
match = re.search(r'(?:data entry|entry)\s*(\d+)', prompt.lower()) | |
if match: | |
# Convert to 0-indexed value. | |
return int(match.group(1)) - 1 | |
return None | |
# Helper to filter the generative response. | |
# We expect the response to contain: | |
# "Let's break down this tweet-like MongoDB entry:" followed by text, | |
# then "Conclusion:" followed by text. | |
# We remove any extra parts and remove the header "Conclusion:". | |
def filter_ai_response(ai_text): | |
breakdown_marker = "Let's break down this tweet-like MongoDB entry:" | |
conclusion_marker = "Conclusion:" | |
if breakdown_marker in ai_text and conclusion_marker in ai_text: | |
# Split into two parts. | |
parts = ai_text.split(breakdown_marker, 1)[1] | |
breakdown_part, conclusion_part = parts.split(conclusion_marker, 1) | |
# Rebuild output with the breakdown section and the conclusion content (without the header) | |
filtered = breakdown_marker + "\n" + breakdown_part.strip() + "\n" + conclusion_part.strip() | |
return filtered | |
else: | |
# If the markers aren't found, return the original text. | |
return ai_text | |
# Main function to generate AI response along with sentiment and category analysis. | |
# If the prompt asks for a specific entry, fetch its "text" from MongoDB and build a custom prompt. | |
def chatbot_response(user_prompt): | |
if not user_prompt: | |
return None, None, None, None, None | |
try: | |
entry_index = get_entry_index(user_prompt) | |
if entry_index is not None: | |
entry = get_entry_by_index(entry_index) | |
if entry is None: | |
return "β No entry found for the requested index.", None, None, None, None | |
entry_text = entry.get("text", "No text available.") | |
# Build a prompt instructing the Gemini model to provide analysis in a structured format. | |
combined_prompt = ( | |
f"Provide analysis for the following MongoDB entry:\n\n" | |
f"{entry_text}\n\n" | |
"Please respond in the following format:\n" | |
"Let's break down this tweet-like MongoDB entry:\n[Your detailed analysis here]\n" | |
"Conclusion:\n[Your conclusion here]" | |
) | |
# Run sentiment and topic analysis on the entry's text. | |
sentiment_label, sentiment_confidence = analyze_sentiment(entry_text) | |
topic_label, topic_confidence = extract_topic(entry_text) | |
else: | |
# If not an entry query, use the user prompt directly. | |
combined_prompt = user_prompt | |
sentiment_label, sentiment_confidence = analyze_sentiment(user_prompt) | |
topic_label, topic_confidence = extract_topic(user_prompt) | |
# Generate AI response using Gemini. | |
model_gen = genai.GenerativeModel("gemini-1.5-pro") | |
ai_response = model_gen.generate_content(combined_prompt) | |
# Filter the generative response to show only the required sections. | |
filtered_response = filter_ai_response(ai_response.text) | |
return filtered_response, sentiment_label, sentiment_confidence, topic_label, topic_confidence | |
except Exception as e: | |
return f"β Error: {e}", None, None, None, None | |