Spaces:
Sleeping
Sleeping
File size: 833 Bytes
f51abca 5bf9e80 f51abca 5a5b916 5bf9e80 f51abca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch
import gradio as gr
from transformers import AutoConfig
from models import CustomClassifier, CustomClassificationConfig
MODEL_ID = "yhamidullah/custom-classifier-demo"
config = CustomClassificationConfig.from_pretrained(MODEL_ID)
model = CustomClassifier.from_pretrained(MODEL_ID)
model.eval()
def predict(input_csv: str):
vec = [float(x) for x in input_csv.split(",")]
if len(vec) != config.input_dim:
return f"Error: Need {config.input_dim} floats"
x = torch.tensor([vec])
with torch.no_grad():
logits = model(input_ids=x)["logits"]
pred = logits.argmax(dim=-1).item()
return f"Predicted class: {pred}"
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(label="Input Vector (comma-separated)"),
outputs="text",
title="Custom Classifier Demo",
)
demo.launch()
|