Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from transformers import AutoConfig | |
from models import CustomClassifier, CustomClassificationConfig | |
MODEL_ID = "yhamidullah/custom-classifier-demo" | |
config = CustomClassificationConfig.from_pretrained(MODEL_ID) | |
model = CustomClassifier.from_pretrained(MODEL_ID) | |
model.eval() | |
def predict(input_csv: str): | |
vec = [float(x) for x in input_csv.split(",")] | |
if len(vec) != config.input_dim: | |
return f"Error: Need {config.input_dim} floats" | |
x = torch.tensor([vec]) | |
with torch.no_grad(): | |
logits = model(input_ids=x)["logits"] | |
pred = logits.argmax(dim=-1).item() | |
return f"Predicted class: {pred}" | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(label="Input Vector (comma-separated)"), | |
outputs="text", | |
title="Custom Classifier Demo", | |
) | |
demo.launch() | |