File size: 1,645 Bytes
bc3ec38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff8f460
bc3ec38
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn

class NeuralNet(nn.Module):
    def __init__(
        self, 
        input_size = 24, 
        hidden_size = 256, 
        num_classes = 5
    ):
        super(NeuralNet, self).__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        out = self.l1(x)
        out = self.relu(out)
        out = self.l2(out)
        return out
    
class KeypointClassification:
    def __init__(self, path_model):
        self.path_model = path_model
        self.classes = ['Downdog', 'Goddess', 'Plank', 'Tree', 'Warrior2']
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.load_model()

    def load_model(self):
        self.model = NeuralNet()
        self.model.load_state_dict(
            torch.load(self.path_model, map_location=self.device)
        )
    def __call__(self, input_keypoint):
        if not type(input_keypoint) == torch.Tensor:
            input_keypoint = torch.tensor(
                input_keypoint, dtype=torch.float32
            )
        out = self.model(input_keypoint)
        _, predict = torch.max(out, -1)
        label_predict = self.classes[predict]
        return label_predict

if __name__ == '__main__':
    keypoint_classification = KeypointClassification(
        path_model='/Users/nishantkaushik20/Me/source-code/AI/PoseEstimationYOLOv8/models/pose_classification.pt'
    )
    dummy_input = torch.randn(23)
    classification = keypoint_classification(dummy_input)
    print(classification)