Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ import shutil
|
|
10 |
import requests
|
11 |
import pandas as pd
|
12 |
import difflib
|
|
|
13 |
|
14 |
# OCR Correction Model
|
15 |
ocr_model_name = "PleIAs/OCRonos-Vintage"
|
@@ -162,22 +163,26 @@ def split_text(text, max_tokens=500):
|
|
162 |
|
163 |
|
164 |
# Function to generate text
|
165 |
-
|
166 |
-
def ocr_correction(prompt, max_new_tokens=500):
|
167 |
-
model.to(device)
|
168 |
-
|
169 |
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
|
170 |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
171 |
|
|
|
|
|
|
|
172 |
# Generate text
|
173 |
-
|
|
|
|
|
|
|
174 |
max_new_tokens=max_new_tokens,
|
175 |
pad_token_id=tokenizer.eos_token_id,
|
176 |
top_k=50,
|
177 |
num_return_sequences=1,
|
178 |
-
do_sample=
|
179 |
-
temperature=0.7
|
180 |
)
|
|
|
|
|
181 |
# Decode and return the generated text
|
182 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
183 |
print(result)
|
|
|
10 |
import requests
|
11 |
import pandas as pd
|
12 |
import difflib
|
13 |
+
from concurrent.futures import ThreadPoolExecutor
|
14 |
|
15 |
# OCR Correction Model
|
16 |
ocr_model_name = "PleIAs/OCRonos-Vintage"
|
|
|
163 |
|
164 |
|
165 |
# Function to generate text
|
166 |
+
def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
|
|
|
|
|
|
|
167 |
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
|
168 |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
169 |
|
170 |
+
# Set the number of threads for PyTorch
|
171 |
+
torch.set_num_threads(num_threads)
|
172 |
+
|
173 |
# Generate text
|
174 |
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
175 |
+
future = executor.submit(
|
176 |
+
model.generate,
|
177 |
+
input_ids,
|
178 |
max_new_tokens=max_new_tokens,
|
179 |
pad_token_id=tokenizer.eos_token_id,
|
180 |
top_k=50,
|
181 |
num_return_sequences=1,
|
182 |
+
do_sample=False
|
|
|
183 |
)
|
184 |
+
output = future.result()
|
185 |
+
|
186 |
# Decode and return the generated text
|
187 |
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
188 |
print(result)
|