File size: 10,344 Bytes
b04682c
 
 
c5922b9
 
 
b04682c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5922b9
b04682c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5922b9
b04682c
 
 
 
 
fb510e6
b04682c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d25649c
b04682c
d25649c
b04682c
 
 
 
 
 
 
 
 
 
 
 
 
d25649c
b04682c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28094fc
b04682c
d25649c
b04682c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5922b9
b04682c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import os
from typing import Dict, List, Any, Optional, Union
from smolagents import Tool

class NamedEntityRecognitionTool(Tool):
    name = "ner_tool"
    description = """
    Identifies and labels named entities in text using customizable NER models.
    Can recognize entities such as persons, organizations, locations, dates, etc.
    Returns a structured analysis of all entities found in the input text.
    """
    inputs = {
        "text": {
            "type": "string",
            "description": "The text to analyze for named entities",
        },
        "model": {
            "type": "string",
            "description": "The NER model to use (default: 'dslim/bert-base-NER')",
            "nullable": True
        },
        "aggregation": {
            "type": "string",
            "description": "How to aggregate entities: 'simple' (just list), 'grouped' (by label), or 'detailed' (with confidence scores)",
            "nullable": True
        },
        "min_score": {
            "type": "number",
            "description": "Minimum confidence score threshold (0.0-1.0) for including entities",
            "nullable": True
        }
    }
    output_type = "string"
    
    def __init__(self):
        """Initialize the NER Tool with default settings."""
        super().__init__()
        self.default_model = "dslim/bert-base-NER"
        self.available_models = {
            "dslim/bert-base-NER": "Standard NER (English)",
            "jean-baptiste/camembert-ner": "French NER",
            "Davlan/bert-base-multilingual-cased-ner-hrl": "Multilingual NER",
            "Babelscape/wikineural-multilingual-ner": "WikiNeural Multilingual NER",
            "flair/ner-english-ontonotes-large": "OntoNotes English (fine-grained)",
            "elastic/distilbert-base-cased-finetuned-conll03-english": "CoNLL (fast)"
        }
        self.entity_colors = {
            "PER": "πŸŸ₯ Person",
            "PERSON": "πŸŸ₯ Person",
            "LOC": "🟨 Location",
            "LOCATION": "🟨 Location",
            "GPE": "🟨 Location",
            "ORG": "🟦 Organization",
            "ORGANIZATION": "🟦 Organization",
            "MISC": "🟩 Miscellaneous",
            "DATE": "πŸŸͺ Date",
            "TIME": "πŸŸͺ Time",
            "MONEY": "πŸ’° Money",
            "PERCENT": "πŸ“Š Percentage",
            "PRODUCT": "πŸ›’ Product",
            "EVENT": "🎫 Event",
            "WORK_OF_ART": "🎨 Work of Art",
            "LAW": "βš–οΈ Law",
            "LANGUAGE": "πŸ—£οΈ Language",
            "FAC": "🏒 Facility"
        }
        # Pipeline will be lazily loaded
        self._pipeline = None

    def _load_pipeline(self, model_name: str):
        """Load the NER pipeline with the specified model."""
        try:
            from transformers import pipeline
            self._pipeline = pipeline("ner", model=model_name, aggregation_strategy="simple")
            return True
        except Exception as e:
            print(f"Error loading model {model_name}: {str(e)}")
            try:
                # Fall back to default model
                from transformers import pipeline
                self._pipeline = pipeline("ner", model=self.default_model, aggregation_strategy="simple")
                return True
            except Exception as fallback_error:
                print(f"Error loading fallback model: {str(fallback_error)}")
                return False

    def _get_friendly_label(self, label: str) -> str:
        """Convert technical entity labels to friendly descriptions with color indicators."""
        # Strip B- or I- prefixes that indicate beginning or inside of entity
        clean_label = label.replace("B-", "").replace("I-", "")
        return self.entity_colors.get(clean_label, f"πŸ”· {clean_label}")

    def forward(self, text: str, model: str = None, aggregation: str = None, min_score: float = None) -> str:
        """
        Perform Named Entity Recognition on the input text.
        
        Args:
            text: The text to analyze
            model: NER model to use (default: dslim/bert-base-NER)
            aggregation: How to aggregate results (simple, grouped, detailed)
            min_score: Minimum confidence threshold (0.0-1.0)
            
        Returns:
            Formatted string with NER analysis results
        """
        # Set default values if parameters are None
        if model is None:
            model = self.default_model
        if aggregation is None:
            aggregation = "grouped"
        if min_score is None:
            min_score = 0.8
            
        # Validate model choice
        if model not in self.available_models and not model.startswith("dslim/"):
            return f"Model '{model}' not recognized. Available models: {', '.join(self.available_models.keys())}"
            
        # Load the model if not already loaded or if different from current
        if self._pipeline is None or self._pipeline.model.name_or_path != model:
            if not self._load_pipeline(model):
                return "Failed to load NER model. Please try a different model."
                
        # Perform NER analysis
        try:
            entities = self._pipeline(text)
            
            # Filter by confidence score
            entities = [e for e in entities if e.get('score', 0) >= min_score]
            
            if not entities:
                return "No entities were detected in the text with the current settings."
                
            # Process results based on aggregation method
            if aggregation == "simple":
                return self._format_simple(text, entities)
            elif aggregation == "detailed":
                return self._format_detailed(text, entities)
            else:  # default to grouped
                return self._format_grouped(text, entities)
                
        except Exception as e:
            return f"Error analyzing text: {str(e)}"
            
    def _format_simple(self, text: str, entities: List[Dict[str, Any]]) -> str:
        """Format entities as a simple list."""
        result = "Named Entities Found:\n\n"
        
        for entity in entities:
            word = entity.get("word", "")
            label = entity.get("entity", "UNKNOWN")
            score = entity.get("score", 0)
            friendly_label = self._get_friendly_label(label)
            
            result += f"β€’ {word} - {friendly_label} (confidence: {score:.2f})\n"
            
        return result
            
    def _format_grouped(self, text: str, entities: List[Dict[str, Any]]) -> str:
        """Format entities grouped by their category."""
        # Group entities by their label
        grouped = {}
        
        for entity in entities:
            word = entity.get("word", "")
            label = entity.get("entity", "UNKNOWN").replace("B-", "").replace("I-", "")
            
            if label not in grouped:
                grouped[label] = []
                
            grouped[label].append(word)
            
        # Build the result string
        result = "Named Entities by Category:\n\n"
        
        for label, words in grouped.items():
            friendly_label = self._get_friendly_label(label)
            unique_words = list(set(words))
            result += f"{friendly_label}: {', '.join(unique_words)}\n"
            
        return result
            
    def _format_detailed(self, text: str, entities: List[Dict[str, Any]]) -> str:
        """Format entities with detailed information including position in text."""
        # First, build an entity map to highlight the entire text
        character_labels = [None] * len(text)
        
        # Mark each character with its entity
        for entity in entities:
            start = entity.get("start", 0)
            end = entity.get("end", 0)
            label = entity.get("entity", "UNKNOWN")
            
            for i in range(start, min(end, len(text))):
                character_labels[i] = label
                
        # Build highlighted text sections
        highlighted_text = ""
        current_label = None
        current_segment = ""
        
        for i, char in enumerate(text):
            label = character_labels[i]
            
            if label != current_label:
                # End the previous segment if any
                if current_segment:
                    if current_label:
                        clean_label = current_label.replace("B-", "").replace("I-", "")
                        highlighted_text += f"[{current_segment}]({clean_label}) "
                    else:
                        highlighted_text += current_segment + " "
                        
                # Start a new segment
                current_label = label
                current_segment = char
            else:
                current_segment += char
                
        # Add the final segment
        if current_segment:
            if current_label:
                clean_label = current_label.replace("B-", "").replace("I-", "")
                highlighted_text += f"[{current_segment}]({clean_label})"
            else:
                highlighted_text += current_segment
                
        # Get entity details
        entity_details = []
        for entity in entities:
            word = entity.get("word", "")
            label = entity.get("entity", "UNKNOWN")
            score = entity.get("score", 0)
            friendly_label = self._get_friendly_label(label)
            
            entity_details.append(f"β€’ {word} - {friendly_label} (confidence: {score:.2f})")
            
        # Combine into final result
        result = "Entity Analysis:\n\n"
        result += "Text with Entities Marked:\n"
        result += highlighted_text + "\n\n"
        result += "Entity Details:\n"
        result += "\n".join(entity_details)
        
        return result
        
    def get_available_models(self) -> Dict[str, str]:
        """Return the dictionary of available models with descriptions."""
        return self.available_models

# Example usage:
# ner_tool = NamedEntityRecognitionTool()
# result = ner_tool("Apple Inc. is planning to open a new store in Paris, France next year.", model="dslim/bert-base-NER")
# print(result)