|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer |
|
from PIL import Image |
|
from threading import Thread |
|
|
|
|
|
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf" |
|
processor = LlavaProcessor.from_pretrained(model_id) |
|
model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cpu") |
|
|
|
|
|
client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3") |
|
|
|
def llava(inputs): |
|
"""Processes an image and text input using Llava.""" |
|
image = Image.open(inputs["files"][0]).convert("RGB") |
|
prompt = f"<|im_start|>user <image>\n{inputs['text']}<|im_end|>" |
|
processed = processor(prompt, image, return_tensors="pt").to("cpu") |
|
return processed |
|
|
|
def respond(message, history): |
|
"""Generate a response based on text or image input.""" |
|
if "files" in message and message["files"]: |
|
|
|
inputs = llava(message) |
|
streamer = TextIteratorStreamer(skip_prompt=True, skip_special_tokens=True) |
|
thread = Thread(target=model.generate, kwargs=dict(inputs=inputs, max_new_tokens=512, streamer=streamer)) |
|
thread.start() |
|
|
|
buffer = "" |
|
for new_text in streamer: |
|
buffer += new_text |
|
history[-1][1] = buffer |
|
yield history, history |
|
|
|
else: |
|
|
|
user_message = message["text"] |
|
history.append([user_message, None]) |
|
|
|
|
|
prompt = [{"role": "user", "content": msg[0]} for msg in history if msg[0]] |
|
response = client_gemma.chat_completion(prompt, max_tokens=200) |
|
|
|
|
|
bot_message = response["choices"][0]["message"]["content"] |
|
history[-1][1] = bot_message |
|
yield history, history |
|
|
|
def generate_image(prompt): |
|
"""Generates an image based on the user prompt.""" |
|
client = InferenceClient("KingNish/Image-Gen-Pro") |
|
return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro") |
|
|
|
|
|
with gr.Blocks() as demo: |
|
chatbot = gr.Chatbot() |
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox(placeholder="Enter your message...") |
|
file_input = gr.File(label="Upload an image") |
|
|
|
def handle_text(text, history=[]): |
|
"""Handle text input and generate responses.""" |
|
return respond({"text": text}, history) |
|
|
|
def handle_file_upload(files, history=[]): |
|
"""Handle file uploads and generate responses.""" |
|
return respond({"files": files, "text": "Describe this image."}, history) |
|
|
|
|
|
text_input.submit(handle_text, [text_input, chatbot], [chatbot, chatbot]) |
|
file_input.change(handle_file_upload, [file_input, chatbot], [chatbot, chatbot]) |
|
|
|
|
|
demo.launch() |