import gradio as gr from transformers import ViTForImageClassification, ViTFeatureExtractor, AutoTokenizer, AutoModelForQuestionAnswering import torch from PIL import Image from transformers import pipeline import numpy as np from haystack import Document from haystack.components.readers import ExtractiveReader import wikipediaapi model = ViTForImageClassification.from_pretrained("Dmitry43243242/banana-disease-leaf-model") feature_extractor = ViTFeatureExtractor.from_pretrained("Dmitry43243242/banana-disease-leaf-model") tokenizer_qa = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") model_qa = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") summarizer = pipeline("summarization", model="facebook/bart-large-cnn") translator = pipeline("translation", model="facebook/nllb-200-distilled-600M", torch_dtype=torch.bfloat16) def translate_question(question, translator): text_translated = translator(question, src_lang="rus_Cyrl", tgt_lang="eng_Latn") return text_translated[0]['translation_text'] def translate_summary(text_to_translate, translator): text_translated = translator(text_to_translate, src_lang="eng_Latn", tgt_lang="rus_Cyrl") return text_translated[0]['translation_text'] #Получение текста из Wikipedia def get_wiki_text(disease): wiki = wikipediaapi.Wikipedia(user_agent='Mozilla/5.0 (Windows; Windows NT 10.0;; en-US) AppleWebKit/601.41 (KHTML, like Gecko) Chrome/51.0.1308.224 Safari/601.6 Edge/14.17024', language="en") # Для русского языка укажите 'ru' page = wiki.page(disease) return page.text if page.exists() else '' def process_input(img, text): inputs=gr.Image(), outputs=gr.Label(num_top_classes=3), title="Banana Leaf Disease Classifier" labels = ['banana_healthy_leaf', 'black_sigatoka', 'yellow_sigatoka', 'panama_disease', 'moko_disease', 'insect_pest', 'bract_mosaic_virus'] translate_eng_question = translate_question(text, translator) img = Image.fromarray(np.uint8(img)).convert("RGB") inputs = feature_extractor(images=img, return_tensors="pt") outputs = model(**inputs) probas = torch.nn.functional.softmax(outputs.logits, dim=-1) max_index = probas[0].argmax().item() diagnose = labels[max_index].replace('_', ' ') docs = [ Document(content=get_wiki_text(diagnose)), ] reader = ExtractiveReader(model="deepset/roberta-base-squad2") reader.warm_up() result = reader.run(query=translate_eng_question, documents=docs) text = [answer.data for answer in result['answers'] if answer.data != None] result = '.'.join(text[0:6]) summary = summarizer( result, max_length=150, min_length=60, do_sample=False, truncation=True ) summary_text = summary[0]['summary_text'] answer = translate_summary(summary_text, translator) return {labels[i]: float(probas[0][i]) for i in range(len(labels))}, answer with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("## Классификация болезней банановых листьев 🌿") with gr.Row(): with gr.Column(): image_input = gr.Image(label="Загрузите изображение листа", type="pil") question_input = gr.Textbox(label="Ваш вопрос о болезни") submit_btn = gr.Button("Анализировать", variant="primary") with gr.Column(): confidence_bar = gr.Label(num_top_classes=3, label="Диагноз") answer_output = gr.Textbox(label="Ответ на вопрос") submit_btn.click( fn=process_input, inputs=[image_input, question_input], outputs=[confidence_bar, answer_output] ) demo.launch(share=True, debug=True)