File size: 3,849 Bytes
b6ab738
1afc2d5
 
72073e1
1afc2d5
 
 
 
 
 
 
2cb303a
1afc2d5
 
2cb303a
1afc2d5
2cb303a
1afc2d5
 
 
 
 
 
2cb303a
1afc2d5
72073e1
1afc2d5
 
 
 
 
 
 
 
b6ab738
2cb303a
72073e1
 
 
1afc2d5
72073e1
2cb303a
72073e1
b6ab738
1afc2d5
2cb303a
1afc2d5
 
b6ab738
2cb303a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1afc2d5
2cb303a
 
 
 
1afc2d5
2cb303a
 
 
1afc2d5
2cb303a
1afc2d5
2cb303a
1afc2d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
from PIL import Image
from threading import Thread

# Initialize model and processor
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cpu")

# Initialize inference client
client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")

# Functions
def llava(inputs, history):
    """Processes image + text input using Llava."""
    image = Image.open(inputs["files"][0]).convert("RGB")
    prompt = f"<|im_start|>user <image>\n{inputs['text']}<|im_end|>"
    processed = processor(prompt, image, return_tensors="pt").to("cpu")
    return processed

def respond(message, history):
    """Generate a response for text or image input."""
    if "files" in message and message["files"]:
        # Handle image + text input
        inputs = llava(message, history)
        streamer = TextIteratorStreamer(skip_prompt=True, skip_special_tokens=True)
        thread = Thread(target=model.generate, kwargs=dict(inputs=inputs, max_new_tokens=512, streamer=streamer))
        thread.start()
        buffer = ""
        for new_text in streamer:
            buffer += new_text
            yield buffer
    else:
        # Handle text input
        user_message = message["text"]
        history.append([user_message, None])  # Append user message to history
        prompt = [{"role": "user", "content": msg[0]} for msg in history if msg[0]]
        response = client_gemma.chat_completion(prompt, max_tokens=200)
        bot_message = response["choices"][0]["message"]["content"]
        history[-1][1] = bot_message  # Update history with bot's response
        yield history

def generate_image(prompt):
    """Generates an image based on user prompt."""
    client = InferenceClient("KingNish/Image-Gen-Pro")
    return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro")

# Gradio app setup with multi-page and sidebar
with gr.Blocks(title="AI Chat & Tools", theme="compact") as demo:
    with gr.Sidebar():
        gr.Markdown("## AI Assistant Sidebar")
        gr.Markdown("Navigate through features and try them out.")
        gr.Button("Open Chat").click(None, [], [], _js="() => window.location.hash='#chat'")
        gr.Button("Generate Image").click(None, [], [], _js="() => window.location.hash='#image'")

    with gr.Page("chat", title="Chat Interface"):
        chatbot = gr.Chatbot(label="Chat with AI Assistant", show_label=False)
        with gr.Row():
            text_input = gr.Textbox(placeholder="Enter your message...", lines=2, show_label=False)
            file_input = gr.File(label="Upload an image", file_types=["image/*"])

        def handle_text(text, history=[]):
            """Handle text input."""
            return respond({"text": text}, history), history

        def handle_file(files, history=[]):
            """Handle file upload."""
            return respond({"files": files, "text": "Describe this image."}, history), history

        # Connect callbacks
        text_input.submit(handle_text, [text_input, chatbot], [chatbot])
        file_input.change(handle_file, [file_input, chatbot], [chatbot])

    with gr.Page("image", title="Generate Image"):
        gr.Markdown("### Image Generator")
        image_prompt = gr.Textbox(placeholder="Describe the image to generate", show_label=False)
        image_output = gr.Image(label="Generated Image")

        def generate_image_callback(prompt):
            """Handle image generation."""
            return generate_image(prompt)

        image_prompt.submit(generate_image_callback, [image_prompt], [image_output])

# Launch Gradio app
demo.launch()