|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
import torch |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
from torch.nn.functional import softmax |
|
|
|
app = FastAPI( |
|
title="Contact Information Detection API", |
|
description="API for detecting contact information in text", |
|
version="1.0.0", |
|
docs_url="/" |
|
) |
|
|
|
class ContactDetector: |
|
def __init__(self): |
|
cache_dir = "/app/model_cache" |
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=cache_dir) |
|
self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2, cache_dir=cache_dir) |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
def detect_contact_info(self, text): |
|
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device) |
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
probabilities = softmax(outputs.logits, dim=1) |
|
return probabilities[0][1].item() |
|
|
|
def is_contact_info(self, text, threshold=0.45): |
|
return self.detect_contact_info(text) > threshold |
|
|
|
detector = ContactDetector() |
|
|
|
class TextInput(BaseModel): |
|
text: str |
|
|
|
@app.post("/detect_contact", summary="Detect contact information in text") |
|
async def detect_contact(input: TextInput): |
|
try: |
|
probability = detector.detect_contact_info(input.text) |
|
is_contact = detector.is_contact_info(input.text) |
|
return { |
|
"text": input.text, |
|
"contact_probability": probability, |
|
"is_contact_info": is_contact |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|