parth parekh
added bert again better outputs
a404f18
raw
history blame
1.84 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.nn.functional import softmax
app = FastAPI(
title="Contact Information Detection API",
description="API for detecting contact information in text",
version="1.0.0",
docs_url="/"
)
class ContactDetector:
def __init__(self):
cache_dir = "/app/model_cache"
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=cache_dir)
self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2, cache_dir=cache_dir)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
def detect_contact_info(self, text):
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = softmax(outputs.logits, dim=1)
return probabilities[0][1].item() # Probability of contact info
def is_contact_info(self, text, threshold=0.45):
return self.detect_contact_info(text) > threshold
detector = ContactDetector()
class TextInput(BaseModel):
text: str
@app.post("/detect_contact", summary="Detect contact information in text")
async def detect_contact(input: TextInput):
try:
probability = detector.detect_contact_info(input.text)
is_contact = detector.is_contact_info(input.text)
return {
"text": input.text,
"contact_probability": probability,
"is_contact_info": is_contact
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))