sajalmadan0909's picture
Update app.py
6952bc9 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")
class LlamaAddressCompletion:
def __init__(self):
self.model_name = "shiprocket-ai/open-llama-1b-address-completion"
self.model = None
self.tokenizer = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.load_model()
def load_model(self):
"""Load the Llama model and tokenizer"""
try:
print("Loading Llama 3.2-1B Address Completion model...")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# Load model with appropriate settings for the space
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
if not torch.cuda.is_available():
self.model = self.model.to(self.device)
self.model.eval()
print("βœ… Model loaded successfully!")
except Exception as e:
print(f"❌ Error loading model: {str(e)}")
raise e
def extract_address_components(self, address, max_new_tokens=150):
"""Extract address components using the model"""
if not address.strip():
return "Please provide an address to extract components from."
try:
# Format prompt for Llama 3.2-1B-Instruct
prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Extract address components from: {address}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
# Tokenize
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
# Move inputs to the same device as the model
device = next(self.model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.1,
top_p=0.9,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.05
)
# Decode only the new tokens
input_length = inputs['input_ids'].shape[1]
generated_tokens = outputs[0][input_length:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip()
except Exception as e:
return f"Error processing address: {str(e)}"
def complete_partial_address(self, partial_address, max_new_tokens=100):
"""Complete a partial address"""
if not partial_address.strip():
return "Please provide a partial address to complete."
try:
# Format prompt for address completion
prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Complete this partial address: {partial_address}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
# Tokenize
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
# Move inputs to the same device as the model
device = next(self.model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.2,
top_p=0.9,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.05
)
# Decode only the new tokens
input_length = inputs['input_ids'].shape[1]
generated_tokens = outputs[0][input_length:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip()
except Exception as e:
return f"Error completing address: {str(e)}"
def standardize_address(self, address, max_new_tokens=150):
"""Standardize an address format"""
if not address.strip():
return "Please provide an address to standardize."
try:
# Format prompt for address standardization
prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Standardize this address into proper format: {address}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
# Tokenize
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
# Move inputs to the same device as the model
device = next(self.model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=0.1,
top_p=0.9,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.05
)
# Decode only the new tokens
input_length = inputs['input_ids'].shape[1]
generated_tokens = outputs[0][input_length:]
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return response.strip()
except Exception as e:
return f"Error standardizing address: {str(e)}"
# Initialize the model
print("Initializing Llama Address Completion system...")
try:
llama_system = LlamaAddressCompletion()
print("System ready!")
except Exception as e:
print(f"Failed to initialize system: {e}")
llama_system = None
def extract_components_interface(address_text):
"""Interface function for component extraction"""
if llama_system is None:
return "❌ Model not loaded. Please check the logs."
result = llama_system.extract_address_components(address_text)
return f"**Input:** {address_text}\n\n**Extracted Components:**\n{result}"
def complete_address_interface(partial_address):
"""Interface function for address completion"""
if llama_system is None:
return "❌ Model not loaded. Please check the logs."
result = llama_system.complete_partial_address(partial_address)
return f"**Partial Address:** {partial_address}\n\n**Completed Address:**\n{result}\n\n*⚠️ Note: This feature has limited training data and results may vary in quality.*"
def standardize_address_interface(address_text):
"""Interface function for address standardization"""
if llama_system is None:
return "❌ Model not loaded. Please check the logs."
result = llama_system.standardize_address(address_text)
return f"**Original:** {address_text}\n\n**Standardized:**\n{result}\n\n*⚠️ Note: This feature has limited training data and results may vary in quality.*"
# Sample data
sample_addresses = [
"C-704, Gayatri Shivam, Thakur Complex, Kandivali East, 400101",
"Villa 141, Geown Oasis, V Kallahalli, Off Sarjapur, Bengaluru, Karnataka, 562125",
"E401 Supertech Icon Indrapam 201301 UP",
"Shop No 123, Sunshine Apartments, Andheri West, Mumbai, 400058",
"Flat 201, MG Road, Bangalore, Karnataka, 560001"
]
partial_addresses = [
"C-704, Gayatri Shivam, Thakur Complex",
"Villa 141, Geown Oasis, V Kallahalli",
"E401 Supertech Icon",
"Shop No 123, Sunshine Apartments",
"Flat 201, MG Road, Bangalore"
]
informal_addresses = [
"c704 gayatri shivam thakur complex kandivali e 400101",
"villa141 geown oasis vkallahalli off sarjapur blr kar 562125",
"e401 supertech icon indrapam up 201301",
"shop123 sunshine apts andheri w mumbai 400058"
]
# Create Gradio interface
with gr.Blocks(title="Llama Address Intelligence", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ¦™ Llama 3.2-1B Address Intelligence
Powered by a fine-tuned Llama 3.2-1B model specialized for Indian address processing.
**⭐ Best Performance**: Entity extraction from complete addresses
**⚠️ Limited Performance**: Address completion and standardization (limited training data)
**Model:** [shiprocket-ai/open-llama-1b-address-completion](https://huggingface.co/shiprocket-ai/open-llama-1b-address-completion)
""")
with gr.Tab("πŸ“‹ Extract Components"):
gr.Markdown("⭐ **BEST PERFORMANCE** - Extract structured components from complete addresses")
with gr.Row():
with gr.Column(scale=1):
extract_input = gr.Textbox(
label="Enter Address",
placeholder="e.g., C-704, Gayatri Shivam, Thakur Complex, Kandivali East, 400101",
lines=3
)
extract_btn = gr.Button("πŸ” Extract Components", variant="primary")
gr.Markdown("### Sample Addresses:")
extract_samples = []
for addr in sample_addresses:
btn = gr.Button(addr, size="sm")
btn.click(fn=lambda x=addr: x, outputs=extract_input)
extract_samples.append(btn)
with gr.Column(scale=1):
extract_output = gr.Markdown(
value="Enter an address and click 'Extract Components' to see structured breakdown."
)
extract_btn.click(
fn=extract_components_interface,
inputs=extract_input,
outputs=extract_output
)
extract_input.submit(
fn=extract_components_interface,
inputs=extract_input,
outputs=extract_output
)
with gr.Tab("✨ Complete Address"):
gr.Markdown("⚠️ **EXPERIMENTAL** - Complete partial addresses (limited training data - results may vary)")
with gr.Row():
with gr.Column(scale=1):
complete_input = gr.Textbox(
label="Enter Partial Address",
placeholder="e.g., C-704, Gayatri Shivam, Thakur Complex",
lines=3
)
complete_btn = gr.Button("πŸš€ Complete Address", variant="primary")
gr.Markdown("### Sample Partial Addresses:")
complete_samples = []
for addr in partial_addresses:
btn = gr.Button(addr, size="sm")
btn.click(fn=lambda x=addr: x, outputs=complete_input)
complete_samples.append(btn)
with gr.Column(scale=1):
complete_output = gr.Markdown(
value="Enter a partial address and click 'Complete Address' to see the AI completion."
)
complete_btn.click(
fn=complete_address_interface,
inputs=complete_input,
outputs=complete_output
)
complete_input.submit(
fn=complete_address_interface,
inputs=complete_input,
outputs=complete_output
)
with gr.Tab("πŸ“ Standardize Format"):
gr.Markdown("⚠️ **EXPERIMENTAL** - Convert informal addresses to standardized format (limited training data - results may vary)")
with gr.Row():
with gr.Column(scale=1):
standardize_input = gr.Textbox(
label="Enter Informal Address",
placeholder="e.g., c704 gayatri shivam thakur complex kandivali e 400101",
lines=3
)
standardize_btn = gr.Button("πŸ“ Standardize Format", variant="primary")
gr.Markdown("### Sample Informal Addresses:")
standardize_samples = []
for addr in informal_addresses:
btn = gr.Button(addr, size="sm")
btn.click(fn=lambda x=addr: x, outputs=standardize_input)
standardize_samples.append(btn)
with gr.Column(scale=1):
standardize_output = gr.Markdown(
value="Enter an informal address and click 'Standardize Format' to see the cleaned version."
)
standardize_btn.click(
fn=standardize_address_interface,
inputs=standardize_input,
outputs=standardize_output
)
standardize_input.submit(
fn=standardize_address_interface,
inputs=standardize_input,
outputs=standardize_output
)
with gr.Tab("ℹ️ Model Information"):
gr.Markdown("""
## πŸ¦™ About Llama 3.2-1B Address Completion
### Model Specifications
- **Base Model**: meta-llama/Llama-3.2-1B-Instruct
- **Parameters**: 1.24B parameters
- **Model Size**: ~2.47GB
- **Architecture**: Causal Language Model (Autoregressive)
- **Max Context**: 131,072 tokens
- **Precision**: FP16 for GPU, FP32 for CPU
### Key Features
- **Lightweight**: Only 1B parameters for fast inference
- **Specialized**: Fine-tuned specifically for Indian addresses
- **Versatile**: Handles extraction, completion, and standardization
- **Efficient**: Optimized for real-time applications
- **Context-Aware**: Understands relationships between address components
### Supported Address Components
- **Building Names**: Apartments, complexes, towers, malls
- **Localities**: Areas, neighborhoods, sectors
- **Pincodes**: 6-digit Indian postal codes
- **Cities**: Major and minor Indian cities
- **States**: All Indian states and union territories
- **Sub-localities**: Sectors, phases, blocks
- **Road Names**: Streets, lanes, main roads
- **Landmarks**: Notable reference points
### Performance Notes
- **⭐ Entity Extraction**: Excellent performance - primary use case
- **⚠️ Address Completion**: Limited training data - experimental feature
- **⚠️ Format Standardization**: Limited training data - experimental feature
**Recommendation**: Use this model primarily for address component extraction.
### Use Cases
- **E-commerce**: Auto-complete checkout addresses
- **Forms**: Intelligent address suggestions
- **Data Cleaning**: Standardize legacy address databases
- **Mobile Apps**: On-device address processing
- **APIs**: Real-time address validation services
### Performance Tips
- Use lower temperatures (0.1-0.3) for factual outputs
- Keep prompts under 512 tokens for optimal speed
- Process in batches for high-throughput scenarios
- Works best with Llama chat format prompts
""")
gr.Markdown("""
---
**Powered by:** [Llama 3.2-1B Address Completion](https://huggingface.co/shiprocket-ai/open-llama-1b-address-completion) |
**License:** Apache 2.0 |
**Developed by:** Shiprocket AI Team
This model demonstrates the power of lightweight LLMs for specialized address intelligence tasks.
""")
if __name__ == "__main__":
demo.launch()