iamomtiwari commited on
Commit
80725e9
·
verified ·
1 Parent(s): a32c184

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -1,27 +1,29 @@
1
  import gradio as gr
2
  from PIL import Image
3
- from transformers import ViTFeatureExtractor, ViTForImageClassification
 
4
 
5
- # Load the model and feature extractor
6
- feature_extractor = ViTFeatureExtractor.from_pretrained('wambugu1738/crop_leaf_diseases_vit')
7
- model = ViTForImageClassification.from_pretrained('wambugu1738/crop_leaf_diseases_vit')
 
 
 
8
 
9
- # Define prediction function
10
- def predict(image):
11
- image = Image.fromarray(image) # Convert image from numpy array to PIL Image
12
- inputs = feature_extractor(images=image, return_tensors="pt")
13
  outputs = model(**inputs)
14
  logits = outputs.logits
15
  predicted_class_idx = logits.argmax(-1).item()
16
  return model.config.id2label[predicted_class_idx]
17
 
18
- # Create Gradio interface
19
- iface = gr.Interface(
20
- fn=predict,
21
- inputs=gr.inputs.Image(type="numpy"), # Input type as a numpy array
22
- outputs="text",
23
- title="Crop Disease Detection",
24
- description="Upload an image of a crop leaf to detect diseases."
25
  )
26
 
27
- iface.launch()
 
 
1
  import gradio as gr
2
  from PIL import Image
3
+ from transformers import ViTImageProcessor, ViTForImageClassification
4
+ import torch
5
 
6
+ # Load the image processor and model
7
+ processor = ViTImageProcessor.from_pretrained('wambugu1738/crop_leaf_diseases_vit')
8
+ model = ViTForImageClassification.from_pretrained(
9
+ 'wambugu1738/crop_leaf_diseases_vit',
10
+ ignore_mismatched_sizes=True
11
+ )
12
 
13
+ # Define a function to make predictions
14
+ def classify_image(image):
15
+ inputs = processor(images=image, return_tensors="pt")
 
16
  outputs = model(**inputs)
17
  logits = outputs.logits
18
  predicted_class_idx = logits.argmax(-1).item()
19
  return model.config.id2label[predicted_class_idx]
20
 
21
+ # Create the Gradio interface
22
+ app = gr.Interface(
23
+ fn=classify_image,
24
+ inputs=gr.Image(type="numpy"), # Corrected input type
25
+ outputs="text"
 
 
26
  )
27
 
28
+ # Launch the Gradio app
29
+ app.launch()