File size: 3,674 Bytes
adc05de
623c9e7
adc05de
 
191e2cd
623c9e7
5924948
191e2cd
 
 
 
5924948
 
 
 
 
 
 
1ee9cdc
 
5924948
 
1ee9cdc
 
 
 
 
191e2cd
1ee9cdc
 
 
 
 
 
 
5924948
adc05de
1ee9cdc
 
adc05de
 
 
 
623c9e7
1ee9cdc
191e2cd
 
 
adc05de
191e2cd
1ee9cdc
191e2cd
1ee9cdc
 
 
 
adc05de
5924948
 
 
 
1ee9cdc
191e2cd
5924948
191e2cd
 
 
 
adc05de
5924948
191e2cd
 
 
 
adc05de
 
623c9e7
1ee9cdc
 
623c9e7
 
191e2cd
adc05de
5924948
623c9e7
 
 
 
 
adc05de
 
623c9e7
 
5924948
191e2cd
adc05de
623c9e7
5924948
1ee9cdc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
import gradio as gr
import torch
from concurrent.futures import ThreadPoolExecutor
from threading import Lock

# Global cache and lock for thread-safety
CACHE_SIZE = 100
prediction_cache = {}
cache_lock = Lock()

# Mapping for sentiment labels from cardiffnlp/twitter-roberta-base-sentiment
SENTIMENT_LABEL_MAPPING = {
    "LABEL_0": "negative",
    "LABEL_1": "neutral",
    "LABEL_2": "positive"
}

def load_model(model_name):
    """
    Loads the model with 8-bit quantization if a GPU is available;
    otherwise, loads the full model.
    """
    if torch.cuda.is_available():
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name, load_in_8bit=True, device_map="auto"
        )
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        device = 0  # GPU index
    else:
        model = AutoModelForSequenceClassification.from_pretrained(model_name)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        device = -1
    return pipeline("text-classification", model=model, tokenizer=tokenizer, device=device)

# Load both models concurrently at startup.
with ThreadPoolExecutor() as executor:
    sentiment_future = executor.submit(load_model, "cardiffnlp/twitter-roberta-base-sentiment")
    emotion_future = executor.submit(load_model, "bhadresh-savani/bert-base-uncased-emotion")

sentiment_pipeline = sentiment_future.result()
emotion_pipeline = emotion_future.result()

def analyze_text(text):
    # Check cache first (thread-safe)
    with cache_lock:
        if text in prediction_cache:
            return prediction_cache[text]
    
    try:
        # Run both model inferences in parallel.
        with ThreadPoolExecutor() as executor:
            future_sentiment = executor.submit(sentiment_pipeline, text)
            future_emotion = executor.submit(emotion_pipeline, text)
            sentiment_result = future_sentiment.result()[0]
            emotion_result = future_emotion.result()[0]
        
        # Remap the sentiment label to a human-readable format if available.
        raw_sentiment_label = sentiment_result.get("label", "")
        sentiment_label = SENTIMENT_LABEL_MAPPING.get(raw_sentiment_label, raw_sentiment_label)
        
        # Format the output with rounded scores.
        result = {
            "Sentiment": {sentiment_label: round(sentiment_result['score'], 4)},
            "Emotion": {emotion_result['label']: round(emotion_result['score'], 4)}
        }
    except Exception as e:
        result = {"error": str(e)}
    
    # Update the cache in a thread-safe manner.
    with cache_lock:
        if len(prediction_cache) >= CACHE_SIZE:
            prediction_cache.pop(next(iter(prediction_cache)))
        prediction_cache[text] = result
    
    return result

# Define the Gradio interface.
demo = gr.Interface(
    fn=analyze_text,
    inputs=gr.Textbox(placeholder="Enter your text here...", label="Input Text"),
    outputs=gr.JSON(label="Analysis Results"),
    title="🚀 Fast Sentiment & Emotion Analysis",
    description="Optimized application that remaps sentiment labels and uses parallel processing.",
    examples=[
        ["I'm thrilled to start this new adventure!"],
        ["This situation is making me really frustrated."],
        ["I feel so heartbroken and lost."]
    ],
    theme="soft",
    allow_flagging="never"
)

# Warm up the models with a sample input.
_ = analyze_text("Warming up models...")

if __name__ == "__main__":
    # Bind to all interfaces for Hugging Face Spaces.
    demo.launch(server_name="0.0.0.0", server_port=7860)