Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas | |
| import torch | |
| from transformers import AutoProcessor, PaliGemmaForConditionalGeneration | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| # Load the model and processor | |
| model_id = "brucewayne0459/paligemma_derm" | |
| processor_chat = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | |
| model_chat = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = PaliGemmaForConditionalGeneration.from_pretrained(model_id) | |
| model.eval() | |
| # Set device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| # Add Hugging Face logo at the top | |
| st.markdown( | |
| """ | |
| <style> | |
| .huggingface-logo { | |
| display: flex; | |
| justify-content: center; | |
| margin-bottom: 20px; | |
| } | |
| .huggingface-logo img { | |
| width: 150px; | |
| } | |
| </style> | |
| <div class="huggingface-logo"> | |
| <img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="Hugging Face Logo"> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Streamlit app title and instructions | |
| st.title("Skin Condition Identifier") | |
| st.write("Upload an image and provide a text prompt to identify the skin condition.") | |
| # Column layout for input and display | |
| col1, col2 = st.columns([3, 2]) | |
| with col1: | |
| # File uploader for image | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| prompt = "Identify the skin condition?" | |
| # Text input for prompt | |
| input_text = st.text_input("Enter your prompt:", prompt) | |
| with col2: | |
| # Display uploaded image (if any) | |
| if uploaded_file: | |
| input_image = Image.open(uploaded_file).convert("RGB") | |
| # Resize image for display (300x300 pixels) | |
| resized_image = input_image.resize((300, 300)) | |
| # Display the resized image | |
| st.image(resized_image, caption="Uploaded Image (300x300)", use_container_width=True) | |
| # Process and display the result when the button is clicked | |
| if uploaded_file and st.button("Analyze"): | |
| if not input_text.strip(): | |
| st.error("Please provide a valid prompt!") | |
| else: | |
| try: | |
| # Resize image for processing (512x512 pixels) | |
| max_size = (512, 512) | |
| input_image = input_image.resize(max_size) | |
| # Prepare inputs | |
| with st.spinner("Processing..."): | |
| inputs = processor( | |
| text=input_text, | |
| images=input_image, | |
| return_tensors="pt", | |
| padding="longest" | |
| ).to(device) | |
| # Generate output with default max_new_tokens | |
| default_max_tokens = 50 # Set a default value for max tokens | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_new_tokens=default_max_tokens) | |
| # Decode output | |
| decoded_output = processor.decode(outputs[0], skip_special_tokens=True) | |
| # Display result | |
| st.success("Analysis Complete!") | |
| st.write("**Model Output:**", decoded_output) | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |