File size: 3,888 Bytes
7b21a18
 
 
 
 
 
 
 
 
 
0650c92
 
 
 
 
 
7b21a18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adf7368
7b21a18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
from transformers import ViTForImageClassification, ViTFeatureExtractor, AutoTokenizer, AutoModelForQuestionAnswering
import torch
from PIL import Image
from transformers import pipeline
import numpy as np
from haystack import Document
from haystack.components.readers import ExtractiveReader
import wikipediaapi

model = ViTForImageClassification.from_pretrained("Dmitry43243242/banana-disease-leaf-model")
feature_extractor = ViTFeatureExtractor.from_pretrained("Dmitry43243242/banana-disease-leaf-model")
tokenizer_qa = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
model_qa = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
translator = pipeline("translation", model="facebook/nllb-200-distilled-600M", torch_dtype=torch.bfloat16)

def translate_question(question, translator):
  text_translated = translator(question,
                             src_lang="rus_Cyrl",
                             tgt_lang="eng_Latn")
  return text_translated[0]['translation_text']

def translate_summary(text_to_translate, translator):
  text_translated = translator(text_to_translate,
                             src_lang="eng_Latn",
                             tgt_lang="rus_Cyrl")
  return text_translated[0]['translation_text']

#Получение текста из Wikipedia
def get_wiki_text(disease):
    wiki = wikipediaapi.Wikipedia(user_agent='Mozilla/5.0 (Windows; Windows NT 10.0;; en-US) AppleWebKit/601.41 (KHTML, like Gecko) Chrome/51.0.1308.224 Safari/601.6 Edge/14.17024', language="en")  # Для русского языка укажите 'ru'
    page = wiki.page(disease)
    return page.text if page.exists() else ''

def process_input(img, text):
  inputs=gr.Image(),
  outputs=gr.Label(num_top_classes=3),
  title="Banana Leaf Disease Classifier"
  labels = ['banana_healthy_leaf', 'black_sigatoka', 'yellow_sigatoka', 'panama_disease', 'moko_disease', 'insect_pest', 'bract_mosaic_virus']


  translate_eng_question = translate_question(text, translator)

  img = Image.fromarray(np.uint8(img)).convert("RGB")
  inputs = feature_extractor(images=img, return_tensors="pt")
  outputs = model(**inputs)
  probas = torch.nn.functional.softmax(outputs.logits, dim=-1)


  max_index = probas[0].argmax().item()

  diagnose = labels[max_index].replace('_', ' ')

  docs = [
    Document(content=get_wiki_text(diagnose)),
  ]

  reader = ExtractiveReader(model="deepset/roberta-base-squad2")
  reader.warm_up()

  result = reader.run(query=translate_eng_question, documents=docs)

  text = [answer.data for answer in result['answers'] if answer.data != None]
  result = '.'.join(text[0:6])

  summary = summarizer(
    result,
    max_length=150,
    min_length=60,
    do_sample=False,
    truncation=True
  )

  summary_text = summary[0]['summary_text']

  answer = translate_summary(summary_text, translator)

  return {labels[i]: float(probas[0][i]) for i in range(len(labels))}, answer

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## Классификация болезней банановых листьев 🌿")

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Загрузите изображение листа", type="pil")
            question_input = gr.Textbox(label="Ваш вопрос о болезни")
            submit_btn = gr.Button("Анализировать", variant="primary")

        with gr.Column():
            confidence_bar = gr.Label(num_top_classes=3, label="Диагноз")
            answer_output = gr.Textbox(label="Ответ на вопрос")

    submit_btn.click(
        fn=process_input,
        inputs=[image_input, question_input],
        outputs=[confidence_bar, answer_output]
    )


demo.launch(share=True, debug=True)