parth parekh
le lode
645ea59
raw
history blame
2.99 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from torch.nn.functional import softmax
import re
app = FastAPI(
title="Contact Information Detection API",
description="API for detecting contact information in text",
version="1.0.0",
docs_url="/"
)
class ContactDetector:
def __init__(self):
cache_dir = "/app/model_cache"
self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', cache_dir=cache_dir)
self.model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2, cache_dir=cache_dir)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
def detect_contact_info(self, text):
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = softmax(outputs.logits, dim=1)
return probabilities[0][1].item() # Probability of contact info
def is_contact_info(self, text, threshold=0.45):
return self.detect_contact_info(text) > threshold
detector = ContactDetector()
class TextInput(BaseModel):
text: str
def check_regex_patterns(text):
patterns = [
r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', # Email
r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', # Phone number
r'\b\d{5}(?:[-\s]\d{4})?\b', # ZIP code
r'\b\d+\s+[\w\s]+(?:street|st|avenue|ave|road|rd|highway|hwy|square|sq|trail|trl|drive|dr|court|ct|park|parkway|pkwy|circle|cir|boulevard|blvd)\b\s*(?:[a-z]+\s*\d{1,3})?(?:,\s*(?:apt|bldg|dept|fl|hngr|lot|pier|rm|ste|unit|#)\s*[a-z0-9-]+)?(?:,\s*[a-z]+\s*[a-z]{2}\s*\d{5}(?:-\d{4})?)?', # Street address
r'(?:http|https)://(?:www\.)?[a-zA-Z0-9-]+\.[a-zA-Z]{2,}(?:/[^\s]*)?' # Website URL
]
for pattern in patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
@app.post("/detect_contact", summary="Detect contact information in text")
async def detect_contact(input: TextInput):
try:
# First, check with regex patterns
if check_regex_patterns(input.text):
return {
"text": input.text,
"contact_probability": 1.0,
"is_contact_info": True,
"method": "regex"
}
# If no regex patterns match, use the model
probability = detector.detect_contact_info(input.text)
is_contact = detector.is_contact_info(input.text)
return {
"text": input.text,
"contact_probability": probability,
"is_contact_info": is_contact,
"method": "model"
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))