|
import gradio as gr |
|
from transformers import ViTForImageClassification, ViTFeatureExtractor, AutoTokenizer, AutoModelForQuestionAnswering |
|
import torch |
|
from PIL import Image |
|
from transformers import pipeline |
|
import numpy as np |
|
from haystack import Document |
|
from haystack.components.readers import ExtractiveReader |
|
import wikipediaapi |
|
|
|
model = ViTForImageClassification.from_pretrained("Dmitry43243242/banana-disease-leaf-model") |
|
feature_extractor = ViTFeatureExtractor.from_pretrained("Dmitry43243242/banana-disease-leaf-model") |
|
tokenizer_qa = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") |
|
model_qa = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") |
|
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") |
|
translator = pipeline("translation", model="facebook/nllb-200-distilled-600M", torch_dtype=torch.bfloat16) |
|
|
|
def translate_question(question, translator): |
|
text_translated = translator(question, |
|
src_lang="rus_Cyrl", |
|
tgt_lang="eng_Latn") |
|
return text_translated[0]['translation_text'] |
|
|
|
def translate_summary(text_to_translate, translator): |
|
text_translated = translator(text_to_translate, |
|
src_lang="eng_Latn", |
|
tgt_lang="rus_Cyrl") |
|
return text_translated[0]['translation_text'] |
|
|
|
|
|
def get_wiki_text(disease): |
|
wiki = wikipediaapi.Wikipedia(user_agent='Mozilla/5.0 (Windows; Windows NT 10.0;; en-US) AppleWebKit/601.41 (KHTML, like Gecko) Chrome/51.0.1308.224 Safari/601.6 Edge/14.17024', language="en") |
|
page = wiki.page(disease) |
|
return page.text if page.exists() else '' |
|
|
|
def process_input(img, text): |
|
inputs=gr.Image(), |
|
outputs=gr.Label(num_top_classes=3), |
|
title="Banana Leaf Disease Classifier" |
|
labels = ['banana_healthy_leaf', 'black_sigatoka', 'yellow_sigatoka', 'panama_disease', 'moko_disease', 'insect_pest', 'bract_mosaic_virus'] |
|
|
|
|
|
translate_eng_question = translate_question(text, translator) |
|
|
|
img = Image.fromarray(np.uint8(img)).convert("RGB") |
|
inputs = feature_extractor(images=img, return_tensors="pt") |
|
outputs = model(**inputs) |
|
probas = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
|
|
|
max_index = probas[0].argmax().item() |
|
|
|
diagnose = labels[max_index].replace('_', ' ') |
|
|
|
docs = [ |
|
Document(content=get_wiki_text(diagnose)), |
|
] |
|
|
|
reader = ExtractiveReader(model="deepset/roberta-base-squad2") |
|
reader.warm_up() |
|
|
|
result = reader.run(query=translate_eng_question, documents=docs) |
|
|
|
text = [answer.data for answer in result['answers'] if answer.data != None] |
|
result = '.'.join(text[0:6]) |
|
|
|
summary = summarizer( |
|
result, |
|
max_length=150, |
|
min_length=60, |
|
do_sample=False, |
|
truncation=True |
|
) |
|
|
|
summary_text = summary[0]['summary_text'] |
|
|
|
answer = translate_summary(summary_text, translator) |
|
|
|
return {labels[i]: float(probas[0][i]) for i in range(len(labels))}, answer |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("## Классификация болезней банановых листьев 🌿") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(label="Загрузите изображение листа", type="pil") |
|
question_input = gr.Textbox(label="Ваш вопрос о болезни") |
|
submit_btn = gr.Button("Анализировать", variant="primary") |
|
|
|
with gr.Column(): |
|
confidence_bar = gr.Label(num_top_classes=3, label="Диагноз") |
|
answer_output = gr.Textbox(label="Ответ на вопрос") |
|
|
|
submit_btn.click( |
|
fn=process_input, |
|
inputs=[image_input, question_input], |
|
outputs=[confidence_bar, answer_output] |
|
) |
|
|
|
|
|
demo.launch(share=True, debug=True) |