File size: 3,502 Bytes
d566fee a3cb4b9 ee36b3c ba699eb d566fee a3cb4b9 d566fee a3cb4b9 ee36b3c d566fee ee36b3c 8878677 77be94a 8878677 a27a94b a3cb4b9 ba699eb a3cb4b9 ee36b3c d566fee 6704e8f 4a077d0 a3cb4b9 6704e8f f132889 d566fee 6704e8f d566fee 4a077d0 d566fee 23fec88 41dcd30 d566fee ba699eb 2825722 ba699eb a8f5c0e a8296f6 2825722 d566fee 910566d 6704e8f 8a69f2c 6704e8f d566fee 910566d 6704e8f 910566d ba699eb d566fee fe44a5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import google.generativeai as genai
import os
import markdown2
# Load the TensorFlow model
model_path = 'model'
model = tf.saved_model.load(model_path)
# Configure Gemini API
api_key = os.getenv("GEMINI_API_KEY")
genai.configure(api_key=api_key)
labels = ['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']
def get_disease_detail(disease_name):
# if disease_name == "normal":
# prompt = (
# "Create a text that congratulates having healthy eyes and gives bullet point tips to keep eyes healthy."
# )
# else:
prompt = (
"You are an Ophthalmologist with over 25 years of experience, you have treated thousands of patients with various eye diseases including cataracts, diabetic retinopathy and glaucoma. The entire medical process from disease identification to patient management is second nature to you and you are used to it. Your job is to critically and comprehensively make recommendations based on the diagnosis, the recommendations contain actions that can be taken on the patient, no need to re-explain the disease. In every recommendation you must remind the patient to always see the Ophthalmologist to validate the diagnosis and recommendation.\n"
f"The diagnosis is {disease_name}, what are your recommendations?"
)
try:
response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
return markdown2.markdown(response.text.strip())
except Exception as e:
return f"Error: {e}"
def predict_image(image):
image_resized = image.resize((224, 224))
image_array = np.array(image_resized).astype(np.float32) / 255.0
image_array = np.expand_dims(image_array, axis=0)
predictions = model.signatures['serving_default'](tf.convert_to_tensor(image_array, dtype=tf.float32))['output_0']
# Highest prediction
top_index = np.argmax(predictions.numpy(), axis=1)[0]
top_label = labels[top_index]
top_probability = predictions.numpy()[0][top_index]
explanation = get_disease_detail(top_label)
return {top_label: top_probability}, explanation
# Example images
example_images = [
["exp_eye_images/0_right_h.png"],
["exp_eye_images/03fd50da928d_dr.png"],
["exp_eye_images/108_right_h.png"],
["exp_eye_images/1062_right_c.png"],
["exp_eye_images/1084_right_c.png"],
["exp_eye_images/image_1002_g.jpg"]
]
# Custom CSS for HTML height
css = """
.scrollable-html {
height: 206px;
overflow-y: auto;
border: 1px solid #ccc;
padding: 10px;
box-sizing: border-box;
}
"""
# Gradio Interface
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=1, label="Prediction"),
gr.HTML(label="Explanation", elem_classes=["scrollable-html"])
],
examples=example_images,
title="Eye Diseases Classifier",
description=(
"Upload an image of an eye fundus, and the model will predict it.\n\n"
"**Disclaimer:** This model is intended as a form of learning process in the field of health-related machine learning and was trained with a limited amount and variety of data with a total of about 4000 data, so the prediction results may not always be correct. There is still a lot of room for improvisation on this model in the future."
),
allow_flagging="never",
css=css
)
interface.launch(share=True)
|