parth parekh
added basic distilbart and it should most probablly work
1e494e3
raw
history blame
1.79 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from torch.nn.functional import softmax
app = FastAPI(
title="Contact Information Detection API",
description="API for detecting contact information in text",
version="1.0.0",
docs_url="/"
)
class ContactDetector:
def __init__(self):
self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
self.model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
def detect_contact_info(self, text):
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = softmax(outputs.logits, dim=1)
return probabilities[0][1].item() # Probability of contact info
def is_contact_info(self, text, threshold=0.5):
return self.detect_contact_info(text) > threshold
detector = ContactDetector()
class TextInput(BaseModel):
text: str
@app.post("/detect_contact", summary="Detect contact information in text")
async def detect_contact(input: TextInput):
try:
probability = detector.detect_contact_info(input.text)
is_contact = detector.is_contact_info(input.text)
return {
"text": input.text,
"contact_probability": probability,
"is_contact_info": is_contact
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))