File size: 3,207 Bytes
b6ab738
1afc2d5
 
72073e1
1afc2d5
 
 
 
 
 
 
8cd9c33
1afc2d5
 
3f88864
8cd9c33
1afc2d5
 
 
 
 
 
8cd9c33
1afc2d5
8cd9c33
3f88864
1afc2d5
 
 
3f88864
1afc2d5
 
 
3f88864
 
 
b6ab738
8cd9c33
72073e1
3f88864
8cd9c33
 
72073e1
1afc2d5
8cd9c33
 
72073e1
3f88864
 
b6ab738
1afc2d5
8cd9c33
1afc2d5
 
b6ab738
8cd9c33
 
 
9b4d9ff
8cd9c33
 
 
9ede33d
8cd9c33
 
3f88864
2cb303a
8cd9c33
 
3f88864
9ede33d
8cd9c33
3f88864
 
9ede33d
8cd9c33
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
from PIL import Image
from threading import Thread

# Initialize model and processor
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cpu")

# Initialize inference clients
client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")

def llava(inputs):
    """Processes an image and text input using Llava."""
    image = Image.open(inputs["files"][0]).convert("RGB")
    prompt = f"<|im_start|>user <image>\n{inputs['text']}<|im_end|>"
    processed = processor(prompt, image, return_tensors="pt").to("cpu")
    return processed

def respond(message, history):
    """Generate a response based on text or image input."""
    if "files" in message and message["files"]:
        # Handle image + text input
        inputs = llava(message)
        streamer = TextIteratorStreamer(skip_prompt=True, skip_special_tokens=True)
        thread = Thread(target=model.generate, kwargs=dict(inputs=inputs, max_new_tokens=512, streamer=streamer))
        thread.start()

        buffer = ""
        for new_text in streamer:
            buffer += new_text
            history[-1][1] = buffer  # Update the latest message in history
            yield history, history  # Yield both chatbot and history for updating

    else:
        # Handle text-only input
        user_message = message["text"]
        history.append([user_message, None])  # Add user's message with a placeholder response

        # Prepare prompt for the language model
        prompt = [{"role": "user", "content": msg[0]} for msg in history if msg[0]]
        response = client_gemma.chat_completion(prompt, max_tokens=200)

        # Extract response and update history
        bot_message = response["choices"][0]["message"]["content"]
        history[-1][1] = bot_message  # Update the latest message with bot's response
        yield history, history  # Yield both chatbot and history for updating

def generate_image(prompt):
    """Generates an image based on the user prompt."""
    client = InferenceClient("KingNish/Image-Gen-Pro")
    return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro")

# Set up Gradio interface
with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(placeholder="Enter your message...")
            file_input = gr.File(label="Upload an image")

    def handle_text(text, history=[]):
        """Handle text input and generate responses."""
        return respond({"text": text}, history)

    def handle_file_upload(files, history=[]):
        """Handle file uploads and generate responses."""
        return respond({"files": files, "text": "Describe this image."}, history)

    # Connect components to callbacks
    text_input.submit(handle_text, [text_input, chatbot], [chatbot, chatbot])
    file_input.change(handle_file_upload, [file_input, chatbot], [chatbot, chatbot])

# Launch the Gradio app
demo.launch()