parth parekh commited on
Commit
ddaad57
·
1 Parent(s): 838063e

added batch processing endpoint

Browse files
Files changed (4) hide show
  1. __pycache__/test.cpython-312.pyc +0 -0
  2. app.py +44 -6
  3. load_test.py +67 -0
  4. predictor.py +18 -0
__pycache__/test.cpython-312.pyc ADDED
Binary file (11.6 kB). View file
 
app.py CHANGED
@@ -3,11 +3,11 @@ from pydantic import BaseModel
3
  import torch
4
  from torch.nn.functional import softmax
5
  import re
6
- from predictor import predict
7
 
8
  app = FastAPI(
9
  title="Contact Information Detection API",
10
- description="API for detecting contact information in text great thanks to xxparthparekhxx/ContactShieldAI for the model",
11
  version="1.0.0",
12
  docs_url="/"
13
  )
@@ -19,6 +19,8 @@ def preprocess_text(text):
19
  class TextInput(BaseModel):
20
  text: str
21
 
 
 
22
 
23
  def check_regex_patterns(text):
24
  patterns = [
@@ -34,8 +36,6 @@ def check_regex_patterns(text):
34
  return True
35
  return False
36
 
37
-
38
-
39
  @app.post("/detect_contact", summary="Detect contact information in text")
40
  async def detect_contact(input: TextInput):
41
  try:
@@ -45,7 +45,6 @@ async def detect_contact(input: TextInput):
45
  if check_regex_patterns(preprocessed_text):
46
  return {
47
  "text": input.text,
48
- "contact_probability": 1.0,
49
  "is_contact_info": True,
50
  "method": "regex"
51
  }
@@ -54,9 +53,48 @@ async def detect_contact(input: TextInput):
54
  is_contact = predict(preprocessed_text)
55
  return {
56
  "text": input.text,
57
- "contact_probability": 0.98,
58
  "is_contact_info": is_contact == 1,
59
  "method": "model"
60
  }
61
  except Exception as e:
62
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import torch
4
  from torch.nn.functional import softmax
5
  import re
6
+ from predictor import predict, batch_predict # Assuming batch_predict is in predictor module
7
 
8
  app = FastAPI(
9
  title="Contact Information Detection API",
10
+ description="API for detecting contact information in text, great thanks to xxparthparekhxx/ContactShieldAI for the model",
11
  version="1.0.0",
12
  docs_url="/"
13
  )
 
19
  class TextInput(BaseModel):
20
  text: str
21
 
22
+ class BatchTextInput(BaseModel):
23
+ texts: list[str]
24
 
25
  def check_regex_patterns(text):
26
  patterns = [
 
36
  return True
37
  return False
38
 
 
 
39
  @app.post("/detect_contact", summary="Detect contact information in text")
40
  async def detect_contact(input: TextInput):
41
  try:
 
45
  if check_regex_patterns(preprocessed_text):
46
  return {
47
  "text": input.text,
 
48
  "is_contact_info": True,
49
  "method": "regex"
50
  }
 
53
  is_contact = predict(preprocessed_text)
54
  return {
55
  "text": input.text,
 
56
  "is_contact_info": is_contact == 1,
57
  "method": "model"
58
  }
59
  except Exception as e:
60
  raise HTTPException(status_code=500, detail=str(e))
61
+
62
+
63
+ @app.post("/batch_detect_contact", summary="Detect contact information in batch of texts")
64
+ async def batch_detect_contact(inputs: BatchTextInput):
65
+ try:
66
+ # Preprocess all texts
67
+ preprocessed_texts = [preprocess_text(text) for text in inputs.texts]
68
+
69
+ # First, use regex to check patterns
70
+ regex_results = [check_regex_patterns(text) for text in preprocessed_texts]
71
+
72
+ # For texts where regex doesn't detect anything, use the model
73
+ texts_for_model = [text for text, regex_match in zip(preprocessed_texts, regex_results) if not regex_match]
74
+ if texts_for_model:
75
+ model_results = batch_predict(texts_for_model)
76
+ else:
77
+ model_results = []
78
+
79
+ # Prepare final results
80
+ results = []
81
+ model_idx = 0
82
+ for i, text in enumerate(preprocessed_texts):
83
+ if regex_results[i]:
84
+ results.append({
85
+ "text": inputs.texts[i],
86
+ "is_contact_info": True,
87
+ "method": "regex"
88
+ })
89
+ else:
90
+ is_contact = model_results[model_idx]
91
+ results.append({
92
+ "text": inputs.texts[i],
93
+ "is_contact_info": is_contact == 1,
94
+ "method": "model"
95
+ })
96
+ model_idx += 1
97
+
98
+ return results
99
+ except Exception as e:
100
+ raise HTTPException(status_code=500, detail=str(e))
load_test.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import aiohttp
3
+ import json
4
+ from tqdm.asyncio import tqdm
5
+ import time
6
+ from test import test_texts
7
+
8
+ url = "https://vidhitmakvana1-contact-sharing-recognizer-api.hf.space/detect_contact"
9
+ concurrent_requests = 2
10
+
11
+ async def process_text(session, text, semaphore):
12
+ payload = {"text": text}
13
+ headers = {"Content-Type": "application/json"}
14
+
15
+ async with semaphore:
16
+ start_time = time.time()
17
+ while True:
18
+ async with session.post(url, data=json.dumps(payload), headers=headers) as response:
19
+ if response.status == 200:
20
+ result = await response.json()
21
+ end_time = time.time()
22
+ result['response_time'] = end_time - start_time
23
+ return result
24
+ elif response.status == 429:
25
+ print(f"Rate limit exceeded. Waiting for 60 seconds before retrying...")
26
+ await asyncio.sleep(60)
27
+ else:
28
+ print(f"Error for text: {text}")
29
+ print(f"Status code: {response.status}")
30
+ print(f"Response: {await response.text()}")
31
+ return None
32
+
33
+ async def main():
34
+ semaphore = asyncio.Semaphore(concurrent_requests)
35
+ async with aiohttp.ClientSession() as session:
36
+ tasks = [process_text(session, text, semaphore) for text in [*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts,*test_texts]]
37
+ results = await tqdm.gather(*tasks)
38
+
39
+ correct_predictions = 0
40
+ total_predictions = len(results)
41
+ total_response_time = 0
42
+
43
+ for text, result in zip(test_texts, results):
44
+ if result:
45
+ print(f"Text: {result['text']}")
46
+ print(f"Contact Probability: {result['contact_probability']:.4f}")
47
+ print(f"Is Contact Info: {result['is_contact_info']}")
48
+ print(f"Response Time: {result['response_time']:.4f} seconds")
49
+ print("---")
50
+
51
+ if result['is_contact_info']:
52
+ correct_predictions += 1
53
+
54
+ total_response_time += result['response_time']
55
+
56
+ accuracy = correct_predictions / (total_predictions * 37)
57
+ average_response_time = total_response_time / total_predictions
58
+ print(f"Accuracy: {accuracy:.2f}")
59
+ print(f"Average Response Time: {average_response_time:.4f} seconds")
60
+
61
+ if __name__ == "__main__":
62
+ while True:
63
+ start_time = time.time()
64
+ asyncio.run(main())
65
+ end_time = time.time()
66
+ total_time = end_time - start_time
67
+ print(f"\nTotal execution time: {total_time:.2f} seconds")
predictor.py CHANGED
@@ -104,6 +104,24 @@ def predict(text):
104
  # Return predicted class
105
  return torch.argmax(outputs, dim=1).item()
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  # Test the sentences
109
  for i, sentence in enumerate(test_sentences, 1):
 
104
  # Return predicted class
105
  return torch.argmax(outputs, dim=1).item()
106
 
107
+ def batch_predict(texts):
108
+ with torch.inference_mode(): # Use inference mode for performance
109
+ # Tokenize and convert each text to tensor
110
+ inputs = [torch.tensor(text_pipeline(text)) for text in texts]
111
+
112
+ # Pad all sequences to the maximum filter size (max of FILTER_SIZES)
113
+ max_len = max(FILTER_SIZES)
114
+ padded_inputs = torch.stack([
115
+ torch.cat([seq, torch.zeros(max_len - len(seq), dtype=torch.long)]) if len(seq) < max_len else seq
116
+ for seq in inputs
117
+ ]).to(device)
118
+
119
+ # Pass the batch through the scripted model
120
+ outputs = scripted_model(padded_inputs)
121
+
122
+ # Return predicted classes for each sentence
123
+ predictions = torch.argmax(outputs, dim=1).cpu().numpy()
124
+ return predictions
125
 
126
  # Test the sentences
127
  for i, sentence in enumerate(test_sentences, 1):