shurik-p commited on
Commit
186680f
·
verified ·
1 Parent(s): 878c644

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import detectron2
3
+ except:
4
+ import os
5
+ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
6
+
7
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
8
+ import gradio as gr
9
+
10
+ qa_pipeline = pipeline("document-question-answering", model="shurik-p/llmv2-docvqa-finetuned")
11
+
12
+ ru_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
13
+ ru_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
14
+
15
+ en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
16
+ en_ru_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
17
+
18
+ def translate_ru2en(ru_question, ru_en_model, ru_tokenizer):
19
+ input_ids = ru_tokenizer.encode(ru_question, return_tensors="pt")
20
+ output_ids = ru_en_model.generate(input_ids, max_new_tokens=512)
21
+ en_question = ru_tokenizer.decode(output_ids[0], skip_special_tokens=True)
22
+ return en_question
23
+
24
+
25
+ def translate_en2ru(en_answer, en_ru_model, en_tokenizer):
26
+ input_ids = en_tokenizer.encode(en_answer, return_tensors="pt")
27
+ output_ids = en_ru_model.generate(input_ids, max_new_tokens=512)
28
+ ru_answer = en_tokenizer.decode(output_ids[0], skip_special_tokens=True)
29
+ return ru_answer
30
+
31
+ def ru_inference(image, ru_question):
32
+ en_question = translate_ru2en(ru_question, ru_en_model, ru_tokenizer)
33
+ en_answer = qa_pipeline(image=image, question=en_question)[0]['answer']
34
+ ru_answer = translate_en2ru(en_answer, en_ru_model, en_tokenizer)
35
+ return ru_answer
36
+
37
+
38
+ interface = gr.Interface(
39
+ fn=ru_inference,
40
+ inputs=[gr.Image(type="pil"), gr.Textbox(label="Question")],
41
+ outputs=[gr.Text()],
42
+ title='Document answer questions'
43
+ ).launch(debug=True)