Dmitry43243242's picture
Update app.py
0650c92 verified
import gradio as gr
from transformers import ViTForImageClassification, ViTFeatureExtractor, AutoTokenizer, AutoModelForQuestionAnswering
import torch
from PIL import Image
from transformers import pipeline
import numpy as np
from haystack import Document
from haystack.components.readers import ExtractiveReader
import wikipediaapi
model = ViTForImageClassification.from_pretrained("Dmitry43243242/banana-disease-leaf-model")
feature_extractor = ViTFeatureExtractor.from_pretrained("Dmitry43243242/banana-disease-leaf-model")
tokenizer_qa = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
model_qa = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
translator = pipeline("translation", model="facebook/nllb-200-distilled-600M", torch_dtype=torch.bfloat16)
def translate_question(question, translator):
text_translated = translator(question,
src_lang="rus_Cyrl",
tgt_lang="eng_Latn")
return text_translated[0]['translation_text']
def translate_summary(text_to_translate, translator):
text_translated = translator(text_to_translate,
src_lang="eng_Latn",
tgt_lang="rus_Cyrl")
return text_translated[0]['translation_text']
#Получение текста из Wikipedia
def get_wiki_text(disease):
wiki = wikipediaapi.Wikipedia(user_agent='Mozilla/5.0 (Windows; Windows NT 10.0;; en-US) AppleWebKit/601.41 (KHTML, like Gecko) Chrome/51.0.1308.224 Safari/601.6 Edge/14.17024', language="en") # Для русского языка укажите 'ru'
page = wiki.page(disease)
return page.text if page.exists() else ''
def process_input(img, text):
inputs=gr.Image(),
outputs=gr.Label(num_top_classes=3),
title="Banana Leaf Disease Classifier"
labels = ['banana_healthy_leaf', 'black_sigatoka', 'yellow_sigatoka', 'panama_disease', 'moko_disease', 'insect_pest', 'bract_mosaic_virus']
translate_eng_question = translate_question(text, translator)
img = Image.fromarray(np.uint8(img)).convert("RGB")
inputs = feature_extractor(images=img, return_tensors="pt")
outputs = model(**inputs)
probas = torch.nn.functional.softmax(outputs.logits, dim=-1)
max_index = probas[0].argmax().item()
diagnose = labels[max_index].replace('_', ' ')
docs = [
Document(content=get_wiki_text(diagnose)),
]
reader = ExtractiveReader(model="deepset/roberta-base-squad2")
reader.warm_up()
result = reader.run(query=translate_eng_question, documents=docs)
text = [answer.data for answer in result['answers'] if answer.data != None]
result = '.'.join(text[0:6])
summary = summarizer(
result,
max_length=150,
min_length=60,
do_sample=False,
truncation=True
)
summary_text = summary[0]['summary_text']
answer = translate_summary(summary_text, translator)
return {labels[i]: float(probas[0][i]) for i in range(len(labels))}, answer
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Классификация болезней банановых листьев 🌿")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Загрузите изображение листа", type="pil")
question_input = gr.Textbox(label="Ваш вопрос о болезни")
submit_btn = gr.Button("Анализировать", variant="primary")
with gr.Column():
confidence_bar = gr.Label(num_top_classes=3, label="Диагноз")
answer_output = gr.Textbox(label="Ответ на вопрос")
submit_btn.click(
fn=process_input,
inputs=[image_input, question_input],
outputs=[confidence_bar, answer_output]
)
demo.launch(share=True, debug=True)