File size: 3,413 Bytes
5a6aade
935d5c9
5a6aade
 
 
e8c1089
 
 
5a6aade
 
 
e8c1089
 
5a6aade
25835df
5a6aade
 
 
 
 
 
11666df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a6aade
 
 
39a7b7a
 
5a6aade
39a7b7a
 
 
e8c1089
39a7b7a
 
e8c1089
 
39a7b7a
caae7d4
 
 
3c0f272
 
 
 
 
 
caae7d4
39a7b7a
 
 
 
 
 
3c0f272
39a7b7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
import pandas
import torch
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageTextToText



# Load the model and processor
model_id = "brucewayne0459/paligemma_derm"
processor_chat = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
model_chat = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")  
processor = AutoProcessor.from_pretrained(model_id)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
model.eval()

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Add Hugging Face logo at the top
st.markdown(
    """
    <style>
    .huggingface-logo {
        display: flex;
        justify-content: center;
        margin-bottom: 20px;
    }
    .huggingface-logo img {
        width: 150px;
    }
    </style>
    <div class="huggingface-logo">
        <img src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" alt="Hugging Face Logo">
    </div>
    """,
    unsafe_allow_html=True,
)

# Streamlit app title and instructions
st.title("Skin Condition Identifier")
st.write("Upload an image and provide a text prompt to identify the skin condition.")

# Column layout for input and display
col1, col2 = st.columns([3, 2])

with col1:
    # File uploader for image
    uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
    prompt = "Identify the skin condition?"
    
    # Text input for prompt
    input_text = st.text_input("Enter your prompt:", prompt)
    
with col2:
    # Display uploaded image (if any)
    if uploaded_file:
        input_image = Image.open(uploaded_file).convert("RGB")
        
        # Resize image for display (300x300 pixels)
        resized_image = input_image.resize((300, 300))
        
        # Display the resized image
        st.image(resized_image, caption="Uploaded Image (300x300)", use_container_width=True)

    # Process and display the result when the button is clicked
    if uploaded_file and st.button("Analyze"):
        if not input_text.strip():
            st.error("Please provide a valid prompt!")
        else:
            try:
                # Resize image for processing (512x512 pixels)
                max_size = (512, 512)
                input_image = input_image.resize(max_size)
    
                # Prepare inputs
                with st.spinner("Processing..."):
                    inputs = processor(
                        text=input_text,
                        images=input_image,
                        return_tensors="pt",
                        padding="longest"
                    ).to(device)
    
                    # Generate output with default max_new_tokens
                    default_max_tokens = 50  # Set a default value for max tokens
                    with torch.no_grad():
                        outputs = model.generate(**inputs, max_new_tokens=default_max_tokens)
                    
                    # Decode output
                    decoded_output = processor.decode(outputs[0], skip_special_tokens=True)
    
                # Display result
                st.success("Analysis Complete!")
                st.write("**Model Output:**", decoded_output)
            except Exception as e:
                st.error(f"Error: {str(e)}")