parth parekh
commited on
Commit
·
ddaad57
1
Parent(s):
838063e
added batch processing endpoint
Browse files- __pycache__/test.cpython-312.pyc +0 -0
- app.py +44 -6
- load_test.py +67 -0
- predictor.py +18 -0
__pycache__/test.cpython-312.pyc
ADDED
Binary file (11.6 kB). View file
|
|
app.py
CHANGED
@@ -3,11 +3,11 @@ from pydantic import BaseModel
|
|
3 |
import torch
|
4 |
from torch.nn.functional import softmax
|
5 |
import re
|
6 |
-
from predictor import predict
|
7 |
|
8 |
app = FastAPI(
|
9 |
title="Contact Information Detection API",
|
10 |
-
description="API for detecting contact information in text great thanks to xxparthparekhxx/ContactShieldAI for the model",
|
11 |
version="1.0.0",
|
12 |
docs_url="/"
|
13 |
)
|
@@ -19,6 +19,8 @@ def preprocess_text(text):
|
|
19 |
class TextInput(BaseModel):
|
20 |
text: str
|
21 |
|
|
|
|
|
22 |
|
23 |
def check_regex_patterns(text):
|
24 |
patterns = [
|
@@ -34,8 +36,6 @@ def check_regex_patterns(text):
|
|
34 |
return True
|
35 |
return False
|
36 |
|
37 |
-
|
38 |
-
|
39 |
@app.post("/detect_contact", summary="Detect contact information in text")
|
40 |
async def detect_contact(input: TextInput):
|
41 |
try:
|
@@ -45,7 +45,6 @@ async def detect_contact(input: TextInput):
|
|
45 |
if check_regex_patterns(preprocessed_text):
|
46 |
return {
|
47 |
"text": input.text,
|
48 |
-
"contact_probability": 1.0,
|
49 |
"is_contact_info": True,
|
50 |
"method": "regex"
|
51 |
}
|
@@ -54,9 +53,48 @@ async def detect_contact(input: TextInput):
|
|
54 |
is_contact = predict(preprocessed_text)
|
55 |
return {
|
56 |
"text": input.text,
|
57 |
-
"contact_probability": 0.98,
|
58 |
"is_contact_info": is_contact == 1,
|
59 |
"method": "model"
|
60 |
}
|
61 |
except Exception as e:
|
62 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import torch
|
4 |
from torch.nn.functional import softmax
|
5 |
import re
|
6 |
+
from predictor import predict, batch_predict # Assuming batch_predict is in predictor module
|
7 |
|
8 |
app = FastAPI(
|
9 |
title="Contact Information Detection API",
|
10 |
+
description="API for detecting contact information in text, great thanks to xxparthparekhxx/ContactShieldAI for the model",
|
11 |
version="1.0.0",
|
12 |
docs_url="/"
|
13 |
)
|
|
|
19 |
class TextInput(BaseModel):
|
20 |
text: str
|
21 |
|
22 |
+
class BatchTextInput(BaseModel):
|
23 |
+
texts: list[str]
|
24 |
|
25 |
def check_regex_patterns(text):
|
26 |
patterns = [
|
|
|
36 |
return True
|
37 |
return False
|
38 |
|
|
|
|
|
39 |
@app.post("/detect_contact", summary="Detect contact information in text")
|
40 |
async def detect_contact(input: TextInput):
|
41 |
try:
|
|
|
45 |
if check_regex_patterns(preprocessed_text):
|
46 |
return {
|
47 |
"text": input.text,
|
|
|
48 |
"is_contact_info": True,
|
49 |
"method": "regex"
|
50 |
}
|
|
|
53 |
is_contact = predict(preprocessed_text)
|
54 |
return {
|
55 |
"text": input.text,
|
|
|
56 |
"is_contact_info": is_contact == 1,
|
57 |
"method": "model"
|
58 |
}
|
59 |
except Exception as e:
|
60 |
raise HTTPException(status_code=500, detail=str(e))
|
61 |
+
|
62 |
+
|
63 |
+
@app.post("/batch_detect_contact", summary="Detect contact information in batch of texts")
|
64 |
+
async def batch_detect_contact(inputs: BatchTextInput):
|
65 |
+
try:
|
66 |
+
# Preprocess all texts
|
67 |
+
preprocessed_texts = [preprocess_text(text) for text in inputs.texts]
|
68 |
+
|
69 |
+
# First, use regex to check patterns
|
70 |
+
regex_results = [check_regex_patterns(text) for text in preprocessed_texts]
|
71 |
+
|
72 |
+
# For texts where regex doesn't detect anything, use the model
|
73 |
+
texts_for_model = [text for text, regex_match in zip(preprocessed_texts, regex_results) if not regex_match]
|
74 |
+
if texts_for_model:
|
75 |
+
model_results = batch_predict(texts_for_model)
|
76 |
+
else:
|
77 |
+
model_results = []
|
78 |
+
|
79 |
+
# Prepare final results
|
80 |
+
results = []
|
81 |
+
model_idx = 0
|
82 |
+
for i, text in enumerate(preprocessed_texts):
|
83 |
+
if regex_results[i]:
|
84 |
+
results.append({
|
85 |
+
"text": inputs.texts[i],
|
86 |
+
"is_contact_info": True,
|
87 |
+
"method": "regex"
|
88 |
+
})
|
89 |
+
else:
|
90 |
+
is_contact = model_results[model_idx]
|
91 |
+
results.append({
|
92 |
+
"text": inputs.texts[i],
|
93 |
+
"is_contact_info": is_contact == 1,
|
94 |
+
"method": "model"
|
95 |
+
})
|
96 |
+
model_idx += 1
|
97 |
+
|
98 |
+
return results
|
99 |
+
except Exception as e:
|
100 |
+
raise HTTPException(status_code=500, detail=str(e))
|
load_test.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import aiohttp
|
3 |
+
import json
|
4 |
+
from tqdm.asyncio import tqdm
|
5 |
+
import time
|
6 |
+
from test import test_texts
|
7 |
+
|
8 |
+
url = "https://vidhitmakvana1-contact-sharing-recognizer-api.hf.space/detect_contact"
|
9 |
+
concurrent_requests = 2
|
10 |
+
|
11 |
+
async def process_text(session, text, semaphore):
|
12 |
+
payload = {"text": text}
|
13 |
+
headers = {"Content-Type": "application/json"}
|
14 |
+
|
15 |
+
async with semaphore:
|
16 |
+
start_time = time.time()
|
17 |
+
while True:
|
18 |
+
async with session.post(url, data=json.dumps(payload), headers=headers) as response:
|
19 |
+
if response.status == 200:
|
20 |
+
result = await response.json()
|
21 |
+
end_time = time.time()
|
22 |
+
result['response_time'] = end_time - start_time
|
23 |
+
return result
|
24 |
+
elif response.status == 429:
|
25 |
+
print(f"Rate limit exceeded. Waiting for 60 seconds before retrying...")
|
26 |
+
await asyncio.sleep(60)
|
27 |
+
else:
|
28 |
+
print(f"Error for text: {text}")
|
29 |
+
print(f"Status code: {response.status}")
|
30 |
+
print(f"Response: {await response.text()}")
|
31 |
+
return None
|
32 |
+
|
33 |
+
async def main():
|
34 |
+
semaphore = asyncio.Semaphore(concurrent_requests)
|
35 |
+
async with aiohttp.ClientSession() as session:
|
36 |
+
tasks = [process_text(session, text, semaphore) for text in [*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts]]
|
37 |
+
results = await tqdm.gather(*tasks)
|
38 |
+
|
39 |
+
correct_predictions = 0
|
40 |
+
total_predictions = len(results)
|
41 |
+
total_response_time = 0
|
42 |
+
|
43 |
+
for text, result in zip(test_texts, results):
|
44 |
+
if result:
|
45 |
+
print(f"Text: {result['text']}")
|
46 |
+
print(f"Contact Probability: {result['contact_probability']:.4f}")
|
47 |
+
print(f"Is Contact Info: {result['is_contact_info']}")
|
48 |
+
print(f"Response Time: {result['response_time']:.4f} seconds")
|
49 |
+
print("---")
|
50 |
+
|
51 |
+
if result['is_contact_info']:
|
52 |
+
correct_predictions += 1
|
53 |
+
|
54 |
+
total_response_time += result['response_time']
|
55 |
+
|
56 |
+
accuracy = correct_predictions / (total_predictions * 37)
|
57 |
+
average_response_time = total_response_time / total_predictions
|
58 |
+
print(f"Accuracy: {accuracy:.2f}")
|
59 |
+
print(f"Average Response Time: {average_response_time:.4f} seconds")
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
while True:
|
63 |
+
start_time = time.time()
|
64 |
+
asyncio.run(main())
|
65 |
+
end_time = time.time()
|
66 |
+
total_time = end_time - start_time
|
67 |
+
print(f"\nTotal execution time: {total_time:.2f} seconds")
|
predictor.py
CHANGED
@@ -104,6 +104,24 @@ def predict(text):
|
|
104 |
# Return predicted class
|
105 |
return torch.argmax(outputs, dim=1).item()
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
# Test the sentences
|
109 |
for i, sentence in enumerate(test_sentences, 1):
|
|
|
104 |
# Return predicted class
|
105 |
return torch.argmax(outputs, dim=1).item()
|
106 |
|
107 |
+
def batch_predict(texts):
|
108 |
+
with torch.inference_mode(): # Use inference mode for performance
|
109 |
+
# Tokenize and convert each text to tensor
|
110 |
+
inputs = [torch.tensor(text_pipeline(text)) for text in texts]
|
111 |
+
|
112 |
+
# Pad all sequences to the maximum filter size (max of FILTER_SIZES)
|
113 |
+
max_len = max(FILTER_SIZES)
|
114 |
+
padded_inputs = torch.stack([
|
115 |
+
torch.cat([seq, torch.zeros(max_len - len(seq), dtype=torch.long)]) if len(seq) < max_len else seq
|
116 |
+
for seq in inputs
|
117 |
+
]).to(device)
|
118 |
+
|
119 |
+
# Pass the batch through the scripted model
|
120 |
+
outputs = scripted_model(padded_inputs)
|
121 |
+
|
122 |
+
# Return predicted classes for each sentence
|
123 |
+
predictions = torch.argmax(outputs, dim=1).cpu().numpy()
|
124 |
+
return predictions
|
125 |
|
126 |
# Test the sentences
|
127 |
for i, sentence in enumerate(test_sentences, 1):
|