skypro1111 commited on
Commit
b3e895e
·
verified ·
1 Parent(s): 3bf7f00

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -0
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import onnxruntime
3
+ import numpy as np
4
+ from transformers import AutoTokenizer
5
+ import time
6
+ import os
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ model_name = "skypro1111/mbart-large-50-verbalization"
10
+
11
+ # Example inputs for the dropdown
12
+ EXAMPLES = [
13
+ ["мій телефон 0979456822"],
14
+ ["квартира площею 11 тис кв м."],
15
+ ["Пропонували хабар у 1 млрд грн."],
16
+ ["1 2 3 4 5 6 7 8 9 10."],
17
+ ["Крім того, парламентарій володіє шістьма ділянками землі (дві площею 25000 кв м, дві по 15000 кв м та дві по 10000 кв м) розташованими в Сосновій Балці Луганської області."],
18
+ ["Підписуючи цей документ у 2003 році, голови Росії та України мали намір зміцнити співпрацю та сприяти розширенню двосторонніх відносин."],
19
+ ["Очікується, що цей застосунок буде запущено 22.08.2025."],
20
+ ["За інформацією від Державної служби з надзвичайних ситуацій станом на 7 ранку 15 липня."],
21
+ ]
22
+
23
+ def download_model_from_hf(repo_id=model_name, model_dir="./"):
24
+ """Download ONNX models from HuggingFace Hub."""
25
+ files = ["onnx/encoder_model.onnx", "onnx/decoder_model.onnx", "onnx/decoder_model.onnx_data"]
26
+
27
+ for file in files:
28
+ hf_hub_download(
29
+ repo_id=repo_id,
30
+ filename=file,
31
+ local_dir=model_dir,
32
+ )
33
+
34
+ return files
35
+
36
+ def create_onnx_session(model_path, use_gpu=True):
37
+ """Create an ONNX inference session."""
38
+ session_options = onnxruntime.SessionOptions()
39
+ session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
40
+ session_options.enable_mem_pattern = True
41
+ session_options.enable_mem_reuse = True
42
+ session_options.intra_op_num_threads = 8
43
+ session_options.log_severity_level = 1
44
+
45
+ cuda_provider_options = {
46
+ 'device_id': 0,
47
+ 'arena_extend_strategy': 'kSameAsRequested',
48
+ 'gpu_mem_limit': 0, # 0 means no limit
49
+ 'cudnn_conv_algo_search': 'DEFAULT',
50
+ 'do_copy_in_default_stream': True,
51
+ }
52
+
53
+ if use_gpu and 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
54
+ providers = [('CUDAExecutionProvider', cuda_provider_options)]
55
+ else:
56
+ providers = ['CPUExecutionProvider']
57
+
58
+ session = onnxruntime.InferenceSession(
59
+ model_path,
60
+ providers=providers,
61
+ sess_options=session_options
62
+ )
63
+
64
+ return session
65
+
66
+ def generate_text(text, tokenizer, encoder_session, decoder_session, max_length=128):
67
+ """Generate text for a single input."""
68
+ # Prepare input
69
+ inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=512)
70
+ input_ids = inputs["input_ids"].astype(np.int64)
71
+ attention_mask = inputs["attention_mask"].astype(np.int64)
72
+
73
+ # Run encoder
74
+ encoder_outputs = encoder_session.run(
75
+ output_names=["last_hidden_state"],
76
+ input_feed={
77
+ "input_ids": input_ids,
78
+ "attention_mask": attention_mask,
79
+ }
80
+ )[0]
81
+
82
+ # Initialize decoder input
83
+ decoder_input_ids = np.array([[tokenizer.pad_token_id]], dtype=np.int64)
84
+
85
+ # Generate sequence
86
+ for _ in range(max_length):
87
+ # Run decoder
88
+ decoder_outputs = decoder_session.run(
89
+ output_names=["logits"],
90
+ input_feed={
91
+ "input_ids": decoder_input_ids,
92
+ "encoder_hidden_states": encoder_outputs,
93
+ "encoder_attention_mask": attention_mask,
94
+ }
95
+ )[0]
96
+
97
+ # Get next token
98
+ next_token = decoder_outputs[:, -1:].argmax(axis=-1)
99
+ decoder_input_ids = np.concatenate([decoder_input_ids, next_token], axis=-1)
100
+
101
+ # Check if sequence is complete
102
+ if tokenizer.eos_token_id in decoder_input_ids[0]:
103
+ break
104
+
105
+ # Decode sequence
106
+ output_text = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=True)
107
+ return output_text
108
+
109
+ # Initialize models and tokenizer globally
110
+ print("Downloading models...")
111
+ files = download_model_from_hf()
112
+
113
+ print("Loading tokenizer...")
114
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
115
+ tokenizer.src_lang = "uk_UA"
116
+ tokenizer.tgt_lang = "uk_UA"
117
+
118
+ print("Creating ONNX sessions...")
119
+ encoder_session = create_onnx_session("onnx/encoder_model.onnx")
120
+ decoder_session = create_onnx_session("onnx/decoder_model.onnx")
121
+
122
+ def inference(text):
123
+ """Gradio inference function"""
124
+ start_time = time.time()
125
+
126
+ # Generate text
127
+ output = generate_text(text, tokenizer, encoder_session, decoder_session)
128
+
129
+ # Calculate inference time
130
+ inference_time = time.time() - start_time
131
+
132
+ return output, f"{inference_time:.2f} seconds"
133
+
134
+ # Create Gradio interface
135
+ with gr.Blocks(title="Numbers to Words ONNX Inference") as demo:
136
+ gr.Markdown("# Numbers to Words ONNX Inference")
137
+ gr.Markdown("Convert numbers in Ukrainian text to words using ONNX optimized model")
138
+
139
+ with gr.Row():
140
+ with gr.Column():
141
+ input_text = gr.Textbox(
142
+ label="Input Text",
143
+ placeholder="Enter Ukrainian text with numbers...",
144
+ lines=3
145
+ )
146
+ inference_button = gr.Button("Run Inference", variant="primary")
147
+
148
+ with gr.Column():
149
+ output_text = gr.Textbox(
150
+ label="Output Text",
151
+ lines=3,
152
+ interactive=False
153
+ )
154
+ inference_time = gr.Textbox(
155
+ label="Inference Time",
156
+ interactive=False
157
+ )
158
+
159
+ # Add examples
160
+ gr.Examples(
161
+ examples=EXAMPLES,
162
+ inputs=input_text,
163
+ label="Example Inputs"
164
+ )
165
+
166
+ # Set up inference button click event
167
+ inference_button.click(
168
+ fn=inference,
169
+ inputs=input_text,
170
+ outputs=[output_text, inference_time]
171
+ )
172
+
173
+ if __name__ == "__main__":
174
+ demo.launch(share=True)