File size: 957 Bytes
84c4b50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import torch

class RAMPlusModel:
    def __init__(self):
        self.feature_extractor = AutoFeatureExtractor.from_pretrained("xinyu1205/recognize-anything-plus-model")
        self.model = AutoModelForImageClassification.from_pretrained("xinyu1205/recognize-anything-plus-model")
        self.model.eval()

    def predict(self, image):
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**inputs)
        
        logits = outputs.logits
        predicted_classes = logits.argmax(-1)
        
        # ์ƒ์œ„ 5๊ฐœ ํƒœ๊ทธ ๋ฐ˜ํ™˜ (์ด ๋ถ€๋ถ„์€ ๋ชจ๋ธ์˜ ์‹ค์ œ ์ถœ๋ ฅ์— ๋”ฐ๋ผ ์กฐ์ • ํ•„์š”)
        top_5 = torch.topk(logits, k=5)
        return [self.model.config.id2label[i.item()] for i in top_5.indices[0]]

# ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
model = RAMPlusModel()