File size: 2,056 Bytes
bb3a6cf
 
8efb051
a721254
bb3a6cf
 
a721254
bb3a6cf
a721254
bb3a6cf
a721254
bb3a6cf
 
 
a721254
 
 
bb3a6cf
a721254
bb3a6cf
a721254
cb87e23
a721254
1a240ed
 
 
 
bb3a6cf
a721254
1a240ed
 
bb3a6cf
a721254
1a240ed
 
 
 
 
bb3a6cf
a721254
1a240ed
 
 
 
 
 
bb3a6cf
a721254
1a240ed
 
 
bb3a6cf
a721254
1a240ed
 
 
 
 
bb3a6cf
1a240ed
 
a721254
bb3a6cf
 
cb87e23
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
---
library_name: transformers
base_model:
  - google/vit-base-patch16-224
---

# Model Card for Pokémon Type Classification

This model leverages a Vision Transformer (ViT) to classify Pokémon images into 18 different types. 

It was developed as part of the CS 310 Final Project and trained on a Pokémon image dataset.

## Model Details

- **Developer:** Xianglu (Steven) Zhu  
- **Purpose:** Pokémon type classification  
- **Model Type:** Vision Transformer (ViT) for image classification  

## Getting Started

Here’s how you can use the model for classification:

```python
import torch
from PIL import Image
import torchvision.transforms as transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor

# Load the pretrained model and feature extractor
hf_model = ViTForImageClassification.from_pretrained("NP-NP/pokemon_model")
hf_feature_extractor = ViTFeatureExtractor.from_pretrained("NP-NP/pokemon_model")

# Define preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=hf_feature_extractor.image_mean, std=hf_feature_extractor.image_std)
])

# Mapping of labels to indices and vice versa
labels_dict = {
    'Grass': 0, 'Fire': 1, 'Water': 2, 'Bug': 3, 'Normal': 4, 'Poison': 5, 'Electric': 6,
    'Ground': 7, 'Fairy': 8, 'Fighting': 9, 'Psychic': 10, 'Rock': 11, 'Ghost': 12,
    'Ice': 13, 'Dragon': 14, 'Dark': 15, 'Steel': 16, 'Flying': 17
}
idx_to_label = {v: k for k, v in labels_dict.items()}

# Load and preprocess the image
image_path = "cute-pikachu-flowers-pokemon-desktop-wallpaper.jpg"
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0)  # shape: (1, 3, 224, 224)

# Make a prediction
hf_model.eval()
with torch.no_grad():
    outputs = hf_model(input_tensor)
    logits = outputs.logits
    predicted_class_idx = torch.argmax(logits, dim=1).item()

predicted_class = idx_to_label[predicted_class_idx]
print("Predicted Pokémon type:", predicted_class)
```