karthikmn commited on
Commit
e037c22
·
verified ·
1 Parent(s): 88f235c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
2
+ from flask import Flask, request, jsonify
3
+ import torch
4
+
5
+ # Initialize Flask app
6
+ app = Flask(__name__)
7
+
8
+ # Load pre-trained DistilBERT model and tokenizer
9
+ model_name = "distilbert-base-uncased"
10
+ tokenizer = DistilBertTokenizer.from_pretrained(model_name)
11
+ model = DistilBertForSequenceClassification.from_pretrained(model_name)
12
+
13
+ # Ensure the model is in evaluation mode
14
+ model.eval()
15
+
16
+ # Define a function to predict score and risk
17
+ def predict_deal_qualification(inputs):
18
+ # Prepare input text
19
+ input_text = f"{inputs['industry']} {inputs['stage']} {inputs['amount']} {inputs['lead_score']}"
20
+
21
+ # Tokenize the input
22
+ tokens = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, max_length=128)
23
+
24
+ # Run model inference
25
+ with torch.no_grad():
26
+ outputs = model(**tokens)
27
+
28
+ # Get prediction logits
29
+ logits = outputs.logits
30
+ score = torch.sigmoid(logits).item() # Convert logits to probability (0-1 range)
31
+
32
+ # Risk classification (for simplicity, let's map logits to Low, Medium, High)
33
+ risk_classes = ["Low", "Medium", "High"]
34
+ risk = risk_classes[torch.argmax(logits).item()]
35
+
36
+ # Dummy recommendation (customize this as needed)
37
+ recommendation = "Schedule another meeting before sending proposal."
38
+
39
+ return {
40
+ "score": round(score * 100, 2), # Scale the score to 0-100
41
+ "confidence": round(torch.max(torch.softmax(logits, dim=-1)).item(), 2),
42
+ "risk": risk,
43
+ "recommendation": recommendation
44
+ }
45
+
46
+ # Define an endpoint for deal qualification prediction
47
+ @app.route('/predict', methods=['POST'])
48
+ def predict():
49
+ # Get input JSON data from POST request
50
+ data = request.get_json()
51
+
52
+ # Validate input structure
53
+ if not all(key in data for key in ['industry', 'stage', 'amount', 'lead_score']):
54
+ return jsonify({"error": "Missing required input data"}), 400
55
+
56
+ # Predict using the pre-trained model
57
+ result = predict_deal_qualification(data)
58
+
59
+ # Return the prediction result as JSON
60
+ return jsonify(result)
61
+
62
+ # Run the app
63
+ if __name__ == '__main__':
64
+ app.run(debug=True, host="0.0.0.0", port=5000)