File size: 3,680 Bytes
b6ab738
1afc2d5
 
72073e1
1afc2d5
 
 
 
 
 
 
 
 
9b4d9ff
1afc2d5
9b4d9ff
1afc2d5
 
 
 
 
 
9ede33d
1afc2d5
 
 
 
 
 
 
 
 
b6ab738
72073e1
9b4d9ff
72073e1
1afc2d5
72073e1
9b4d9ff
72073e1
b6ab738
1afc2d5
9b4d9ff
1afc2d5
 
b6ab738
9ede33d
 
 
 
 
9b4d9ff
 
9ede33d
 
9b4d9ff
 
 
9ede33d
 
 
9b4d9ff
9ede33d
9b4d9ff
 
9ede33d
 
9b4d9ff
 
9ede33d
 
9b4d9ff
 
 
 
2cb303a
9ede33d
 
 
 
9b4d9ff
 
9ede33d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
from PIL import Image
from threading import Thread

# Initialize model and processor
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cpu")

client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")

# Functions for chat and image handling
def llava(inputs, history):
    """Processes image + text input with Llava."""
    image = Image.open(inputs["files"][0]).convert("RGB")
    prompt = f"<|im_start|>user <image>\n{inputs['text']}<|im_end|>"
    processed = processor(prompt, image, return_tensors="pt").to("cpu")
    return processed

def respond(message, history):
    """Generate a response for input."""
    if "files" in message and message["files"]:
        inputs = llava(message, history)
        streamer = TextIteratorStreamer(skip_prompt=True, skip_special_tokens=True)
        thread = Thread(target=model.generate, kwargs=dict(inputs=inputs, max_new_tokens=512, streamer=streamer))
        thread.start()
        buffer = ""
        for new_text in streamer:
            buffer += new_text
            yield buffer
    else:
        user_message = message["text"]
        history.append([user_message, None])
        prompt = [{"role": "user", "content": msg[0]} for msg in history if msg[0]]
        response = client_gemma.chat_completion(prompt, max_tokens=200)
        bot_message = response["choices"][0]["message"]["content"]
        history[-1][1] = bot_message
        yield history

def generate_image(prompt):
    """Generates an image."""
    client = InferenceClient("KingNish/Image-Gen-Pro")
    return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro")

# State management to control visibility
def show_page(page, state):
    """Updates the state to show the selected page."""
    return {"chat_visible": page == "chat", "image_visible": page == "image"}

# Gradio app setup
with gr.Blocks(title="AI Chat & Tools") as demo:
    state = gr.State({"chat_visible": True, "image_visible": False})

    with gr.Row():
        with gr.Column(scale=1, min_width=200):
            gr.Markdown("## Navigation")
            chat_button = gr.Button("Chat Interface")
            image_button = gr.Button("Image Generation")

        with gr.Column(scale=3):
            with gr.Row(visible=lambda state: state["chat_visible"], interactive=True):
                gr.Markdown("## Chat with AI Assistant")
                chatbot = gr.Chatbot(label="Chat", show_label=False)
                text_input = gr.Textbox(placeholder="Enter your message...", lines=2, show_label=False)
                file_input = gr.File(label="Upload an image", file_types=["image/*"])
                text_input.submit(respond, [text_input, chatbot], [chatbot])
                file_input.change(respond, [file_input, chatbot], [chatbot])

            with gr.Row(visible=lambda state: state["image_visible"], interactive=True):
                gr.Markdown("## Image Generator")
                image_prompt = gr.Textbox(placeholder="Describe the image to generate", show_label=False)
                image_output = gr.Image(label="Generated Image")
                image_prompt.submit(generate_image, [image_prompt], [image_output])

    # Button actions to switch between pages
    chat_button.click(lambda: show_page("chat", state.value), None, state)
    image_button.click(lambda: show_page("image", state.value), None, state)

# Launch the app
demo.launch()