import streamlit as st import tempfile from PIL import Image import os import tensorflow as tf import numpy as np import io import json from google import genai MODEL_ID = "gemini-2.0-flash" model_id = MODEL_ID client = genai.Client(api_key = os.getenv("GEMINI_API_KEY")) # Load labels from JSON file with open('model/labels.json', 'r') as f: labels = json.load(f) # Load TensorFlow Lite model def load_tflite_model(): interpreter = tf.lite.Interpreter(model_path="model/model.tflite") interpreter.allocate_tensors() return interpreter # Get input and output details interpreter = load_tflite_model() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() def main(): st.title("Skin Cancer Classifier") st.write("Class demo app that classifies skin cancer images into different categories. It is using a TensorFlow Lite model trained on the Skin Cancer MNIST dataset. The result is then used to generate a detailed explanation of the disease and its treatment using the Gemini API.") img_file = None image_path = "" # Upload an image img_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) if img_file is not None: # Save the image to a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: temp_file.write(img_file.read()) image_path = temp_file.name st.write("Image saved to:", image_path) st.image(image_path, use_container_width=True) if st.button("Classify"): with st.spinner("Processing..."): image = Image.open(io.BytesIO(img_file.getbuffer())).convert('RGB') # Get model input shape input_shape = input_details[0]['shape'] # Preprocess the image image = image.resize((input_shape[1], input_shape[2])) image_array = np.array(image, dtype=np.float32) image_array = image_array / 255.0 image_array = np.expand_dims(image_array, axis=0) # Make prediction interpreter.set_tensor(input_details[0]['index'], image_array) interpreter.invoke() # Get prediction results prediction = interpreter.get_tensor(output_details[0]['index']) # Get the predicted class index predicted_class_index = np.argmax(prediction[0]) # Get the corresponding label information predicted_label = labels[predicted_class_index] genai_response = client.models.generate_content( model=MODEL_ID, contents=[ "You are a medical encyclopedia. You are given a skin cancer image and you need to provide a detailed explanation of the disease and its treatment.", f"The predicted class is {predicted_label['name']}. Provide a detailed explanation of the disease, its treatment, and prevention.", "Start with the title: Classfification of Skin Cancer", ] ) # Display Gemini response st.markdown(genai_response.text) if __name__ == "__main__": main()