Acetde commited on
Commit
9269f29
·
verified ·
1 Parent(s): 691e861

add FastAPI

Browse files
Files changed (1) hide show
  1. app.py +38 -44
app.py CHANGED
@@ -4,71 +4,65 @@ import onnxruntime as rt
4
  from torchvision import transforms as T
5
  from PIL import Image
6
  from tokenizer_base import Tokenizer
7
- import pathlib
8
- import os
9
- import gradio as gr
10
- from huggingface_hub import Repository
11
-
12
-
13
 
 
14
  model_file = "captcha.onnx"
15
  img_size = (32,128)
16
  charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
17
  tokenizer_base = Tokenizer(charset)
18
 
19
  def get_transform(img_size):
20
- transforms = []
21
- transforms.extend([
22
- T.Resize(img_size, T.InterpolationMode.BICUBIC),
23
- T.ToTensor(),
24
- T.Normalize(0.5, 0.5)
25
- ])
26
- return T.Compose(transforms)
27
 
28
  def to_numpy(tensor):
29
  return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
30
 
31
  def initialize_model(model_file):
32
  transform = get_transform(img_size)
33
- # Onnx model loading
34
  onnx_model = onnx.load(model_file)
35
  onnx.checker.check_model(onnx_model)
36
  ort_session = rt.InferenceSession(model_file)
37
- return transform,ort_session
 
 
 
 
 
38
 
 
39
  def get_text(img_org):
40
- # img_org = Image.open(image_path)
41
- # Preprocess. Model expects a batch of images with shape: (B, C, H, W)
42
  x = transform(img_org.convert('RGB')).unsqueeze(0)
43
-
44
- # compute ONNX Runtime output prediction
45
  ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
46
  logits = ort_session.run(None, ort_inputs)[0]
47
  probs = torch.tensor(logits).softmax(-1)
48
- preds, probs = tokenizer_base.decode(probs)
49
- preds = preds[0]
50
- print(preds)
51
- return preds
52
-
53
- transform,ort_session = initialize_model(model_file=model_file)
54
-
55
- # Создание интерфейса
56
- with gr.Blocks() as demo:
57
- image_input = gr.Image(type="pil")
58
- text_output = gr.Textbox()
59
-
60
- # Кнопка для обработки изображения
61
- submit_button = gr.Button("Распознать текст")
62
-
63
- # Связываем функцию с кнопкой
64
- submit_button.click(fn=get_text, inputs=image_input, outputs=text_output)
65
-
66
- # Запуск с включенной поддержкой очереди
67
- demo.queue().launch()
68
 
69
- # if __name__ == "__main__":
70
- # image_path = "8000.png"
71
- # preds,probs = get_text(image_path)
72
- # print(preds[0])
73
-
 
 
 
 
 
 
 
 
 
 
74
 
 
 
 
4
  from torchvision import transforms as T
5
  from PIL import Image
6
  from tokenizer_base import Tokenizer
7
+ from fastapi import FastAPI, File, UploadFile
8
+ from io import BytesIO
9
+ from fastapi.responses import JSONResponse
 
 
 
10
 
11
+ # Инициализация модели
12
  model_file = "captcha.onnx"
13
  img_size = (32,128)
14
  charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
15
  tokenizer_base = Tokenizer(charset)
16
 
17
  def get_transform(img_size):
18
+ transforms = []
19
+ transforms.extend([
20
+ T.Resize(img_size, T.InterpolationMode.BICUBIC),
21
+ T.ToTensor(),
22
+ T.Normalize(0.5, 0.5)
23
+ ])
24
+ return T.Compose(transforms)
25
 
26
  def to_numpy(tensor):
27
  return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
28
 
29
  def initialize_model(model_file):
30
  transform = get_transform(img_size)
31
+ # Загрузка модели ONNX
32
  onnx_model = onnx.load(model_file)
33
  onnx.checker.check_model(onnx_model)
34
  ort_session = rt.InferenceSession(model_file)
35
+ return transform, ort_session
36
+
37
+ transform, ort_session = initialize_model(model_file=model_file)
38
+
39
+ # Создаем FastAPI приложение
40
+ app = FastAPI()
41
 
42
+ # Функция для получения текста
43
  def get_text(img_org):
 
 
44
  x = transform(img_org.convert('RGB')).unsqueeze(0)
 
 
45
  ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
46
  logits = ort_session.run(None, ort_inputs)[0]
47
  probs = torch.tensor(logits).softmax(-1)
48
+ preds, _ = tokenizer_base.decode(probs)
49
+ return preds[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Маршрут для обработки POST-запросов с изображениями
52
+ @app.post("/predict")
53
+ async def predict(file: UploadFile = File(...)):
54
+ try:
55
+ # Получаем изображение из запроса
56
+ image_bytes = await file.read()
57
+ img = Image.open(BytesIO(image_bytes))
58
+
59
+ # Получаем текст с изображения
60
+ result = get_text(img)
61
+
62
+ # Возвращаем распознанный текст
63
+ return JSONResponse(content={"text": result})
64
+ except Exception as e:
65
+ return JSONResponse(status_code=500, content={"message": str(e)})
66
 
67
+ # Для запуска FastAPI приложения
68
+ # uvicorn main:app --reload