parth parekh commited on
Commit
cce759a
·
1 Parent(s): 329c4c3

reverting the model to the running one

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -1,13 +1,13 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  from torch.nn.functional import softmax
6
  import re
7
 
8
  app = FastAPI(
9
- title="ContactShieldAI Detection API",
10
- description="API for detecting contact information in text using ContactShieldAI",
11
  version="1.0.0",
12
  docs_url="/"
13
  )
@@ -15,8 +15,8 @@ app = FastAPI(
15
  class ContactDetector:
16
  def __init__(self):
17
  cache_dir = "/app/model_cache"
18
- self.tokenizer = AutoTokenizer.from_pretrained('xxparthparekhxx/ContactShieldAI', cache_dir=cache_dir)
19
- self.model = AutoModelForSequenceClassification.from_pretrained('xxparthparekhxx/ContactShieldAI', cache_dir=cache_dir)
20
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  self.model.to(self.device)
22
  self.model.eval()
@@ -36,6 +36,7 @@ detector = ContactDetector()
36
  class TextInput(BaseModel):
37
  text: str
38
 
 
39
  def check_regex_patterns(text):
40
  patterns = [
41
  r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', # Email
@@ -50,7 +51,8 @@ def check_regex_patterns(text):
50
  return True
51
  return False
52
 
53
- @app.post("/detect_contact", summary="Detect contact information in text using ContactShieldAI")
 
54
  async def detect_contact(input: TextInput):
55
  try:
56
  # First, check with regex patterns
@@ -62,14 +64,14 @@ async def detect_contact(input: TextInput):
62
  "method": "regex"
63
  }
64
 
65
- # If no regex patterns match, use the ContactShieldAI model
66
  probability = detector.detect_contact_info(input.text)
67
  is_contact = detector.is_contact_info(input.text)
68
  return {
69
  "text": input.text,
70
  "contact_probability": probability,
71
  "is_contact_info": is_contact,
72
- "method": "ContactShieldAI"
73
  }
74
  except Exception as e:
75
- raise HTTPException(status_code=500, detail=str(e))
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  import torch
4
+ from transformers import RobertaTokenizer, RobertaForSequenceClassification
5
  from torch.nn.functional import softmax
6
  import re
7
 
8
  app = FastAPI(
9
+ title="Contact Information Detection API",
10
+ description="API for detecting contact information in text",
11
  version="1.0.0",
12
  docs_url="/"
13
  )
 
15
  class ContactDetector:
16
  def __init__(self):
17
  cache_dir = "/app/model_cache"
18
+ self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', cache_dir=cache_dir)
19
+ self.model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2, cache_dir=cache_dir)
20
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  self.model.to(self.device)
22
  self.model.eval()
 
36
  class TextInput(BaseModel):
37
  text: str
38
 
39
+
40
  def check_regex_patterns(text):
41
  patterns = [
42
  r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', # Email
 
51
  return True
52
  return False
53
 
54
+
55
+ @app.post("/detect_contact", summary="Detect contact information in text")
56
  async def detect_contact(input: TextInput):
57
  try:
58
  # First, check with regex patterns
 
64
  "method": "regex"
65
  }
66
 
67
+ # If no regex patterns match, use the model
68
  probability = detector.detect_contact_info(input.text)
69
  is_contact = detector.is_contact_info(input.text)
70
  return {
71
  "text": input.text,
72
  "contact_probability": probability,
73
  "is_contact_info": is_contact,
74
+ "method": "model"
75
  }
76
  except Exception as e:
77
+ raise HTTPException(status_code=500, detail=str(e))