Vision-Derm / app.py
AkmalYafa's picture
Update app.py
935d5c9 verified
raw
history blame
3.41 kB
import streamlit as st
import pandas
import torch
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageTextToText
# Load the model and processor
model_id = "brucewayne0459/paligemma_derm"
processor_chat = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
model_chat = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
processor = AutoProcessor.from_pretrained(model_id)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
model.eval()
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Add Hugging Face logo at the top
st.markdown(
"""
<style>
.huggingface-logo {
display: flex;
justify-content: center;
margin-bottom: 20px;
}
.huggingface-logo img {
width: 150px;
}
</style>
<div class="huggingface-logo">
<img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="Hugging Face Logo">
</div>
""",
unsafe_allow_html=True,
)
# Streamlit app title and instructions
st.title("Skin Condition Identifier")
st.write("Upload an image and provide a text prompt to identify the skin condition.")
# Column layout for input and display
col1, col2 = st.columns([3, 2])
with col1:
# File uploader for image
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
prompt = "Identify the skin condition?"
# Text input for prompt
input_text = st.text_input("Enter your prompt:", prompt)
with col2:
# Display uploaded image (if any)
if uploaded_file:
input_image = Image.open(uploaded_file).convert("RGB")
# Resize image for display (300x300 pixels)
resized_image = input_image.resize((300, 300))
# Display the resized image
st.image(resized_image, caption="Uploaded Image (300x300)", use_container_width=True)
# Process and display the result when the button is clicked
if uploaded_file and st.button("Analyze"):
if not input_text.strip():
st.error("Please provide a valid prompt!")
else:
try:
# Resize image for processing (512x512 pixels)
max_size = (512, 512)
input_image = input_image.resize(max_size)
# Prepare inputs
with st.spinner("Processing..."):
inputs = processor(
text=input_text,
images=input_image,
return_tensors="pt",
padding="longest"
).to(device)
# Generate output with default max_new_tokens
default_max_tokens = 50 # Set a default value for max tokens
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=default_max_tokens)
# Decode output
decoded_output = processor.decode(outputs[0], skip_special_tokens=True)
# Display result
st.success("Analysis Complete!")
st.write("**Model Output:**", decoded_output)
except Exception as e:
st.error(f"Error: {str(e)}")