yhamidullah's picture
update app.py
5bf9e80 verified
raw
history blame contribute delete
833 Bytes
import torch
import gradio as gr
from transformers import AutoConfig
from models import CustomClassifier, CustomClassificationConfig
MODEL_ID = "yhamidullah/custom-classifier-demo"
config = CustomClassificationConfig.from_pretrained(MODEL_ID)
model = CustomClassifier.from_pretrained(MODEL_ID)
model.eval()
def predict(input_csv: str):
vec = [float(x) for x in input_csv.split(",")]
if len(vec) != config.input_dim:
return f"Error: Need {config.input_dim} floats"
x = torch.tensor([vec])
with torch.no_grad():
logits = model(input_ids=x)["logits"]
pred = logits.argmax(dim=-1).item()
return f"Predicted class: {pred}"
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(label="Input Vector (comma-separated)"),
outputs="text",
title="Custom Classifier Demo",
)
demo.launch()