Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from torchtext.data.utils import get_tokenizer | |
from torchtext.vocab import build_vocab_from_iterator | |
from torchtext.datasets import Multi30k | |
from torch import Tensor | |
from typing import Iterable, List | |
# Define your model, tokenizer, and other necessary components here | |
# Ensure you have imported all necessary libraries | |
# Load your transformer model | |
model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, | |
NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.load_state_dict(torch.load('./transformer_model.pth', map_location=device)) | |
model.eval() | |
def translate(model: torch.nn.Module, src_sentence: str): | |
model.eval() | |
src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1) | |
num_tokens = src.shape[0] | |
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool) | |
tgt_tokens = greedy_decode( | |
model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten() | |
return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "") | |
if __name__ == "__main__": | |
# Create the Gradio interface | |
iface = gr.Interface( | |
fn=translate, # Specify the translation function as the main function | |
inputs=[ | |
gr.inputs.Textbox(label="Text") | |
], | |
outputs=["text"], # Define the output type as text | |
cache_examples=False, # Disable caching of examples | |
title="germanToenglish", # Set the title of the interface | |
) | |
# Launch the interface | |
iface.launch(share=True) | |