andgrt's picture
upd: log
86e6582
raw
history blame
4.1 kB
import gradio as gr
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoProcessor,
AutoModelForDocumentQuestionAnswering,
pipeline,
)
import torch
import torchaudio
processor = AutoProcessor.from_pretrained(
"MariaK/layoutlmv2-base-uncased_finetuned_docvqa_v2"
)
model = AutoModelForDocumentQuestionAnswering.from_pretrained(
"MariaK/layoutlmv2-base-uncased_finetuned_docvqa_v2"
)
tokenizer_ru2en = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
model_ru2en = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ru-en")
tokenizer_en2ru = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
model_en2ru = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ru")
transcriber = pipeline(
"automatic-speech-recognition", model="artyomboyko/whisper-base-fine_tuned-ru"
)
def translate_ru2en(text):
inputs = tokenizer_ru2en(text, return_tensors="pt")
outputs = model_ru2en.generate(**inputs)
translated_text = tokenizer_ru2en.decode(outputs[0], skip_special_tokens=True)
return translated_text
def translate_en2ru(text):
inputs = tokenizer_en2ru(text, return_tensors="pt")
outputs = model_en2ru.generate(**inputs)
translated_text = tokenizer_en2ru.decode(outputs[0], skip_special_tokens=True)
return translated_text
def generate_answer_git(image, question):
with torch.no_grad():
encoding = processor(
images=image,
text=question,
return_tensors="pt",
max_length=512,
truncation=True,
)
outputs = model(**encoding)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
predicted_start_idx = start_logits.argmax(-1).item()
predicted_end_idx = end_logits.argmax(-1).item()
return processor.tokenizer.decode(
encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
)
def generate_answer(image, question):
question_en = translate_ru2en(question)
print(f"Вопрос на английском: {question_en}")
answer_en = generate_answer_git(image, question_en)
print(f"Ответ на английском: {answer_en}")
answer_ru = translate_en2ru(answer_en)
return answer_ru
def transcribe(image, audio):
if not image or not audio:
return
sr, y = audio
if y.ndim > 1:
y = y.mean(axis=1)
y_tensor = torch.tensor(y, dtype=torch.float32)
print(y.shape)
if sr != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
y_tensor = resampler(y_tensor)
sr = 16000
y_tensor /= torch.max(torch.abs(y_tensor))
y = y_tensor.numpy()
print(y.shape)
input_features = transcriber.feature_extractor(
y, sampling_rate=sr, return_tensors="pt"
).input_features
print(input_features.shape)
print(input_features)
transcription = transcriber.model.generate(input_features)
transcription_text = transcriber.tokenizer.decode(
transcription[0], skip_special_tokens=True
)
return generate_answer(image, transcription_text)
qa_interface = gr.Interface(
fn=generate_answer,
inputs=[
gr.Image(type="pil"),
gr.Textbox(label="Вопрос (на русском)", placeholder="Ваш вопрос"),
],
outputs=gr.Textbox(label="Ответ (на русском)"),
examples=[["doc.png", "О чем данный документ?"]],
live=False,
)
speech_interface = gr.Interface(
fn=transcribe,
inputs=[
gr.Image(type="pil"),
gr.Audio(sources="microphone", label="Голосовой ввод"),
],
outputs=gr.Textbox(label="Распознанный текст"),
live=True,
)
interface = gr.TabbedInterface(
[qa_interface, speech_interface],
["Текстовый вопрос", "Голосовой вопрос"],
title="Демо визуального ответчика на вопросы (на русском)",
)
interface.launch(debug=True, share=True)