andgrt commited on
Commit
3c9287e
·
1 Parent(s): 9c7b30e

fix preprocess_image

Browse files
Files changed (1) hide show
  1. app.py +3 -9
app.py CHANGED
@@ -6,7 +6,7 @@ from transformers import (
6
  AutoModelForDocumentQuestionAnswering,
7
  )
8
  from transformers import pipeline
9
- import torch
10
 
11
  tokenizer_ru2en = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
12
  model_ru2en = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
@@ -16,19 +16,13 @@ model_en2ru = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru"
16
 
17
  git_processor_base = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
18
 
19
- git_model_base = AutoModelForDocumentQuestionAnswering.from_pretrained(
20
- "andgrt/layoutlmv2-base-uncased_finetuned_docvqa"
21
- )
22
-
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
- git_model_base.to(device)
25
  image_processor = git_processor_base.image_processor
26
 
27
 
28
  def preprocess_image(image):
29
  """Преобразуем изображение для модели"""
30
  image_rgb = image.convert("RGB")
31
- return image_processor([image_rgb], return_tensors="pt").pixel_values.to(device)
32
 
33
 
34
  def translate_ru2en(text):
@@ -51,7 +45,7 @@ def generate_answer_git(image, question):
51
  "document-question-answering",
52
  model="andgrt/layoutlmv2-base-uncased_finetuned_docvqa",
53
  )
54
- return qa_pipeline(image, question)[0]["answer"]
55
  # pixel_values, _, _ = preprocess_image(image)
56
  # input_ids = processor(text=question, add_special_tokens=False).input_ids
57
  # input_ids = [processor.tokenizer.cls_token_id] + input_ids
 
6
  AutoModelForDocumentQuestionAnswering,
7
  )
8
  from transformers import pipeline
9
+
10
 
11
  tokenizer_ru2en = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
12
  model_ru2en = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
 
16
 
17
  git_processor_base = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
18
 
 
 
 
 
 
 
19
  image_processor = git_processor_base.image_processor
20
 
21
 
22
  def preprocess_image(image):
23
  """Преобразуем изображение для модели"""
24
  image_rgb = image.convert("RGB")
25
+ return image_processor([image_rgb]).pixel_values
26
 
27
 
28
  def translate_ru2en(text):
 
45
  "document-question-answering",
46
  model="andgrt/layoutlmv2-base-uncased_finetuned_docvqa",
47
  )
48
+ return qa_pipeline(preprocess_image(image), question)[0]["answer"]
49
  # pixel_values, _, _ = preprocess_image(image)
50
  # input_ids = processor(text=question, add_special_tokens=False).input_ids
51
  # input_ids = [processor.tokenizer.cls_token_id] + input_ids