File size: 2,056 Bytes
bb3a6cf 8efb051 a721254 bb3a6cf a721254 bb3a6cf a721254 bb3a6cf a721254 bb3a6cf a721254 bb3a6cf a721254 bb3a6cf a721254 cb87e23 a721254 1a240ed bb3a6cf a721254 1a240ed bb3a6cf a721254 1a240ed bb3a6cf a721254 1a240ed bb3a6cf a721254 1a240ed bb3a6cf a721254 1a240ed bb3a6cf 1a240ed a721254 bb3a6cf cb87e23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
---
library_name: transformers
base_model:
- google/vit-base-patch16-224
---
# Model Card for Pokémon Type Classification
This model leverages a Vision Transformer (ViT) to classify Pokémon images into 18 different types.
It was developed as part of the CS 310 Final Project and trained on a Pokémon image dataset.
## Model Details
- **Developer:** Xianglu (Steven) Zhu
- **Purpose:** Pokémon type classification
- **Model Type:** Vision Transformer (ViT) for image classification
## Getting Started
Here’s how you can use the model for classification:
```python
import torch
from PIL import Image
import torchvision.transforms as transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
# Load the pretrained model and feature extractor
hf_model = ViTForImageClassification.from_pretrained("NP-NP/pokemon_model")
hf_feature_extractor = ViTFeatureExtractor.from_pretrained("NP-NP/pokemon_model")
# Define preprocessing transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=hf_feature_extractor.image_mean, std=hf_feature_extractor.image_std)
])
# Mapping of labels to indices and vice versa
labels_dict = {
'Grass': 0, 'Fire': 1, 'Water': 2, 'Bug': 3, 'Normal': 4, 'Poison': 5, 'Electric': 6,
'Ground': 7, 'Fairy': 8, 'Fighting': 9, 'Psychic': 10, 'Rock': 11, 'Ghost': 12,
'Ice': 13, 'Dragon': 14, 'Dark': 15, 'Steel': 16, 'Flying': 17
}
idx_to_label = {v: k for k, v in labels_dict.items()}
# Load and preprocess the image
image_path = "cute-pikachu-flowers-pokemon-desktop-wallpaper.jpg"
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0) # shape: (1, 3, 224, 224)
# Make a prediction
hf_model.eval()
with torch.no_grad():
outputs = hf_model(input_tensor)
logits = outputs.logits
predicted_class_idx = torch.argmax(logits, dim=1).item()
predicted_class = idx_to_label[predicted_class_idx]
print("Predicted Pokémon type:", predicted_class)
```
|