andgrt commited on
Commit
36e06f1
·
1 Parent(s): 246c8f9

upd: generate_answer

Browse files
Files changed (1) hide show
  1. app.py +23 -13
app.py CHANGED
@@ -5,8 +5,14 @@ from transformers import (
5
  AutoProcessor,
6
  AutoModelForDocumentQuestionAnswering,
7
  )
8
- from transformers import pipeline
9
 
 
 
 
 
 
 
10
 
11
  tokenizer_ru2en = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
12
  model_ru2en = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
@@ -14,14 +20,6 @@ model_ru2en = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ru-en"
14
  tokenizer_en2ru = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
15
  model_en2ru = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
16
 
17
- git_processor_base = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
18
-
19
- image_processor = git_processor_base.image_processor
20
-
21
-
22
- def preprocess_image(image):
23
- return git_processor_base(images=image, return_tensors="pt").pixel_values
24
-
25
 
26
  def translate_ru2en(text):
27
  inputs = tokenizer_ru2en(text, return_tensors="pt")
@@ -39,11 +37,23 @@ def translate_en2ru(text):
39
 
40
  def generate_answer_git(image, question):
41
 
42
- qa_pipeline = pipeline(
43
- "document-question-answering",
44
- model="andgrt/layoutlmv2-base-uncased_finetuned_docvqa",
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  )
46
- return qa_pipeline(preprocess_image(image), question)[0]["answer"]
47
 
48
 
49
  def generate_answer(image, question):
 
5
  AutoProcessor,
6
  AutoModelForDocumentQuestionAnswering,
7
  )
8
+ import torch
9
 
10
+ processor = AutoProcessor.from_pretrained(
11
+ "MariaK/layoutlmv2-base-uncased_finetuned_docvqa_v2"
12
+ )
13
+ model = AutoModelForDocumentQuestionAnswering.from_pretrained(
14
+ "MariaK/layoutlmv2-base-uncased_finetuned_docvqa_v2"
15
+ )
16
 
17
  tokenizer_ru2en = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
18
  model_ru2en = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
 
20
  tokenizer_en2ru = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
21
  model_en2ru = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
22
 
 
 
 
 
 
 
 
 
23
 
24
  def translate_ru2en(text):
25
  inputs = tokenizer_ru2en(text, return_tensors="pt")
 
37
 
38
  def generate_answer_git(image, question):
39
 
40
+ with torch.no_grad():
41
+ encoding = processor(
42
+ images=image,
43
+ text=question,
44
+ return_tensors="pt",
45
+ max_length=512,
46
+ truncation=True,
47
+ )
48
+ outputs = model(**encoding)
49
+ start_logits = outputs.start_logits
50
+ end_logits = outputs.end_logits
51
+ predicted_start_idx = start_logits.argmax(-1).item()
52
+ predicted_end_idx = end_logits.argmax(-1).item()
53
+
54
+ return processor.tokenizer.decode(
55
+ encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
56
  )
 
57
 
58
 
59
  def generate_answer(image, question):