parth parekh commited on
Commit
1e494e3
·
1 Parent(s): 2303961

added basic distilbart and it should most probablly work

Browse files
Files changed (3) hide show
  1. Dockerfile +20 -0
  2. app.py +48 -0
  3. requirements.txt +4 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ libglib2.0-0 \
7
+ libsm6 \
8
+ libxext6 \
9
+ libxrender-dev \
10
+ libgl1-mesa-glx \
11
+ wget \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ COPY requirements.txt .
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ COPY . .
18
+
19
+
20
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "4"]
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
5
+ from torch.nn.functional import softmax
6
+
7
+ app = FastAPI(
8
+ title="Contact Information Detection API",
9
+ description="API for detecting contact information in text",
10
+ version="1.0.0",
11
+ docs_url="/"
12
+ )
13
+
14
+ class ContactDetector:
15
+ def __init__(self):
16
+ self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
17
+ self.model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
18
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ self.model.to(self.device)
20
+ self.model.eval()
21
+
22
+ def detect_contact_info(self, text):
23
+ inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
24
+ with torch.no_grad():
25
+ outputs = self.model(**inputs)
26
+ probabilities = softmax(outputs.logits, dim=1)
27
+ return probabilities[0][1].item() # Probability of contact info
28
+
29
+ def is_contact_info(self, text, threshold=0.5):
30
+ return self.detect_contact_info(text) > threshold
31
+
32
+ detector = ContactDetector()
33
+
34
+ class TextInput(BaseModel):
35
+ text: str
36
+
37
+ @app.post("/detect_contact", summary="Detect contact information in text")
38
+ async def detect_contact(input: TextInput):
39
+ try:
40
+ probability = detector.detect_contact_info(input.text)
41
+ is_contact = detector.is_contact_info(input.text)
42
+ return {
43
+ "text": input.text,
44
+ "contact_probability": probability,
45
+ "is_contact_info": is_contact
46
+ }
47
+ except Exception as e:
48
+ raise HTTPException(status_code=500, detail=str(e))
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fastapi==0.68.0
2
+ uvicorn==0.15.0
3
+ torch==2.4.1
4
+ transformers==4.10.0