Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import google.generativeai as genai | |
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments | |
from datasets import load_dataset | |
# --- Monkey Patch for Accelerator --- | |
try: | |
import accelerate | |
from accelerate import Accelerator | |
import inspect | |
# If the Accelerator.__init__ does not accept "dispatch_batches", remove it from kwargs. | |
if 'dispatch_batches' not in inspect.signature(Accelerator.__init__).parameters: | |
old_init = Accelerator.__init__ | |
def new_init(self, *args, **kwargs): | |
if 'dispatch_batches' in kwargs: | |
kwargs.pop('dispatch_batches') | |
old_init(self, *args, **kwargs) | |
Accelerator.__init__ = new_init | |
except Exception as e: | |
st.error(f"Error patching Accelerator: {e}") | |
# --- Configure Gemini API --- | |
GEMINI_API_KEY = os.getenv("gemini_api") | |
if GEMINI_API_KEY: | |
genai.configure(api_key=GEMINI_API_KEY) | |
else: | |
st.error("β οΈ Google API key is missing! Set it in Hugging Face Secrets.") | |
# Path to save/load the fine-tuned model | |
FINE_TUNED_MODEL_DIR = "fine-tuned-sentiment-model" | |
# --- Fine-tune the Sentiment Model --- | |
def fine_tune_model(): | |
st.info("Fine-tuning sentiment model. This may take a while...") | |
# Load the dataset from the local CSV file. | |
try: | |
dataset = load_dataset('csv', data_files={'train': 'sentiment140.csv'}, encoding='ISO-8859-1') | |
except Exception as e: | |
st.error(f"β Error loading dataset: {e}") | |
return None, None | |
# Convert sentiment labels: sentiment140 labels are 0 (Negative), 2 (Neutral), 4 (Positive). | |
def convert_labels(example): | |
mapping = {0: 0, 2: 1, 4: 2} | |
example["label"] = mapping[int(example["target"])] | |
return example | |
dataset = dataset.map(convert_labels) | |
base_model_name = "cardiffnlp/twitter-roberta-base-sentiment" | |
tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(base_model_name, num_labels=3) | |
# Tokenize the dataset; assuming the CSV has a column named "text" | |
def tokenize_function(examples): | |
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128) | |
tokenized_dataset = dataset.map(tokenize_function, batched=True) | |
training_args = TrainingArguments( | |
output_dir="./results", | |
num_train_epochs=1, # For demonstration, we train for 1 epoch. | |
per_device_train_batch_size=8, | |
logging_steps=10, | |
save_steps=50, | |
evaluation_strategy="no", | |
learning_rate=2e-5, | |
weight_decay=0.01, | |
logging_dir='./logs', | |
disable_tqdm=False | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_dataset["train"] | |
) | |
trainer.train() | |
model.save_pretrained(FINE_TUNED_MODEL_DIR) | |
tokenizer.save_pretrained(FINE_TUNED_MODEL_DIR) | |
st.success("β Fine-tuning complete and model saved.") | |
return model, tokenizer | |
# Load or fine-tune the sentiment model | |
if not os.path.exists(FINE_TUNED_MODEL_DIR): | |
model, tokenizer = fine_tune_model() | |
if model is None or tokenizer is None: | |
st.error("β Failed to fine-tune the sentiment analysis model.") | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(FINE_TUNED_MODEL_DIR) | |
model = AutoModelForSequenceClassification.from_pretrained(FINE_TUNED_MODEL_DIR) | |
# Create sentiment analysis pipeline from the fine-tuned model | |
try: | |
sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer) | |
except Exception as e: | |
st.error(f"β Error loading sentiment pipeline: {e}") | |
# Load Topic Extraction Model | |
try: | |
topic_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
except Exception as e: | |
st.error(f"β Error loading topic extraction model: {e}") | |
# Predefined topic labels for classification | |
TOPIC_LABELS = [ | |
"Technology", "Politics", "Business", "Sports", "Entertainment", | |
"Health", "Science", "Education", "Finance", "Travel", "Food" | |
] | |
# Function to analyze sentiment | |
def analyze_sentiment(text): | |
try: | |
sentiment_result = sentiment_pipeline(text)[0] | |
label = sentiment_result['label'] | |
score = sentiment_result['score'] | |
sentiment_mapping = { | |
"LABEL_0": "Negative", | |
"LABEL_1": "Neutral", | |
"LABEL_2": "Positive" | |
} | |
return sentiment_mapping.get(label, "Unknown"), score | |
except Exception as e: | |
return f"Error analyzing sentiment: {e}", None | |
# Function to extract topic | |
def extract_topic(text): | |
try: | |
topic_result = topic_pipeline(text, TOPIC_LABELS) | |
top_topic = topic_result["labels"][0] | |
confidence = topic_result["scores"][0] | |
return top_topic, confidence | |
except Exception as e: | |
return f"Error extracting topic: {e}", None | |
# Function to generate AI response along with sentiment and topic analysis | |
def chatbot_response(user_prompt): | |
if not user_prompt: | |
return None, None, None, None, None | |
try: | |
# Generate AI Response using Gemini | |
model_gen = genai.GenerativeModel("gemini-1.5-pro") | |
ai_response = model_gen.generate_content(user_prompt) | |
# Sentiment Analysis | |
sentiment_label, sentiment_confidence = analyze_sentiment(user_prompt) | |
# Topic Extraction | |
topic_label, topic_confidence = extract_topic(user_prompt) | |
return ai_response.text, sentiment_label, sentiment_confidence, topic_label, topic_confidence | |
except Exception as e: | |
return f"β Error: {e}", None, None, None, None | |