File size: 1,836 Bytes
1e494e3 a404f18 1e494e3 dfb2eec a404f18 1e494e3 a404f18 1e494e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.nn.functional import softmax
app = FastAPI(
title="Contact Information Detection API",
description="API for detecting contact information in text",
version="1.0.0",
docs_url="/"
)
class ContactDetector:
def __init__(self):
cache_dir = "/app/model_cache"
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=cache_dir)
self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2, cache_dir=cache_dir)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
def detect_contact_info(self, text):
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = softmax(outputs.logits, dim=1)
return probabilities[0][1].item() # Probability of contact info
def is_contact_info(self, text, threshold=0.45):
return self.detect_contact_info(text) > threshold
detector = ContactDetector()
class TextInput(BaseModel):
text: str
@app.post("/detect_contact", summary="Detect contact information in text")
async def detect_contact(input: TextInput):
try:
probability = detector.detect_contact_info(input.text)
is_contact = detector.is_contact_info(input.text)
return {
"text": input.text,
"contact_probability": probability,
"is_contact_info": is_contact
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|