|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
class LlamaAddressCompletion: |
|
def __init__(self): |
|
self.model_name = "shiprocket-ai/open-llama-1b-address-completion" |
|
self.model = None |
|
self.tokenizer = None |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.load_model() |
|
|
|
def load_model(self): |
|
"""Load the Llama model and tokenizer""" |
|
try: |
|
print("Loading Llama 3.2-1B Address Completion model...") |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
trust_remote_code=True |
|
) |
|
|
|
if not torch.cuda.is_available(): |
|
self.model = self.model.to(self.device) |
|
|
|
self.model.eval() |
|
print("β
Model loaded successfully!") |
|
|
|
except Exception as e: |
|
print(f"β Error loading model: {str(e)}") |
|
raise e |
|
|
|
def extract_address_components(self, address, max_new_tokens=150): |
|
"""Extract address components using the model""" |
|
if not address.strip(): |
|
return "Please provide an address to extract components from." |
|
|
|
try: |
|
|
|
prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> |
|
|
|
Extract address components from: {address}<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|
|
|
""" |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
|
|
device = next(self.model.parameters()).device |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
temperature=0.1, |
|
top_p=0.9, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
repetition_penalty=1.05 |
|
) |
|
|
|
|
|
input_length = inputs['input_ids'].shape[1] |
|
generated_tokens = outputs[0][input_length:] |
|
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
return response.strip() |
|
|
|
except Exception as e: |
|
return f"Error processing address: {str(e)}" |
|
|
|
def complete_partial_address(self, partial_address, max_new_tokens=100): |
|
"""Complete a partial address""" |
|
if not partial_address.strip(): |
|
return "Please provide a partial address to complete." |
|
|
|
try: |
|
|
|
prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> |
|
|
|
Complete this partial address: {partial_address}<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|
|
|
""" |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
|
|
device = next(self.model.parameters()).device |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
temperature=0.2, |
|
top_p=0.9, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
repetition_penalty=1.05 |
|
) |
|
|
|
|
|
input_length = inputs['input_ids'].shape[1] |
|
generated_tokens = outputs[0][input_length:] |
|
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
return response.strip() |
|
|
|
except Exception as e: |
|
return f"Error completing address: {str(e)}" |
|
|
|
def standardize_address(self, address, max_new_tokens=150): |
|
"""Standardize an address format""" |
|
if not address.strip(): |
|
return "Please provide an address to standardize." |
|
|
|
try: |
|
|
|
prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> |
|
|
|
Standardize this address into proper format: {address}<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|
|
|
""" |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
|
|
device = next(self.model.parameters()).device |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
temperature=0.1, |
|
top_p=0.9, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
repetition_penalty=1.05 |
|
) |
|
|
|
|
|
input_length = inputs['input_ids'].shape[1] |
|
generated_tokens = outputs[0][input_length:] |
|
response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
return response.strip() |
|
|
|
except Exception as e: |
|
return f"Error standardizing address: {str(e)}" |
|
|
|
|
|
print("Initializing Llama Address Completion system...") |
|
try: |
|
llama_system = LlamaAddressCompletion() |
|
print("System ready!") |
|
except Exception as e: |
|
print(f"Failed to initialize system: {e}") |
|
llama_system = None |
|
|
|
def extract_components_interface(address_text): |
|
"""Interface function for component extraction""" |
|
if llama_system is None: |
|
return "β Model not loaded. Please check the logs." |
|
|
|
result = llama_system.extract_address_components(address_text) |
|
return f"**Input:** {address_text}\n\n**Extracted Components:**\n{result}" |
|
|
|
def complete_address_interface(partial_address): |
|
"""Interface function for address completion""" |
|
if llama_system is None: |
|
return "β Model not loaded. Please check the logs." |
|
|
|
result = llama_system.complete_partial_address(partial_address) |
|
return f"**Partial Address:** {partial_address}\n\n**Completed Address:**\n{result}\n\n*β οΈ Note: This feature has limited training data and results may vary in quality.*" |
|
|
|
def standardize_address_interface(address_text): |
|
"""Interface function for address standardization""" |
|
if llama_system is None: |
|
return "β Model not loaded. Please check the logs." |
|
|
|
result = llama_system.standardize_address(address_text) |
|
return f"**Original:** {address_text}\n\n**Standardized:**\n{result}\n\n*β οΈ Note: This feature has limited training data and results may vary in quality.*" |
|
|
|
|
|
sample_addresses = [ |
|
"C-704, Gayatri Shivam, Thakur Complex, Kandivali East, 400101", |
|
"Villa 141, Geown Oasis, V Kallahalli, Off Sarjapur, Bengaluru, Karnataka, 562125", |
|
"E401 Supertech Icon Indrapam 201301 UP", |
|
"Shop No 123, Sunshine Apartments, Andheri West, Mumbai, 400058", |
|
"Flat 201, MG Road, Bangalore, Karnataka, 560001" |
|
] |
|
|
|
partial_addresses = [ |
|
"C-704, Gayatri Shivam, Thakur Complex", |
|
"Villa 141, Geown Oasis, V Kallahalli", |
|
"E401 Supertech Icon", |
|
"Shop No 123, Sunshine Apartments", |
|
"Flat 201, MG Road, Bangalore" |
|
] |
|
|
|
informal_addresses = [ |
|
"c704 gayatri shivam thakur complex kandivali e 400101", |
|
"villa141 geown oasis vkallahalli off sarjapur blr kar 562125", |
|
"e401 supertech icon indrapam up 201301", |
|
"shop123 sunshine apts andheri w mumbai 400058" |
|
] |
|
|
|
|
|
with gr.Blocks(title="Llama Address Intelligence", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# π¦ Llama 3.2-1B Address Intelligence |
|
|
|
Powered by a fine-tuned Llama 3.2-1B model specialized for Indian address processing. |
|
|
|
**β Best Performance**: Entity extraction from complete addresses |
|
**β οΈ Limited Performance**: Address completion and standardization (limited training data) |
|
|
|
**Model:** [shiprocket-ai/open-llama-1b-address-completion](https://huggingface.co/shiprocket-ai/open-llama-1b-address-completion) |
|
""") |
|
|
|
with gr.Tab("π Extract Components"): |
|
gr.Markdown("β **BEST PERFORMANCE** - Extract structured components from complete addresses") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
extract_input = gr.Textbox( |
|
label="Enter Address", |
|
placeholder="e.g., C-704, Gayatri Shivam, Thakur Complex, Kandivali East, 400101", |
|
lines=3 |
|
) |
|
extract_btn = gr.Button("π Extract Components", variant="primary") |
|
|
|
gr.Markdown("### Sample Addresses:") |
|
extract_samples = [] |
|
for addr in sample_addresses: |
|
btn = gr.Button(addr, size="sm") |
|
btn.click(fn=lambda x=addr: x, outputs=extract_input) |
|
extract_samples.append(btn) |
|
|
|
with gr.Column(scale=1): |
|
extract_output = gr.Markdown( |
|
value="Enter an address and click 'Extract Components' to see structured breakdown." |
|
) |
|
|
|
extract_btn.click( |
|
fn=extract_components_interface, |
|
inputs=extract_input, |
|
outputs=extract_output |
|
) |
|
|
|
extract_input.submit( |
|
fn=extract_components_interface, |
|
inputs=extract_input, |
|
outputs=extract_output |
|
) |
|
|
|
with gr.Tab("β¨ Complete Address"): |
|
gr.Markdown("β οΈ **EXPERIMENTAL** - Complete partial addresses (limited training data - results may vary)") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
complete_input = gr.Textbox( |
|
label="Enter Partial Address", |
|
placeholder="e.g., C-704, Gayatri Shivam, Thakur Complex", |
|
lines=3 |
|
) |
|
complete_btn = gr.Button("π Complete Address", variant="primary") |
|
|
|
gr.Markdown("### Sample Partial Addresses:") |
|
complete_samples = [] |
|
for addr in partial_addresses: |
|
btn = gr.Button(addr, size="sm") |
|
btn.click(fn=lambda x=addr: x, outputs=complete_input) |
|
complete_samples.append(btn) |
|
|
|
with gr.Column(scale=1): |
|
complete_output = gr.Markdown( |
|
value="Enter a partial address and click 'Complete Address' to see the AI completion." |
|
) |
|
|
|
complete_btn.click( |
|
fn=complete_address_interface, |
|
inputs=complete_input, |
|
outputs=complete_output |
|
) |
|
|
|
complete_input.submit( |
|
fn=complete_address_interface, |
|
inputs=complete_input, |
|
outputs=complete_output |
|
) |
|
|
|
with gr.Tab("π Standardize Format"): |
|
gr.Markdown("β οΈ **EXPERIMENTAL** - Convert informal addresses to standardized format (limited training data - results may vary)") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
standardize_input = gr.Textbox( |
|
label="Enter Informal Address", |
|
placeholder="e.g., c704 gayatri shivam thakur complex kandivali e 400101", |
|
lines=3 |
|
) |
|
standardize_btn = gr.Button("π Standardize Format", variant="primary") |
|
|
|
gr.Markdown("### Sample Informal Addresses:") |
|
standardize_samples = [] |
|
for addr in informal_addresses: |
|
btn = gr.Button(addr, size="sm") |
|
btn.click(fn=lambda x=addr: x, outputs=standardize_input) |
|
standardize_samples.append(btn) |
|
|
|
with gr.Column(scale=1): |
|
standardize_output = gr.Markdown( |
|
value="Enter an informal address and click 'Standardize Format' to see the cleaned version." |
|
) |
|
|
|
standardize_btn.click( |
|
fn=standardize_address_interface, |
|
inputs=standardize_input, |
|
outputs=standardize_output |
|
) |
|
|
|
standardize_input.submit( |
|
fn=standardize_address_interface, |
|
inputs=standardize_input, |
|
outputs=standardize_output |
|
) |
|
|
|
with gr.Tab("βΉοΈ Model Information"): |
|
gr.Markdown(""" |
|
## π¦ About Llama 3.2-1B Address Completion |
|
|
|
### Model Specifications |
|
- **Base Model**: meta-llama/Llama-3.2-1B-Instruct |
|
- **Parameters**: 1.24B parameters |
|
- **Model Size**: ~2.47GB |
|
- **Architecture**: Causal Language Model (Autoregressive) |
|
- **Max Context**: 131,072 tokens |
|
- **Precision**: FP16 for GPU, FP32 for CPU |
|
|
|
### Key Features |
|
- **Lightweight**: Only 1B parameters for fast inference |
|
- **Specialized**: Fine-tuned specifically for Indian addresses |
|
- **Versatile**: Handles extraction, completion, and standardization |
|
- **Efficient**: Optimized for real-time applications |
|
- **Context-Aware**: Understands relationships between address components |
|
|
|
### Supported Address Components |
|
- **Building Names**: Apartments, complexes, towers, malls |
|
- **Localities**: Areas, neighborhoods, sectors |
|
- **Pincodes**: 6-digit Indian postal codes |
|
- **Cities**: Major and minor Indian cities |
|
- **States**: All Indian states and union territories |
|
- **Sub-localities**: Sectors, phases, blocks |
|
- **Road Names**: Streets, lanes, main roads |
|
- **Landmarks**: Notable reference points |
|
|
|
### Performance Notes |
|
- **β Entity Extraction**: Excellent performance - primary use case |
|
- **β οΈ Address Completion**: Limited training data - experimental feature |
|
- **β οΈ Format Standardization**: Limited training data - experimental feature |
|
|
|
**Recommendation**: Use this model primarily for address component extraction. |
|
|
|
### Use Cases |
|
- **E-commerce**: Auto-complete checkout addresses |
|
- **Forms**: Intelligent address suggestions |
|
- **Data Cleaning**: Standardize legacy address databases |
|
- **Mobile Apps**: On-device address processing |
|
- **APIs**: Real-time address validation services |
|
|
|
### Performance Tips |
|
- Use lower temperatures (0.1-0.3) for factual outputs |
|
- Keep prompts under 512 tokens for optimal speed |
|
- Process in batches for high-throughput scenarios |
|
- Works best with Llama chat format prompts |
|
""") |
|
|
|
gr.Markdown(""" |
|
--- |
|
**Powered by:** [Llama 3.2-1B Address Completion](https://huggingface.co/shiprocket-ai/open-llama-1b-address-completion) | |
|
**License:** Apache 2.0 | |
|
**Developed by:** Shiprocket AI Team |
|
|
|
This model demonstrates the power of lightweight LLMs for specialized address intelligence tasks. |
|
""") |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |