File size: 1,196 Bytes
7078ca9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327190d
 
7078ca9
a063777
7078ca9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import gradio as gr
from transformers import AutoModelForImageClassification, AutoProcessor
import torch

# Load the model and processor
model_name = "DeathDaDev/Materializer"
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)

# Define the prediction function
def classify_image(image):
    # Preprocess the image
    inputs = processor(images=image, return_tensors="pt")
    
    # Perform inference
    with torch.no_grad():
        logits = model(**inputs).logits
    
    # Get the predicted class
    predicted_class_idx = logits.argmax(-1).item()
    return model.config.id2label[predicted_class_idx]

# Create the Gradio interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(num_top_classes=3),
    title="Image Classification with Materializer",
    description="This model has been trained on texture images that are commonly used for 3d models in an attempt to create an AI model that understands what image 'material' should be used on a specific object. Upload an image to classify it using the Materializer model."
)

# Launch the interface
iface.launch()