parth parekh commited on
Commit
ab99a02
·
1 Parent(s): ddaad57

added working batch processing endpoint

Browse files
__pycache__/app.cpython-312.pyc ADDED
Binary file (4.43 kB). View file
 
__pycache__/predictor.cpython-312.pyc CHANGED
Binary files a/__pycache__/predictor.cpython-312.pyc and b/__pycache__/predictor.cpython-312.pyc differ
 
app.py CHANGED
@@ -59,23 +59,23 @@ async def detect_contact(input: TextInput):
59
  except Exception as e:
60
  raise HTTPException(status_code=500, detail=str(e))
61
 
62
-
63
  @app.post("/batch_detect_contact", summary="Detect contact information in batch of texts")
64
  async def batch_detect_contact(inputs: BatchTextInput):
65
  try:
66
  # Preprocess all texts
67
  preprocessed_texts = [preprocess_text(text) for text in inputs.texts]
68
-
69
  # First, use regex to check patterns
70
  regex_results = [check_regex_patterns(text) for text in preprocessed_texts]
71
-
 
72
  # For texts where regex doesn't detect anything, use the model
73
  texts_for_model = [text for text, regex_match in zip(preprocessed_texts, regex_results) if not regex_match]
74
  if texts_for_model:
75
  model_results = batch_predict(texts_for_model)
76
  else:
77
  model_results = []
78
-
79
  # Prepare final results
80
  results = []
81
  model_idx = 0
@@ -90,11 +90,11 @@ async def batch_detect_contact(inputs: BatchTextInput):
90
  is_contact = model_results[model_idx]
91
  results.append({
92
  "text": inputs.texts[i],
93
- "is_contact_info": is_contact == 1,
94
  "method": "model"
95
  })
96
  model_idx += 1
97
-
98
  return results
99
  except Exception as e:
100
- raise HTTPException(status_code=500, detail=str(e))
 
59
  except Exception as e:
60
  raise HTTPException(status_code=500, detail=str(e))
61
 
 
62
  @app.post("/batch_detect_contact", summary="Detect contact information in batch of texts")
63
  async def batch_detect_contact(inputs: BatchTextInput):
64
  try:
65
  # Preprocess all texts
66
  preprocessed_texts = [preprocess_text(text) for text in inputs.texts]
67
+
68
  # First, use regex to check patterns
69
  regex_results = [check_regex_patterns(text) for text in preprocessed_texts]
70
+
71
+
72
  # For texts where regex doesn't detect anything, use the model
73
  texts_for_model = [text for text, regex_match in zip(preprocessed_texts, regex_results) if not regex_match]
74
  if texts_for_model:
75
  model_results = batch_predict(texts_for_model)
76
  else:
77
  model_results = []
78
+
79
  # Prepare final results
80
  results = []
81
  model_idx = 0
 
90
  is_contact = model_results[model_idx]
91
  results.append({
92
  "text": inputs.texts[i],
93
+ "is_contact_info": bool(is_contact), # Convert numpy bool
94
  "method": "model"
95
  })
96
  model_idx += 1
97
+
98
  return results
99
  except Exception as e:
100
+ raise HTTPException(status_code=500, detail=str(e))
predictor.py CHANGED
@@ -105,16 +105,13 @@ def predict(text):
105
  return torch.argmax(outputs, dim=1).item()
106
 
107
  def batch_predict(texts):
108
- with torch.inference_mode(): # Use inference mode for performance
109
- # Tokenize and convert each text to tensor
110
  inputs = [torch.tensor(text_pipeline(text)) for text in texts]
111
 
112
- # Pad all sequences to the maximum filter size (max of FILTER_SIZES)
113
- max_len = max(FILTER_SIZES)
114
- padded_inputs = torch.stack([
115
- torch.cat([seq, torch.zeros(max_len - len(seq), dtype=torch.long)]) if len(seq) < max_len else seq
116
- for seq in inputs
117
- ]).to(device)
118
 
119
  # Pass the batch through the scripted model
120
  outputs = scripted_model(padded_inputs)
 
105
  return torch.argmax(outputs, dim=1).item()
106
 
107
  def batch_predict(texts):
108
+ with torch.inference_mode(): # Use inference mode for better performance
109
+ # Tokenize and convert to tensors
110
  inputs = [torch.tensor(text_pipeline(text)) for text in texts]
