techAInewb's picture
Update app.py
339541e verified
raw
history blame
3.75 kB
import gradio as gr
import numpy as np
import onnxruntime as ort
import torch
import gc
import os
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import hf_hub_download, HfFolder
token = HfFolder.get_token() or os.getenv("HF_TOKEN")
HF_MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"
HF_ONNX_REPO = "techAInewb/mistral-nemo-2407-fp32"
ONNX_MODEL_FILE = "model.onnx"
# Shared tokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID, token=token)
def greedy_decode_onnx(session, input_ids, attention_mask, max_new_tokens=50):
generated = input_ids.copy()
for _ in range(max_new_tokens):
outputs = session.run(None, {
"input_ids": generated,
"attention_mask": attention_mask
})
next_token_logits = outputs[0][:, -1, :]
next_token = np.argmax(next_token_logits, axis=-1).reshape(-1, 1)
generated = np.concatenate((generated, next_token), axis=1)
attention_mask = np.concatenate(
(attention_mask, np.ones((1, 1), dtype=np.int64)), axis=1)
if next_token[0][0] == tokenizer.eos_token_id:
break
return tokenizer.decode(generated[0], skip_special_tokens=True)
def compare_outputs(prompt):
summary_log = []
# πŸ”Ή PyTorch Generate
pt_output_text = ""
pt_start = time.time()
try:
torch_inputs = tokenizer(prompt, return_tensors="pt")
pt_model = AutoModelForCausalLM.from_pretrained(HF_MODEL_ID, torch_dtype=torch.float32, token=token)
pt_model.eval()
with torch.no_grad():
pt_outputs = pt_model.generate(**torch_inputs, max_new_tokens=50)
pt_output_text = tokenizer.decode(pt_outputs[0], skip_special_tokens=True)
pt_time = time.time() - pt_start
summary_log.append(f"🧠 PyTorch output length: {pt_outputs.shape[1]} tokens | Time: {pt_time:.2f}s")
finally:
del pt_model
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# πŸ”Ή ONNX Generate (Greedy)
ort_output_text = ""
ort_start = time.time()
ort_inputs = tokenizer(prompt, return_tensors="np")
onnx_path = hf_hub_download(repo_id=HF_ONNX_REPO, filename=ONNX_MODEL_FILE)
ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
ort_output_text = greedy_decode_onnx(
ort_session, ort_inputs["input_ids"], ort_inputs["attention_mask"], max_new_tokens=50
)
ort_time = time.time() - ort_start
summary_log.append(f"βš™οΈ ONNX output length: {len(tokenizer(ort_output_text)['input_ids'])} tokens | Time: {ort_time:.2f}s")
# Final notes
summary_log.append(f"πŸ§ͺ Tokenizer source: {tokenizer.name_or_path} | Vocab size: {tokenizer.vocab_size}")
summary_log.append("πŸ’‘ Note: Future versions will include quantized ONNX (INT8) + Vitis AI support.")
return pt_output_text, ort_output_text, "\n".join(summary_log)
example_prompts = [
"Who was the first president of the United States?",
"If you have 3 apples and eat 1, how many are left?",
"Write a short poem about memory and time.",
"Explain the laws of motion in simple terms.",
"What happens when you mix baking soda and vinegar?"
]
iface = gr.Interface(
fn=compare_outputs,
inputs=gr.Textbox(lines=2, placeholder="Enter a prompt..."),
outputs=[
gr.Textbox(label="PyTorch Output"),
gr.Textbox(label="ONNX Output"),
gr.Textbox(label="Test Summary & Metadata")
],
title="ONNX vs PyTorch (Full Output Comparison)",
description="Sequentially runs both models on your prompt and returns decoded output + metadata.",
examples=[[p] for p in example_prompts]
)
iface.launch()