|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForTokenClassification |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
class MultiModelIndianAddressNER: |
|
def __init__(self): |
|
|
|
self.models_config = { |
|
"TinyBERT": { |
|
"name": "shiprocket-ai/open-tinybert-indian-address-ner", |
|
"description": "Lightweight and fast - 66.4M parameters", |
|
"base_model": "TinyBERT" |
|
}, |
|
"ModernBERT": { |
|
"name": "shiprocket-ai/open-modernbert-indian-address-ner", |
|
"description": "Modern architecture - 150M parameters", |
|
"base_model": "ModernBERT" |
|
}, |
|
"IndicBERT": { |
|
"name": "shiprocket-ai/open-indicbert-indian-address-ner", |
|
"description": "Indic language optimized - 32.9M parameters", |
|
"base_model": "IndicBERT" |
|
} |
|
} |
|
|
|
|
|
self.loaded_models = {} |
|
self.loaded_tokenizers = {} |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self.id2entity = { |
|
"0": "O", |
|
"1": "B-building_name", |
|
"2": "I-building_name", |
|
"3": "B-city", |
|
"4": "I-city", |
|
"5": "B-country", |
|
"6": "I-country", |
|
"7": "B-floor", |
|
"8": "I-floor", |
|
"9": "B-house_details", |
|
"10": "I-house_details", |
|
"11": "B-locality", |
|
"12": "I-locality", |
|
"13": "B-pincode", |
|
"14": "I-pincode", |
|
"15": "B-road", |
|
"16": "I-road", |
|
"17": "B-state", |
|
"18": "I-state", |
|
"19": "B-sub_locality", |
|
"20": "I-sub_locality", |
|
"21": "B-landmarks", |
|
"22": "I-landmarks" |
|
} |
|
|
|
|
|
self.load_model("TinyBERT") |
|
|
|
def load_model(self, model_key): |
|
"""Load a specific model if not already loaded""" |
|
if model_key not in self.loaded_models: |
|
print(f"Loading {model_key} model...") |
|
model_name = self.models_config[model_key]["name"] |
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForTokenClassification.from_pretrained(model_name) |
|
model.to(self.device) |
|
model.eval() |
|
|
|
self.loaded_tokenizers[model_key] = tokenizer |
|
self.loaded_models[model_key] = model |
|
print(f"β
{model_key} model loaded successfully!") |
|
|
|
except Exception as e: |
|
print(f"β Error loading {model_key}: {str(e)}") |
|
raise e |
|
|
|
return self.loaded_tokenizers[model_key], self.loaded_models[model_key] |
|
|
|
def predict(self, address, model_key="TinyBERT"): |
|
"""Extract entities from an Indian address using specified model""" |
|
if not address.strip(): |
|
return {}, f"Using {model_key} model" |
|
|
|
try: |
|
|
|
tokenizer, model = self.load_model(model_key) |
|
|
|
|
|
if model_key == "IndicBERT": |
|
|
|
entities = self._predict_token_based(address, tokenizer, model) |
|
else: |
|
|
|
entities = self._predict_offset_based(address, tokenizer, model) |
|
|
|
model_info = f"Using {model_key} ({self.models_config[model_key]['description']})" |
|
return entities, model_info |
|
|
|
except Exception as e: |
|
return {}, f"Error with {model_key}: {str(e)}" |
|
|
|
def _predict_offset_based(self, address, tokenizer, model): |
|
"""Offset-based prediction for TinyBERT and ModernBERT""" |
|
inputs = tokenizer( |
|
address, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding=True, |
|
max_length=128, |
|
return_offsets_mapping=True |
|
) |
|
|
|
|
|
offset_mapping = inputs.pop("offset_mapping")[0] |
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
predicted_ids = torch.argmax(predictions, dim=-1) |
|
confidence_scores = torch.max(predictions, dim=-1)[0] |
|
|
|
|
|
return self.extract_entities_with_offsets( |
|
address, |
|
predicted_ids[0], |
|
confidence_scores[0], |
|
offset_mapping |
|
) |
|
|
|
def _predict_token_based(self, address, tokenizer, model): |
|
"""Token-based prediction for IndicBERT (SentencePiece)""" |
|
inputs = tokenizer( |
|
address, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding=True, |
|
max_length=128 |
|
) |
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
predicted_ids = torch.argmax(predictions, dim=-1) |
|
confidence_scores = torch.max(predictions, dim=-1)[0] |
|
|
|
|
|
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) |
|
predicted_labels = [self.id2entity.get(str(id.item()), "O") for id in predicted_ids[0]] |
|
confidences = confidence_scores[0].cpu().numpy() |
|
|
|
|
|
return self.group_entities_sentencepiece(tokens, predicted_labels, confidences) |
|
|
|
def extract_entities_with_offsets(self, original_text, predicted_ids, confidences, offset_mapping): |
|
"""Extract entities using offset mapping for accurate text reconstruction""" |
|
entities = {} |
|
current_entity = None |
|
|
|
for i, (pred_id, conf) in enumerate(zip(predicted_ids, confidences)): |
|
if i >= len(offset_mapping): |
|
break |
|
|
|
start, end = offset_mapping[i] |
|
|
|
|
|
if start == end == 0: |
|
continue |
|
|
|
label = self.id2entity.get(str(pred_id.item()), "O") |
|
|
|
if label.startswith("B-"): |
|
|
|
if current_entity: |
|
entity_type = current_entity["type"] |
|
if entity_type not in entities: |
|
entities[entity_type] = [] |
|
entities[entity_type].append({ |
|
"text": current_entity["text"], |
|
"confidence": current_entity["confidence"] |
|
}) |
|
|
|
|
|
entity_type = label[2:] |
|
current_entity = { |
|
"type": entity_type, |
|
"text": original_text[start:end], |
|
"confidence": conf.item(), |
|
"start": start, |
|
"end": end |
|
} |
|
|
|
elif label.startswith("I-") and current_entity: |
|
|
|
entity_type = label[2:] |
|
if entity_type == current_entity["type"]: |
|
|
|
current_entity["text"] = original_text[current_entity["start"]:end] |
|
current_entity["confidence"] = (current_entity["confidence"] + conf.item()) / 2 |
|
current_entity["end"] = end |
|
|
|
elif label == "O" and current_entity: |
|
|
|
entity_type = current_entity["type"] |
|
if entity_type not in entities: |
|
entities[entity_type] = [] |
|
entities[entity_type].append({ |
|
"text": current_entity["text"], |
|
"confidence": current_entity["confidence"] |
|
}) |
|
current_entity = None |
|
|
|
|
|
if current_entity: |
|
entity_type = current_entity["type"] |
|
if entity_type not in entities: |
|
entities[entity_type] = [] |
|
entities[entity_type].append({ |
|
"text": current_entity["text"], |
|
"confidence": current_entity["confidence"] |
|
}) |
|
|
|
return entities |
|
|
|
def group_entities_sentencepiece(self, tokens, labels, confidences): |
|
"""Group entities for SentencePiece tokenization (IndicBERT) with proper text reconstruction""" |
|
entities = {} |
|
current_entity = None |
|
|
|
for i, (token, label, conf) in enumerate(zip(tokens, labels, confidences)): |
|
if token in ["<s>", "</s>", "<pad>", "<unk>"]: |
|
continue |
|
|
|
if label.startswith("B-"): |
|
|
|
if current_entity: |
|
entity_type = current_entity["type"] |
|
if entity_type not in entities: |
|
entities[entity_type] = [] |
|
|
|
|
|
clean_text = self._clean_sentencepiece_text(current_entity["text"]) |
|
entities[entity_type].append({ |
|
"text": clean_text, |
|
"confidence": current_entity["confidence"] |
|
}) |
|
|
|
|
|
entity_type = label[2:] |
|
clean_token = token.replace("β", " ").strip() |
|
current_entity = { |
|
"type": entity_type, |
|
"text": clean_token, |
|
"confidence": conf |
|
} |
|
|
|
elif label.startswith("I-") and current_entity: |
|
|
|
entity_type = label[2:] |
|
if entity_type == current_entity["type"]: |
|
|
|
if token.startswith("β"): |
|
|
|
current_entity["text"] += " " + token.replace("β", "") |
|
else: |
|
|
|
current_entity["text"] += token |
|
current_entity["confidence"] = (current_entity["confidence"] + conf) / 2 |
|
|
|
elif label == "O" and current_entity: |
|
|
|
entity_type = current_entity["type"] |
|
if entity_type not in entities: |
|
entities[entity_type] = [] |
|
|
|
clean_text = self._clean_sentencepiece_text(current_entity["text"]) |
|
entities[entity_type].append({ |
|
"text": clean_text, |
|
"confidence": current_entity["confidence"] |
|
}) |
|
current_entity = None |
|
|
|
|
|
if current_entity: |
|
entity_type = current_entity["type"] |
|
if entity_type not in entities: |
|
entities[entity_type] = [] |
|
|
|
clean_text = self._clean_sentencepiece_text(current_entity["text"]) |
|
entities[entity_type].append({ |
|
"text": clean_text, |
|
"confidence": current_entity["confidence"] |
|
}) |
|
|
|
return entities |
|
|
|
def _clean_sentencepiece_text(self, text): |
|
"""Clean SentencePiece text by removing markers and fixing spacing""" |
|
|
|
clean_text = text.replace("β", " ") |
|
|
|
clean_text = " ".join(clean_text.split()) |
|
|
|
clean_text = clean_text.strip().rstrip(",").strip() |
|
return clean_text |
|
|
|
|
|
print("Initializing Multi-Model Indian Address NER...") |
|
ner_system = MultiModelIndianAddressNER() |
|
print("System ready!") |
|
|
|
def process_address(address_text, selected_model): |
|
"""Process address and return formatted results with selected model""" |
|
if not address_text.strip(): |
|
return "Please enter an address to analyze." |
|
|
|
try: |
|
|
|
entities, model_info = ner_system.predict(address_text, selected_model) |
|
|
|
if not entities: |
|
return f"β No entities found in the provided address.\n\n**{model_info}**" |
|
|
|
|
|
result = f"π **Input Address:** {address_text}\n\n" |
|
result += f"π€ **{model_info}**\n\n" |
|
result += "π·οΈ **Extracted Entities:**\n\n" |
|
|
|
|
|
entity_order = [ |
|
'building_name', 'floor', 'house_details', 'road', |
|
'sub_locality', 'locality', 'landmarks', 'city', |
|
'state', 'country', 'pincode' |
|
] |
|
|
|
displayed_entities = set() |
|
|
|
|
|
for entity_type in entity_order: |
|
if entity_type in entities and entity_type not in displayed_entities: |
|
result += f"**{entity_type.replace('_', ' ').title()}:**\n" |
|
for entity in entities[entity_type]: |
|
confidence = entity['confidence'] |
|
text = entity['text'] |
|
confidence_icon = "π’" if confidence > 0.8 else "π‘" if confidence > 0.6 else "π΄" |
|
result += f" {confidence_icon} {text} (confidence: {confidence:.3f})\n" |
|
result += "\n" |
|
displayed_entities.add(entity_type) |
|
|
|
|
|
for entity_type, entity_list in entities.items(): |
|
if entity_type not in displayed_entities: |
|
result += f"**{entity_type.replace('_', ' ').title()}:**\n" |
|
for entity in entity_list: |
|
confidence = entity['confidence'] |
|
text = entity['text'] |
|
confidence_icon = "π’" if confidence > 0.8 else "π‘" if confidence > 0.6 else "π΄" |
|
result += f" {confidence_icon} {text} (confidence: {confidence:.3f})\n" |
|
result += "\n" |
|
|
|
result += "\n**Legend:**\n" |
|
result += "π’ High confidence (>0.8)\n" |
|
result += "π‘ Medium confidence (0.6-0.8)\n" |
|
result += "π΄ Low confidence (<0.6)\n" |
|
|
|
return result |
|
|
|
except Exception as e: |
|
return f"β Error processing address: {str(e)}" |
|
|
|
def compare_models(address_text): |
|
"""Compare results from all models""" |
|
if not address_text.strip(): |
|
return "Please enter an address to compare models." |
|
|
|
result = f"π **Address:** {address_text}\n\n" |
|
result += "π **Model Comparison:**\n\n" |
|
|
|
for model_key in ner_system.models_config.keys(): |
|
try: |
|
entities, model_info = ner_system.predict(address_text, model_key) |
|
result += f"### {model_key}\n" |
|
result += f"*{ner_system.models_config[model_key]['description']}*\n\n" |
|
|
|
if entities: |
|
entity_count = sum(len(entity_list) for entity_list in entities.values()) |
|
result += f"**Found {entity_count} entities:**\n" |
|
|
|
for entity_type, entity_list in sorted(entities.items()): |
|
for entity in entity_list: |
|
confidence = entity['confidence'] |
|
text = entity['text'] |
|
confidence_icon = "π’" if confidence > 0.8 else "π‘" if confidence > 0.6 else "π΄" |
|
result += f" {confidence_icon} {entity_type}: {text} ({confidence:.3f})\n" |
|
else: |
|
result += "β No entities found\n" |
|
|
|
result += "\n---\n\n" |
|
|
|
except Exception as e: |
|
result += f"### {model_key}\nβ Error: {str(e)}\n\n---\n\n" |
|
|
|
return result |
|
|
|
|
|
sample_addresses = [ |
|
"Shop No 123, Sunshine Apartments, Andheri West, Mumbai, 400058", |
|
"DLF Cyber City, Sector 25, Gurgaon, Haryana", |
|
"Flat 201, MG Road, Bangalore, Karnataka, 560001", |
|
"Phoenix Mall, Kurla West, Mumbai", |
|
"House No 456, Green Park Extension, New Delhi, 110016", |
|
"Office 302, Tech Park, Electronic City, Bangalore, Karnataka, 560100" |
|
] |
|
|
|
|
|
with gr.Blocks(title="Multi-Model Indian Address NER", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# π Multi-Model Indian Address Named Entity Recognition |
|
|
|
Compare different transformer models for extracting components from Indian addresses. Choose between TinyBERT (fast), ModernBERT (modern), and IndicBERT (Indic-optimized). |
|
|
|
**Supported entities:** Building Name, Floor, House Details, Road, Sub-locality, Locality, Landmarks, City, State, Country, Pincode |
|
""") |
|
|
|
with gr.Tab("Single Model Analysis"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
model_dropdown = gr.Dropdown( |
|
choices=list(ner_system.models_config.keys()), |
|
value="TinyBERT", |
|
label="Select Model", |
|
info="Choose which model to use for entity extraction" |
|
) |
|
|
|
address_input = gr.Textbox( |
|
label="Enter Indian Address", |
|
placeholder="e.g., Shop No 123, Sunshine Apartments, Andheri West, Mumbai, 400058", |
|
lines=3, |
|
max_lines=5 |
|
) |
|
|
|
submit_btn = gr.Button("π Extract Entities", variant="primary") |
|
|
|
gr.Markdown("### π Sample Addresses (click to use):") |
|
sample_buttons = [] |
|
for addr in sample_addresses: |
|
btn = gr.Button(addr, size="sm") |
|
btn.click(fn=lambda x=addr: x, outputs=address_input) |
|
sample_buttons.append(btn) |
|
|
|
with gr.Column(scale=1): |
|
output_text = gr.Markdown( |
|
label="Extracted Entities", |
|
value="Select a model, enter an address, and click 'Extract Entities' to see the results." |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=process_address, |
|
inputs=[address_input, model_dropdown], |
|
outputs=output_text |
|
) |
|
|
|
address_input.submit( |
|
fn=process_address, |
|
inputs=[address_input, model_dropdown], |
|
outputs=output_text |
|
) |
|
|
|
with gr.Tab("Model Comparison"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
address_compare = gr.Textbox( |
|
label="Enter Indian Address for Comparison", |
|
placeholder="e.g., Shop No 123, Sunshine Apartments, Andheri West, Mumbai, 400058", |
|
lines=3, |
|
max_lines=5 |
|
) |
|
|
|
compare_btn = gr.Button("π Compare All Models", variant="secondary") |
|
|
|
gr.Markdown("### π Sample Addresses (click to use):") |
|
sample_buttons_compare = [] |
|
for addr in sample_addresses: |
|
btn = gr.Button(addr, size="sm") |
|
btn.click(fn=lambda x=addr: x, outputs=address_compare) |
|
sample_buttons_compare.append(btn) |
|
|
|
with gr.Column(scale=1): |
|
comparison_output = gr.Markdown( |
|
label="Model Comparison Results", |
|
value="Enter an address and click 'Compare All Models' to see how different models perform." |
|
) |
|
|
|
|
|
compare_btn.click( |
|
fn=compare_models, |
|
inputs=address_compare, |
|
outputs=comparison_output |
|
) |
|
|
|
address_compare.submit( |
|
fn=compare_models, |
|
inputs=address_compare, |
|
outputs=comparison_output |
|
) |
|
|
|
with gr.Tab("Model Information"): |
|
gr.Markdown(""" |
|
## π Available Models |
|
|
|
### TinyBERT |
|
- **Base Model**: huawei-noah/TinyBERT_General_6L_768D |
|
- **Model Size**: ~66.4M parameters |
|
- **Advantages**: Fastest inference, lowest memory usage, mobile-friendly |
|
- **Best for**: Real-time applications, edge deployment |
|
|
|
### ModernBERT |
|
- **Base Model**: answerdotai/ModernBERT-base |
|
- **Model Size**: ~150M parameters |
|
- **Advantages**: Latest architectural improvements, balanced performance |
|
- **Best for**: High-accuracy requirements with reasonable speed |
|
|
|
### IndicBERT |
|
- **Base Model**: ai4bharat/indic-bert |
|
- **Model Size**: ~32.9M parameters |
|
- **Advantages**: Optimized for Indian languages and contexts |
|
- **Best for**: Mixed language addresses, regional Indian contexts |
|
|
|
## π― Entity Types Supported |
|
|
|
All models can extract the following entities: |
|
- **Building Name**: Apartment/building names |
|
- **Floor**: Floor numbers and details |
|
- **House Details**: House/flat numbers |
|
- **Road**: Street and road names |
|
- **Sub-locality**: Sector, block details |
|
- **Locality**: Area, neighborhood names |
|
- **Landmarks**: Notable nearby locations |
|
- **City**: City names |
|
- **State**: State names |
|
- **Country**: Country names |
|
- **Pincode**: Postal codes |
|
""") |
|
|
|
gr.Markdown(""" |
|
--- |
|
**Models:** |
|
- [TinyBERT](https://huggingface.co/shiprocket-ai/open-tinybert-indian-address-ner) | |
|
[ModernBERT](https://huggingface.co/shiprocket-ai/open-modernbert-indian-address-ner) | |
|
[IndicBERT](https://huggingface.co/shiprocket-ai/open-indicbert-indian-address-ner) |
|
|
|
**About:** These models are specifically trained on Indian address patterns and can handle various formats and styles common in Indian addresses. |
|
""") |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |