File size: 5,259 Bytes
dc90582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import base64
import io
import gradio as gr
from groq import Groq
from PIL import Image
import requests

def encode_image(image):
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

def analyze_image(image, prompt, api_key, is_url=False):
    client = Groq(api_key=api_key)

    if is_url:
        image_content = {"type": "image_url", "image_url": {"url": image}}
    else:
        base64_image = encode_image(image)
        image_content = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}

    try:
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        image_content,
                    ],
                }
            ],
            model="llava-v1.5-7b-4096-preview",
        )
        return chat_completion.choices[0].message.content
    except Exception as e:
        return f"Error: {str(e)}"

def check_content_safety(image_description, api_key):
    client = Groq(api_key=api_key)

    try:
        chat_completion = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are a content safety classifier. Analyze the given text and determine if it contains any unsafe or inappropriate content."},
                {"role": "user", "content": f"Please analyze this image description for any unsafe or inappropriate content: {image_description}"}
            ],
            model="llama-guard-3-8b",
        )
        return chat_completion.choices[0].message.content
    except Exception as e:
        return f"Error: {str(e)}"

def process_image(image, url, prompt, api_key):
    if image is not None:
        return analyze_image(image, prompt, api_key), check_content_safety(analyze_image(image, prompt, api_key), api_key)
    elif url:
        try:
            response = requests.get(url)
            image = Image.open(io.BytesIO(response.content))
            return analyze_image(url, prompt, api_key, is_url=True), check_content_safety(analyze_image(url, prompt, api_key, is_url=True), api_key)
        except:
            return "Invalid image URL. Please provide a direct link to an image.", ""
    else:
        return "Please provide an image to analyze.", ""

def launch():
    with gr.Blocks(
        theme=gr.themes.Default(primary_hue="orange"),
        css="""
        #app-container { max-width: 1000px; margin: auto; padding: 10px; }
        #title { text-align: center; margin-bottom: 10px; font-size: 24px; }
        #groq-badge { text-align: center; margin-top: 10px; }
        .gr-button { border-radius: 15px; }
        .gr-input, .gr-box { border-radius: 10px; }
        .gr-form { gap: 5px; }
        .gr-block.gr-box { padding: 10px; }
        .gr-paddle { height: auto; }
        """
    ) as demo:
        with gr.Column(elem_id="app-container"):
            gr.Markdown("# 🖼️ Groq x Gradio Image Analysis and Content Safety Check", elem_id="title")
            
            with gr.Row():
                api_key = gr.Textbox(label="Groq API Key:", type="password", scale=2)
                prompt = gr.Textbox(
                    label="Image Analysis Prompt:",
                    value="Describe the image content.",
                    scale=3
                )
            
            with gr.Row():
                with gr.Column(scale=1):
                    image_input = gr.Image(type="pil", label="Upload Image:", height=200, sources=["upload"])
                with gr.Column(scale=1):
                    url_input = gr.Textbox(label="Or Paste Image URL:", lines=1)
                    analyze_button = gr.Button("🚀 Analyze Image", variant="primary")
            
            with gr.Row():
                with gr.Column():
                    analysis_output = gr.Textbox(label="Image Analysis with LlaVA 1.5 7B:", lines=6)
                with gr.Column():
                    safety_output = gr.Textbox(label="Safety Check with Llama Guard 3 8B:", lines=6)
            
            analyze_button.click(
                fn=process_image,
                inputs=[image_input, url_input, prompt, api_key],
                outputs=[analysis_output, safety_output]
            )
            
            with gr.Row():
                with gr.Column():
                    gr.HTML("""
                    <div id="groq-badge">
                        <div style="color: #f55036; font-weight: bold; font-size: 1em;">⚡ POWERED BY GROQ ⚡</div>
                    </div>
                    """)
                with gr.Column():
                    gr.Markdown("""
                    **How to use this app:** 
                    1. Enter your [Groq API Key](https://console.groq.com/keys) in the provided field.
                    2. Upload an image file or paste an image URL.
                    3. Use default prompt or enter custom prompt for image analysis.
                    4. Click "Analyze Image" to check for content safety.
                    """)
    
    demo.launch()

if __name__ == "__main__":
    launch()