import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import warnings warnings.filterwarnings("ignore") class LlamaAddressCompletion: def __init__(self): self.model_name = "shiprocket-ai/open-llama-1b-address-completion" self.model = None self.tokenizer = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.load_model() def load_model(self): """Load the Llama model and tokenizer""" try: print("Loading Llama 3.2-1B Address Completion model...") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) # Load model with appropriate settings for the space self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True ) if not torch.cuda.is_available(): self.model = self.model.to(self.device) self.model.eval() print("✅ Model loaded successfully!") except Exception as e: print(f"❌ Error loading model: {str(e)}") raise e def extract_address_components(self, address, max_new_tokens=150): """Extract address components using the model""" if not address.strip(): return "Please provide an address to extract components from." try: # Format prompt for Llama 3.2-1B-Instruct prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> Extract address components from: {address}<|eot_id|><|start_header_id|>assistant<|end_header_id|> """ # Tokenize inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) # Move inputs to the same device as the model device = next(self.model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=0.1, top_p=0.9, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, repetition_penalty=1.05 ) # Decode only the new tokens input_length = inputs['input_ids'].shape[1] generated_tokens = outputs[0][input_length:] response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) return response.strip() except Exception as e: return f"Error processing address: {str(e)}" def complete_partial_address(self, partial_address, max_new_tokens=100): """Complete a partial address""" if not partial_address.strip(): return "Please provide a partial address to complete." try: # Format prompt for address completion prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> Complete this partial address: {partial_address}<|eot_id|><|start_header_id|>assistant<|end_header_id|> """ # Tokenize inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) # Move inputs to the same device as the model device = next(self.model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=0.2, top_p=0.9, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, repetition_penalty=1.05 ) # Decode only the new tokens input_length = inputs['input_ids'].shape[1] generated_tokens = outputs[0][input_length:] response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) return response.strip() except Exception as e: return f"Error completing address: {str(e)}" def standardize_address(self, address, max_new_tokens=150): """Standardize an address format""" if not address.strip(): return "Please provide an address to standardize." try: # Format prompt for address standardization prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> Standardize this address into proper format: {address}<|eot_id|><|start_header_id|>assistant<|end_header_id|> """ # Tokenize inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) # Move inputs to the same device as the model device = next(self.model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=0.1, top_p=0.9, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, repetition_penalty=1.05 ) # Decode only the new tokens input_length = inputs['input_ids'].shape[1] generated_tokens = outputs[0][input_length:] response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) return response.strip() except Exception as e: return f"Error standardizing address: {str(e)}" # Initialize the model print("Initializing Llama Address Completion system...") try: llama_system = LlamaAddressCompletion() print("System ready!") except Exception as e: print(f"Failed to initialize system: {e}") llama_system = None def extract_components_interface(address_text): """Interface function for component extraction""" if llama_system is None: return "❌ Model not loaded. Please check the logs." result = llama_system.extract_address_components(address_text) return f"**Input:** {address_text}\n\n**Extracted Components:**\n{result}" def complete_address_interface(partial_address): """Interface function for address completion""" if llama_system is None: return "❌ Model not loaded. Please check the logs." result = llama_system.complete_partial_address(partial_address) return f"**Partial Address:** {partial_address}\n\n**Completed Address:**\n{result}\n\n*⚠️ Note: This feature has limited training data and results may vary in quality.*" def standardize_address_interface(address_text): """Interface function for address standardization""" if llama_system is None: return "❌ Model not loaded. Please check the logs." result = llama_system.standardize_address(address_text) return f"**Original:** {address_text}\n\n**Standardized:**\n{result}\n\n*⚠️ Note: This feature has limited training data and results may vary in quality.*" # Sample data sample_addresses = [ "C-704, Gayatri Shivam, Thakur Complex, Kandivali East, 400101", "Villa 141, Geown Oasis, V Kallahalli, Off Sarjapur, Bengaluru, Karnataka, 562125", "E401 Supertech Icon Indrapam 201301 UP", "Shop No 123, Sunshine Apartments, Andheri West, Mumbai, 400058", "Flat 201, MG Road, Bangalore, Karnataka, 560001" ] partial_addresses = [ "C-704, Gayatri Shivam, Thakur Complex", "Villa 141, Geown Oasis, V Kallahalli", "E401 Supertech Icon", "Shop No 123, Sunshine Apartments", "Flat 201, MG Road, Bangalore" ] informal_addresses = [ "c704 gayatri shivam thakur complex kandivali e 400101", "villa141 geown oasis vkallahalli off sarjapur blr kar 562125", "e401 supertech icon indrapam up 201301", "shop123 sunshine apts andheri w mumbai 400058" ] # Create Gradio interface with gr.Blocks(title="Llama Address Intelligence", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🦙 Llama 3.2-1B Address Intelligence Powered by a fine-tuned Llama 3.2-1B model specialized for Indian address processing. **⭐ Best Performance**: Entity extraction from complete addresses **⚠️ Limited Performance**: Address completion and standardization (limited training data) **Model:** [shiprocket-ai/open-llama-1b-address-completion](https://huggingface.co/shiprocket-ai/open-llama-1b-address-completion) """) with gr.Tab("📋 Extract Components"): gr.Markdown("⭐ **BEST PERFORMANCE** - Extract structured components from complete addresses") with gr.Row(): with gr.Column(scale=1): extract_input = gr.Textbox( label="Enter Address", placeholder="e.g., C-704, Gayatri Shivam, Thakur Complex, Kandivali East, 400101", lines=3 ) extract_btn = gr.Button("🔍 Extract Components", variant="primary") gr.Markdown("### Sample Addresses:") extract_samples = [] for addr in sample_addresses: btn = gr.Button(addr, size="sm") btn.click(fn=lambda x=addr: x, outputs=extract_input) extract_samples.append(btn) with gr.Column(scale=1): extract_output = gr.Markdown( value="Enter an address and click 'Extract Components' to see structured breakdown." ) extract_btn.click( fn=extract_components_interface, inputs=extract_input, outputs=extract_output ) extract_input.submit( fn=extract_components_interface, inputs=extract_input, outputs=extract_output ) with gr.Tab("✨ Complete Address"): gr.Markdown("⚠️ **EXPERIMENTAL** - Complete partial addresses (limited training data - results may vary)") with gr.Row(): with gr.Column(scale=1): complete_input = gr.Textbox( label="Enter Partial Address", placeholder="e.g., C-704, Gayatri Shivam, Thakur Complex", lines=3 ) complete_btn = gr.Button("🚀 Complete Address", variant="primary") gr.Markdown("### Sample Partial Addresses:") complete_samples = [] for addr in partial_addresses: btn = gr.Button(addr, size="sm") btn.click(fn=lambda x=addr: x, outputs=complete_input) complete_samples.append(btn) with gr.Column(scale=1): complete_output = gr.Markdown( value="Enter a partial address and click 'Complete Address' to see the AI completion." ) complete_btn.click( fn=complete_address_interface, inputs=complete_input, outputs=complete_output ) complete_input.submit( fn=complete_address_interface, inputs=complete_input, outputs=complete_output ) with gr.Tab("📐 Standardize Format"): gr.Markdown("⚠️ **EXPERIMENTAL** - Convert informal addresses to standardized format (limited training data - results may vary)") with gr.Row(): with gr.Column(scale=1): standardize_input = gr.Textbox( label="Enter Informal Address", placeholder="e.g., c704 gayatri shivam thakur complex kandivali e 400101", lines=3 ) standardize_btn = gr.Button("📏 Standardize Format", variant="primary") gr.Markdown("### Sample Informal Addresses:") standardize_samples = [] for addr in informal_addresses: btn = gr.Button(addr, size="sm") btn.click(fn=lambda x=addr: x, outputs=standardize_input) standardize_samples.append(btn) with gr.Column(scale=1): standardize_output = gr.Markdown( value="Enter an informal address and click 'Standardize Format' to see the cleaned version." ) standardize_btn.click( fn=standardize_address_interface, inputs=standardize_input, outputs=standardize_output ) standardize_input.submit( fn=standardize_address_interface, inputs=standardize_input, outputs=standardize_output ) with gr.Tab("ℹ️ Model Information"): gr.Markdown(""" ## 🦙 About Llama 3.2-1B Address Completion ### Model Specifications - **Base Model**: meta-llama/Llama-3.2-1B-Instruct - **Parameters**: 1.24B parameters - **Model Size**: ~2.47GB - **Architecture**: Causal Language Model (Autoregressive) - **Max Context**: 131,072 tokens - **Precision**: FP16 for GPU, FP32 for CPU ### Key Features - **Lightweight**: Only 1B parameters for fast inference - **Specialized**: Fine-tuned specifically for Indian addresses - **Versatile**: Handles extraction, completion, and standardization - **Efficient**: Optimized for real-time applications - **Context-Aware**: Understands relationships between address components ### Supported Address Components - **Building Names**: Apartments, complexes, towers, malls - **Localities**: Areas, neighborhoods, sectors - **Pincodes**: 6-digit Indian postal codes - **Cities**: Major and minor Indian cities - **States**: All Indian states and union territories - **Sub-localities**: Sectors, phases, blocks - **Road Names**: Streets, lanes, main roads - **Landmarks**: Notable reference points ### Performance Notes - **⭐ Entity Extraction**: Excellent performance - primary use case - **⚠️ Address Completion**: Limited training data - experimental feature - **⚠️ Format Standardization**: Limited training data - experimental feature **Recommendation**: Use this model primarily for address component extraction. ### Use Cases - **E-commerce**: Auto-complete checkout addresses - **Forms**: Intelligent address suggestions - **Data Cleaning**: Standardize legacy address databases - **Mobile Apps**: On-device address processing - **APIs**: Real-time address validation services ### Performance Tips - Use lower temperatures (0.1-0.3) for factual outputs - Keep prompts under 512 tokens for optimal speed - Process in batches for high-throughput scenarios - Works best with Llama chat format prompts """) gr.Markdown(""" --- **Powered by:** [Llama 3.2-1B Address Completion](https://huggingface.co/shiprocket-ai/open-llama-1b-address-completion) | **License:** Apache 2.0 | **Developed by:** Shiprocket AI Team This model demonstrates the power of lightweight LLMs for specialized address intelligence tasks. """) if __name__ == "__main__": demo.launch()