from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from transformers import DistilBertTokenizer, DistilBertForSequenceClassification from torch.nn.functional import softmax app = FastAPI( title="Contact Information Detection API", description="API for detecting contact information in text", version="1.0.0", docs_url="/" ) class ContactDetector: def __init__(self): cache_dir = "/app/model_cache" self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', cache_dir=cache_dir) self.model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2, cache_dir=cache_dir) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model.to(self.device) self.model.eval() def detect_contact_info(self, text): inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) probabilities = softmax(outputs.logits, dim=1) return probabilities[0][1].item() # Probability of contact info def is_contact_info(self, text, threshold=0.5): return self.detect_contact_info(text) > threshold detector = ContactDetector() class TextInput(BaseModel): text: str @app.post("/detect_contact", summary="Detect contact information in text") async def detect_contact(input: TextInput): try: probability = detector.detect_contact_info(input.text) is_contact = detector.is_contact_info(input.text) return { "text": input.text, "contact_probability": probability, "is_contact_info": is_contact } except Exception as e: raise HTTPException(status_code=500, detail=str(e))