parth parekh commited on
Commit
3cfd7e3
·
1 Parent(s): 645ea59

added new model from xxparthparekhxx/ContactShieldAI

Browse files
Files changed (3) hide show
  1. app.py +6 -25
  2. contact_sharing_epoch_1.pth +3 -0
  3. predictor.py +103 -0
app.py CHANGED
@@ -1,37 +1,17 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  import torch
4
- from transformers import RobertaTokenizer, RobertaForSequenceClassification
5
  from torch.nn.functional import softmax
6
  import re
 
7
 
8
  app = FastAPI(
9
  title="Contact Information Detection API",
10
- description="API for detecting contact information in text",
11
  version="1.0.0",
12
  docs_url="/"
13
  )
14
 
15
- class ContactDetector:
16
- def __init__(self):
17
- cache_dir = "/app/model_cache"
18
- self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', cache_dir=cache_dir)
19
- self.model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2, cache_dir=cache_dir)
20
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
- self.model.to(self.device)
22
- self.model.eval()
23
-
24
- def detect_contact_info(self, text):
25
- inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(self.device)
26
- with torch.no_grad():
27
- outputs = self.model(**inputs)
28
- probabilities = softmax(outputs.logits, dim=1)
29
- return probabilities[0][1].item() # Probability of contact info
30
-
31
- def is_contact_info(self, text, threshold=0.45):
32
- return self.detect_contact_info(text) > threshold
33
-
34
- detector = ContactDetector()
35
 
36
  class TextInput(BaseModel):
37
  text: str
@@ -65,9 +45,10 @@ async def detect_contact(input: TextInput):
65
  "method": "regex"
66
  }
67
 
68
- # If no regex patterns match, use the model
69
- probability = detector.detect_contact_info(input.text)
70
- is_contact = detector.is_contact_info(input.text)
 
71
  return {
72
  "text": input.text,
73
  "contact_probability": probability,
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  import torch
 
4
  from torch.nn.functional import softmax
5
  import re
6
+ from .predictor import predict
7
 
8
  app = FastAPI(
9
  title="Contact Information Detection API",
10
+ description="API for detecting contact information in text great thanks to xxparthparekhxx/ContactShieldAI for the model",
11
  version="1.0.0",
12
  docs_url="/"
13
  )
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  class TextInput(BaseModel):
17
  text: str
 
45
  "method": "regex"
46
  }
47
 
