mjsolidarios's picture
Full feature upload.
c87ad69
raw
history blame
3.1 kB
import streamlit as st
import tempfile
from PIL import Image
import os
import tensorflow as tf
import numpy as np
import io
import json
from google import genai
from dotenv import dotenv_values, load_dotenv
load_dotenv()
config = dotenv_values(".env")
MODEL_ID = "gemini-2.0-flash"
model_id = MODEL_ID
client = genai.Client(api_key = config["GEMINI_API_KEY"])
# Load labels from JSON file
with open('model/labels.json', 'r') as f:
labels = json.load(f)
# Load TensorFlow Lite model
def load_tflite_model():
interpreter = tf.lite.Interpreter(model_path="model/model.tflite")
interpreter.allocate_tensors()
return interpreter
# Get input and output details
interpreter = load_tflite_model()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
def main():
st.title("Skin Cancer Classifier")
img_file = None
image_path = ""
# Upload an image
img_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if img_file is not None:
# Save the image to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
temp_file.write(img_file.read())
image_path = temp_file.name
st.write("Image saved to:", image_path)
st.image(image_path, use_container_width=True)
if st.button("Classify"):
with st.spinner("Processing..."):
image = Image.open(io.BytesIO(img_file.getbuffer())).convert('RGB')
# Get model input shape
input_shape = input_details[0]['shape']
# Preprocess the image
image = image.resize((input_shape[1], input_shape[2]))
image_array = np.array(image, dtype=np.float32)
image_array = image_array / 255.0
image_array = np.expand_dims(image_array, axis=0)
# Make prediction
interpreter.set_tensor(input_details[0]['index'], image_array)
interpreter.invoke()
# Get prediction results
prediction = interpreter.get_tensor(output_details[0]['index'])
# Get the predicted class index
predicted_class_index = np.argmax(prediction[0])
# Get the corresponding label information
predicted_label = labels[predicted_class_index]
genai_response = client.models.generate_content(
model=MODEL_ID,
contents=[
"You are a medical encyclopedia. You are given a skin cancer image and you need to provide a detailed explanation of the disease and its treatment.",
f"The predicted class is {predicted_label['name']}. Provide a detailed explanation of the disease, its treatment, and prevention.",
"Start with the title: Classfification of Skin Cancer",
]
)
# Display Gemini response
st.markdown(genai_response.text)
if __name__ == "__main__":
main()