JulianPhillips commited on
Commit
1ea0f2c
·
verified ·
1 Parent(s): 65c9437

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -0
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ import torch
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+
5
+ # Load Meta Sapiens Pose model
6
+ sapiens_model = torch.jit.load('/models/sapiens_pose/model.pt')
7
+ sapiens_model.eval()
8
+
9
+ # Load MotionBERT model
10
+ motionbert_model = AutoModelForSequenceClassification.from_pretrained('/models/motionbert')
11
+ motionbert_tokenizer = AutoTokenizer.from_pretrained('/models/motionbert')
12
+
13
+ app = Flask(__name__)
14
+
15
+ @app.route('/pose_estimation', methods=['POST'])
16
+ def pose_estimation():
17
+ # Accept an image file as input for pose estimation
18
+ image = request.files['image'].read()
19
+ # Perform pose estimation
20
+ with torch.no_grad():
21
+ pose_result = sapiens_model(torch.tensor(image))
22
+ return jsonify({"pose_result": pose_result.tolist()})
23
+
24
+ @app.route('/sequence_analysis', methods=['POST'])
25
+ def sequence_analysis():
26
+ # Accept keypoint data as input for sequence analysis
27
+ keypoints = request.json['keypoints']
28
+ inputs = motionbert_tokenizer(keypoints, return_tensors="pt")
29
+ with torch.no_grad():
30
+ sequence_output = motionbert_model(**inputs)
31
+ return jsonify({"sequence_analysis": sequence_output.logits.tolist()})
32
+
33
+ if __name__ == '__main__':
34
+ app.run(host='0.0.0.0', port=7860)