Spaces:
Running
Running
import streamlit as st | |
import pandas | |
import torch | |
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForImageTextToText | |
# Load the model and processor | |
model_id = "brucewayne0459/paligemma_derm" | |
processor_chat = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | |
model_chat = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id) | |
model.eval() | |
# Set device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
# Add Hugging Face logo at the top | |
st.markdown( | |
""" | |
<style> | |
.huggingface-logo { | |
display: flex; | |
justify-content: center; | |
margin-bottom: 20px; | |
} | |
.huggingface-logo img { | |
width: 150px; | |
} | |
</style> | |
<div class="huggingface-logo"> | |
<img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="Hugging Face Logo"> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Streamlit app title and instructions | |
st.title("Skin Condition Identifier") | |
st.write("Upload an image and provide a text prompt to identify the skin condition.") | |
# Column layout for input and display | |
col1, col2 = st.columns([3, 2]) | |
with col1: | |
# File uploader for image | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
prompt = "Identify the skin condition?" | |
# Text input for prompt | |
input_text = st.text_input("Enter your prompt:", prompt) | |
with col2: | |
# Display uploaded image (if any) | |
if uploaded_file: | |
input_image = Image.open(uploaded_file).convert("RGB") | |
# Resize image for display (300x300 pixels) | |
resized_image = input_image.resize((300, 300)) | |
# Display the resized image | |
st.image(resized_image, caption="Uploaded Image (300x300)", use_container_width=True) | |
# Process and display the result when the button is clicked | |
if uploaded_file and st.button("Analyze"): | |
if not input_text.strip(): | |
st.error("Please provide a valid prompt!") | |
else: | |
try: | |
# Resize image for processing (512x512 pixels) | |
max_size = (512, 512) | |
input_image = input_image.resize(max_size) | |
# Prepare inputs | |
with st.spinner("Processing..."): | |
inputs = processor( | |
text=input_text, | |
images=input_image, | |
return_tensors="pt", | |
padding="longest" | |
).to(device) | |
# Generate output with default max_new_tokens | |
default_max_tokens = 50 # Set a default value for max tokens | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, max_new_tokens=default_max_tokens) | |
# Decode output | |
decoded_output = processor.decode(outputs[0], skip_special_tokens=True) | |
# Display result | |
st.success("Analysis Complete!") | |
st.write("**Model Output:**", decoded_output) | |
except Exception as e: | |
st.error(f"Error: {str(e)}") | |