QuefrencyGuardian / examples /example_usage_fastmodel_hf.py
tlemagueresse
Update read me and delete notebooks
eaba0c2
raw
history blame
1.02 kB
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from sklearn.metrics import accuracy_score
import importlib.util
repo_id = "tlmk22/QuefrencyGuardian"
model_file = "model.py"
model_path = hf_hub_download(repo_id=repo_id, filename=model_file)
spec = importlib.util.spec_from_file_location("model", model_path)
model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_module)
FastModelHuggingFace = model_module.FastModelHuggingFace
fast_model = FastModelHuggingFace.from_pretrained(repo_id)
# Perform predictions for a single WAV file
map_labels = {0: "chainsaw", 1: "environment"}
wav_prediction = fast_model.predict("chainsaw.wav", device="cpu")
print(f"Prediction : {map_labels[wav_prediction[0]]}")
# Example: predicting on a Hugging Face dataset
dataset = load_dataset("rfcx/frugalai")
test_dataset = dataset["test"]
true_label = dataset["test"]["label"]
predictions = fast_model.predict(dataset["test"])
print(accuracy_score(true_label, predictions))