parth parekh
commited on
Commit
·
3cfd7e3
1
Parent(s):
645ea59
added new model from xxparthparekhxx/ContactShieldAI
Browse files- app.py +6 -25
- contact_sharing_epoch_1.pth +3 -0
- predictor.py +103 -0
app.py
CHANGED
@@ -1,37 +1,17 @@
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
import torch
|
4 |
-
from transformers import RobertaTokenizer, RobertaForSequenceClassification
|
5 |
from torch.nn.functional import softmax
|
6 |
import re
|
|
|
7 |
|
8 |
app = FastAPI(
|
9 |
title="Contact Information Detection API",
|
10 |
-
description="API for detecting contact information in text",
|
11 |
version="1.0.0",
|
12 |
docs_url="/"
|
13 |
)
|
14 |
|
15 |
-
class ContactDetector:
|
16 |
-
def __init__(self):
|
17 |
-
cache_dir = "/app/model_cache"
|
18 |
-
self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', cache_dir=cache_dir)
|
19 |
-
self.model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2, cache_dir=cache_dir)
|
20 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
21 |
-
self.model.to(self.device)
|
22 |
-
self.model.eval()
|
23 |
-
|
24 |
-
def detect_contact_info(self, text):
|
25 |
-
inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
|
26 |
-
with torch.no_grad():
|
27 |
-
outputs = self.model(**inputs)
|
28 |
-
probabilities = softmax(outputs.logits, dim=1)
|
29 |
-
return probabilities[0][1].item() # Probability of contact info
|
30 |
-
|
31 |
-
def is_contact_info(self, text, threshold=0.45):
|
32 |
-
return self.detect_contact_info(text) > threshold
|
33 |
-
|
34 |
-
detector = ContactDetector()
|
35 |
|
36 |
class TextInput(BaseModel):
|
37 |
text: str
|
@@ -65,9 +45,10 @@ async def detect_contact(input: TextInput):
|
|
65 |
"method": "regex"
|
66 |
}
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
71 |
return {
|
72 |
"text": input.text,
|
73 |
"contact_probability": probability,
|
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
import torch
|
|
|
4 |
from torch.nn.functional import softmax
|
5 |
import re
|
6 |
+
from .predictor import predict
|
7 |
|
8 |
app = FastAPI(
|
9 |
title="Contact Information Detection API",
|
10 |
+
description="API for detecting contact information in text great thanks to xxparthparekhxx/ContactShieldAI for the model",
|
11 |
version="1.0.0",
|
12 |
docs_url="/"
|
13 |
)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
class TextInput(BaseModel):
|
17 |
text: str
|
|
|
45 |
"method": "regex"
|
46 |
}
|
47 |
|
48 |
+
# If no regex patterns match, use the model
|
49 |
+
probabilities = predict(input.text)
|
50 |
+
probability = probabilities[1] # Probability of containing contact info
|
51 |
+
is_contact = probability > 0.5 # You can adjust this threshold as needed
|
52 |
return {
|
53 |
"text": input.text,
|
54 |
"contact_probability": probability,
|
contact_sharing_epoch_1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bdb70e711c212856ce3df95b82afbae57b8fc34243b3f541ecd65963fa81fd92
|
3 |
+
size 813497259
|
predictor.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torchtext.vocab import build_vocab_from_iterator, GloVe
|
5 |
+
from torchtext.data.utils import get_tokenizer
|
6 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
7 |
+
|
8 |
+
class ContactSharingClassifier(nn.Module):
|
9 |
+
def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes, lstm_hidden_dim, output_dim, dropout, pad_idx):
|
10 |
+
super().__init__()
|
11 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
|
12 |
+
self.lstm = nn.LSTM(embed_dim, lstm_hidden_dim, bidirectional=True, batch_first=True)
|
13 |
+
self.convs = nn.ModuleList([
|
14 |
+
nn.Conv1d(in_channels=lstm_hidden_dim*2, out_channels=num_filters, kernel_size=fs)
|
15 |
+
for fs in filter_sizes
|
16 |
+
])
|
17 |
+
self.fc1 = nn.Linear(len(filter_sizes) * num_filters, len(filter_sizes) * num_filters // 2)
|
18 |
+
self.fc2 = nn.Linear(len(filter_sizes) * num_filters // 2, output_dim)
|
19 |
+
self.dropout = nn.Dropout(dropout)
|
20 |
+
self.layer_norm = nn.LayerNorm(len(filter_sizes) * num_filters)
|
21 |
+
|
22 |
+
def forward(self, text):
|
23 |
+
embedded = self.embedding(text)
|
24 |
+
lstm_out, _ = self.lstm(embedded)
|
25 |
+
lstm_out = lstm_out.permute(0, 2, 1)
|
26 |
+
conved = [F.relu(conv(lstm_out)) for conv in self.convs]
|
27 |
+
pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
|
28 |
+
cat = self.dropout(torch.cat(pooled, dim=1))
|
29 |
+
cat = self.layer_norm(cat)
|
30 |
+
x = F.relu(self.fc1(cat))
|
31 |
+
x = self.dropout(x)
|
32 |
+
return self.fc2(x)
|
33 |
+
|
34 |
+
# Initialize tokenizer and vocabulary
|
35 |
+
tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
|
36 |
+
vocab = torch.load('vocab.pth') # Assuming you've saved the vocabulary
|
37 |
+
|
38 |
+
# Define text pipeline
|
39 |
+
def text_pipeline(x):
|
40 |
+
return [vocab[token] for token in tokenizer(x)]
|
41 |
+
|
42 |
+
# Model parameters
|
43 |
+
VOCAB_SIZE = len(vocab)
|
44 |
+
EMBED_DIM = 600
|
45 |
+
NUM_FILTERS = 600
|
46 |
+
FILTER_SIZES = [3, 4, 5, 6, 7, 8, 9, 10]
|
47 |
+
LSTM_HIDDEN_DIM = 768
|
48 |
+
OUTPUT_DIM = 2
|
49 |
+
DROPOUT = 0.5
|
50 |
+
PAD_IDX = vocab["<pad>"]
|
51 |
+
|
52 |
+
# Load the model
|
53 |
+
|
54 |
+
model = ContactSharingClassifier(VOCAB_SIZE, EMBED_DIM, NUM_FILTERS, FILTER_SIZES, LSTM_HIDDEN_DIM, OUTPUT_DIM, DROPOUT, PAD_IDX)
|
55 |
+
model.load_state_dict(torch.load('contact_sharing_epoch_1.pth', map_location=device))
|
56 |
+
model.to(device)
|
57 |
+
model.eval()
|
58 |
+
|
59 |
+
# Test sentences
|
60 |
+
test_sentences = [
|
61 |
+
"You can reach me at my electronic mail address, it's my first name dot last name at that popular search engine company's mail service.",
|
62 |
+
"Call me on my cellular device, the digits are the same as the year the Declaration of Independence was signed, followed by my birth year, twice.",
|
63 |
+
"Visit my online presence at triple w dot my full name without spaces or punctuation dot com.",
|
64 |
+
"Send a message to username 'not_my_real_name' on that instant messaging platform that starts with 'disc' and ends with 'ord'.",
|
65 |
+
"My contact info is hidden in this sentence: Eight Six Seven Five Three Oh Nine.",
|
66 |
+
"Find me on the professional networking site, just search for my name plus 'software engineer in San Francisco'.",
|
67 |
+
"My handle on the bird-themed social media platform is at symbol followed by 'definitely_not_my_email_address'.",
|
68 |
+
"You know that video sharing site? My channel is there, just add 'cool_coder_' before my full name, all lowercase.",
|
69 |
+
"I'm listed in the phone book under 'Smith, John' but replace 'Smith' with my actual last name and 'John' with my first name.",
|
70 |
+
"My contact details are encrypted: Rot13('[email protected]')",
|
71 |
+
|
72 |
+
# New non-contact sharing examples
|
73 |
+
"The weather today is absolutely beautiful, perfect for a picnic in the park.",
|
74 |
+
"I'm really excited about the new sci-fi movie coming out next month.",
|
75 |
+
"Did you hear about the latest advancements in artificial intelligence? It's fascinating!",
|
76 |
+
"I'm planning to go hiking this weekend in the nearby mountains.",
|
77 |
+
"The recipe calls for two cups of flour and a pinch of salt.",
|
78 |
+
"The annual tech conference will be held virtually this year due to ongoing health concerns.",
|
79 |
+
"I've been learning to play the guitar for the past six months. It's challenging but rewarding.",
|
80 |
+
"The local farmer's market has the freshest produce every Saturday morning.",
|
81 |
+
"Did you catch the game last night? It was an incredible comeback in the final quarter!",
|
82 |
+
"Lets do '42069' tonight it will be really fun what do you say ?"
|
83 |
+
]
|
84 |
+
|
85 |
+
|
86 |
+
def predict(text):
|
87 |
+
with torch.no_grad():
|
88 |
+
inputs = torch.tensor([text_pipeline(text)])
|
89 |
+
if inputs.size(1) < max(FILTER_SIZES):
|
90 |
+
padding = torch.zeros(1, max(FILTER_SIZES) - inputs.size(1), dtype=torch.long)
|
91 |
+
inputs = torch.cat([inputs, padding], dim=1)
|
92 |
+
inputs = inputs.to(device)
|
93 |
+
outputs = model(inputs)
|
94 |
+
probabilities = F.softmax(outputs, dim=1)
|
95 |
+
return probabilities.squeeze().tolist()
|
96 |
+
|
97 |
+
|
98 |
+
# Test the sentences
|
99 |
+
for i, sentence in enumerate(test_sentences, 1):
|
100 |
+
prediction = predict(sentence)
|
101 |
+
result = "Contains contact info" if prediction == 1 else "No contact info"
|
102 |
+
print(f"Sentence {i}: {result}")
|
103 |
+
print(f"Text: {sentence}\n")
|