48
+ # If no regex patterns match, use the model
49
+ probabilities = predict(input.text)
50
+ probability = probabilities[1] # Probability of containing contact info
51
+ is_contact = probability > 0.5 # You can adjust this threshold as needed
52
  return {
53
  "text": input.text,
54
  "contact_probability": probability,
contact_sharing_epoch_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdb70e711c212856ce3df95b82afbae57b8fc34243b3f541ecd65963fa81fd92
3
+ size 813497259
predictor.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchtext.vocab import build_vocab_from_iterator, GloVe
5
+ from torchtext.data.utils import get_tokenizer
6
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
+
8
+ class ContactSharingClassifier(nn.Module):
9
+ def __init__(self, vocab_size, embed_dim, num_filters, filter_sizes, lstm_hidden_dim, output_dim, dropout, pad_idx):
10
+ super().__init__()
11
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
12
+ self.lstm = nn.LSTM(embed_dim, lstm_hidden_dim, bidirectional=True, batch_first=True)
13
+ self.convs = nn.ModuleList([
14
+ nn.Conv1d(in_channels=lstm_hidden_dim*2, out_channels=num_filters, kernel_size=fs)
15
+ for fs in filter_sizes
16
+ ])
17
+ self.fc1 = nn.Linear(len(filter_sizes) * num_filters, len(filter_sizes) * num_filters // 2)
18
+ self.fc2 = nn.Linear(len(filter_sizes) * num_filters // 2, output_dim)
19
+ self.dropout = nn.Dropout(dropout)
20
+ self.layer_norm = nn.LayerNorm(len(filter_sizes) * num_filters)
21
+
22
+ def forward(self, text):
23
+ embedded = self.embedding(text)
24
+ lstm_out, _ = self.lstm(embedded)
25
+ lstm_out = lstm_out.permute(0, 2, 1)
26
+ conved = [F.relu(conv(lstm_out)) for conv in self.convs]
27
+ pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
28
+ cat = self.dropout(torch.cat(pooled, dim=1))
29
+ cat = self.layer_norm(cat)
30
+ x = F.relu(self.fc1(cat))
31
+ x = self.dropout(x)
32
+ return self.fc2(x)
33
+
34
+ # Initialize tokenizer and vocabulary
35
+ tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
36
+ vocab = torch.load('vocab.pth') # Assuming you've saved the vocabulary
37
+
38
+ # Define text pipeline
39
+ def text_pipeline(x):
40
+ return [vocab[token] for token in tokenizer(x)]
41
+
42
+ # Model parameters
43
+ VOCAB_SIZE = len(vocab)
44
+ EMBED_DIM = 600
45
+ NUM_FILTERS = 600
46
+ FILTER_SIZES = [3, 4, 5, 6, 7, 8, 9, 10]
47
+ LSTM_HIDDEN_DIM = 768
48
+ OUTPUT_DIM = 2
49
+ DROPOUT = 0.5
50
+ PAD_IDX = vocab["<pad>"]
51
+
52
+ # Load the model
53
+
54
+ model = ContactSharingClassifier(VOCAB_SIZE, EMBED_DIM, NUM_FILTERS, FILTER_SIZES, LSTM_HIDDEN_DIM, OUTPUT_DIM, DROPOUT, PAD_IDX)
55
+ model.load_state_dict(torch.load('contact_sharing_epoch_1.pth', map_location=device))
56
+ model.to(device)
57
+ model.eval()
58
+
59
+ # Test sentences
60
+ test_sentences = [
61
+ "You can reach me at my electronic mail address, it's my first name dot last name at that popular search engine company's mail service.",
62
+ "Call me on my cellular device, the digits are the same as the year the Declaration of Independence was signed, followed by my birth year, twice.",
63
+ "Visit my online presence at triple w dot my full name without spaces or punctuation dot com.",
64
+ "Send a message to username 'not_my_real_name' on that instant messaging platform that starts with 'disc' and ends with 'ord'.",
65
+ "My contact info is hidden in this sentence: Eight Six Seven Five Three Oh Nine.",
66
+ "Find me on the professional networking site, just search for my name plus 'software engineer in San Francisco'.",
67
+ "My handle on the bird-themed social media platform is at symbol followed by 'definitely_not_my_email_address'.",
68
+ "You know that video sharing site? My channel is there, just add 'cool_coder_' before my full name, all lowercase.",
69
+ "I'm listed in the phone book under 'Smith, John' but replace 'Smith' with my actual last name and 'John' with my first name.",
70
+ "My contact details are encrypted: Rot13('[email protected]')",
71
+
72
+ # New non-contact sharing examples
73
+ "The weather today is absolutely beautiful, perfect for a picnic in the park.",
74
+ "I'm really excited about the new sci-fi movie coming out next month.",
75
+ "Did you hear about the latest advancements in artificial intelligence? It's fascinating!",
76
+ "I'm planning to go hiking this weekend in the nearby mountains.",
77
+ "The recipe calls for two cups of flour and a pinch of salt.",
78
+ "The annual tech conference will be held virtually this year due to ongoing health concerns.",
79
+ "I've been learning to play the guitar for the past six months. It's challenging but rewarding.",
80
+ "The local farmer's market has the freshest produce every Saturday morning.",
81
+ "Did you catch the game last night? It was an incredible comeback in the final quarter!",
82
+ "Lets do '42069' tonight it will be really fun what do you say ?"
83
+ ]
84
+
85
+
86
+ def predict(text):
87
+ with torch.no_grad():
88
+ inputs = torch.tensor([text_pipeline(text)])
89
+ if inputs.size(1) < max(FILTER_SIZES):
90
+ padding = torch.zeros(1, max(FILTER_SIZES) - inputs.size(1), dtype=torch.long)
91
+ inputs = torch.cat([inputs, padding], dim=1)
92
+ inputs = inputs.to(device)
93
+ outputs = model(inputs)
94
+ probabilities = F.softmax(outputs, dim=1)
95
+ return probabilities.squeeze().tolist()
96
+
97
+
98
+ # Test the sentences
99
+ for i, sentence in enumerate(test_sentences, 1):
100
+ prediction = predict(sentence)
101
+ result = "Contains contact info" if prediction == 1 else "No contact info"
102
+ print(f"Sentence {i}: {result}")
103
+ print(f"Text: {sentence}\n")