111
 
112
+ # Pad all sequences to the length of the longest one in the batch
113
+ max_len = max(len(seq) for seq in inputs)
114
+ padded_inputs = torch.stack([torch.cat([seq, torch.zeros(max_len - len(seq), dtype=torch.long)]) for seq in inputs]).to(device)
 
 
 
115
 
116
  # Pass the batch through the scripted model
117
  outputs = scripted_model(padded_inputs)
test.py CHANGED
@@ -104,47 +104,56 @@ test_texts = [
104
 
105
  ]
106
  import time
 
 
107
 
108
- url = "https://vidhitmakvana1-contact-sharing-recognizer-api.hf.space/detect_contact"
109
-
110
- async def process_text(session, text):
111
- payload = {"text": text}
112
  headers = {"Content-Type": "application/json"}
113
 
114
  start_time = time.time()
115
  async with session.post(url, data=json.dumps(payload), headers=headers) as response:
116
  if response.status == 200:
117
- result = await response.json()
118
  end_time = time.time()
119
- result['response_time'] = end_time - start_time
120
- return result
 
121
  else:
122
- print(f"Error for text: {text}")
123
  print(f"Status code: {response.status}")
124
  print(f"Response: {await response.text()}")
125
  return None
126
 
127
  async def main():
 
 
 
128
  async with aiohttp.ClientSession() as session:
129
- tasks = [process_text(session, text) for text in test_texts]
130
- results = await tqdm.gather(*tasks)
 
 
 
 
 
131
 
132
  correct_predictions = 0
133
  total_predictions = len(results)
134
  total_response_time = 0
135
 
136
- for text, result in zip(test_texts, results):
137
  if result:
138
  print(f"Text: {result['text']}")
139
- print(f"Contact Probability: {result['contact_probability']:.4f}")
140
  print(f"Is Contact Info: {result['is_contact_info']}")
 
141
  print(f"Response Time: {result['response_time']:.4f} seconds")
142
  print("---")
143
-
144
  # Assuming all texts in test_texts are actually contact information
145
  if result['is_contact_info']:
146
  correct_predictions += 1
147
-
148
  total_response_time += result['response_time']
149
 
150
  accuracy = correct_predictions / total_predictions
 
104
 
105
  ]
106
  import time
107
+ # url = "https://vidhitmakvana1-contact-sharing-recognizer-api.hf.space/batch_detect_contact"
108
+ url = "http://localhost:8000/batch_detect_contact"
109
 
110
+ async def process_batch(session, texts):
111
+ payload = {"texts": texts}
 
 
112
  headers = {"Content-Type": "application/json"}
113
 
114
  start_time = time.time()
115
  async with session.post(url, data=json.dumps(payload), headers=headers) as response:
116
  if response.status == 200:
117
+ results = await response.json()
118
  end_time = time.time()
119
+ for result in results:
120
+ result['response_time'] = (end_time - start_time) / len(texts)
121
+ return results
122
  else:
123
+ print(f"Error for batch")
124
  print(f"Status code: {response.status}")
125
  print(f"Response: {await response.text()}")
126
  return None
127
 
128
  async def main():
129
+ # Inflate test_texts
130
+ inflated_texts = test_texts * 100 # Multiply the test set by 10
131
+
132
  async with aiohttp.ClientSession() as session:
133
+ batch_size = 1000
134
+ batches = [inflated_texts[i:i + batch_size] for i in range(0, len(inflated_texts), batch_size)]
135
+
136
+ tasks = [process_batch(session, batch) for batch in batches]
137
+ all_results = await tqdm.gather(*tasks)
138
+
139
+ results = [item for sublist in all_results for item in sublist if sublist]
140
 
141
  correct_predictions = 0
142
  total_predictions = len(results)
143
  total_response_time = 0
144
 
145
+ for result in results:
146
  if result:
147
  print(f"Text: {result['text']}")
 
148
  print(f"Is Contact Info: {result['is_contact_info']}")
149
+ print(f"Method: {result['method']}")
150
  print(f"Response Time: {result['response_time']:.4f} seconds")
151
  print("---")
152
+
153
  # Assuming all texts in test_texts are actually contact information
154
  if result['is_contact_info']:
155
  correct_predictions += 1
156
+
157
  total_response_time += result['response_time']
158
 
159
  accuracy = correct_predictions / total_predictions