TuringsSolutions commited on
Commit
8cd9c33
·
verified ·
1 Parent(s): 9ede33d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -38
app.py CHANGED
@@ -9,19 +9,20 @@ model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
9
  processor = LlavaProcessor.from_pretrained(model_id)
10
  model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cpu")
11
 
 
12
  client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
13
 
14
- # Functions for chat and image handling
15
  def llava(inputs, history):
16
- """Processes image + text input with Llava."""
17
  image = Image.open(inputs["files"][0]).convert("RGB")
18
  prompt = f"<|im_start|>user <image>\n{inputs['text']}<|im_end|>"
19
  processed = processor(prompt, image, return_tensors="pt").to("cpu")
20
  return processed
21
 
22
  def respond(message, history):
23
- """Generate a response for input."""
24
  if "files" in message and message["files"]:
 
25
  inputs = llava(message, history)
26
  streamer = TextIteratorStreamer(skip_prompt=True, skip_special_tokens=True)
27
  thread = Thread(target=model.generate, kwargs=dict(inputs=inputs, max_new_tokens=512, streamer=streamer))
@@ -31,53 +32,45 @@ def respond(message, history):
31
  buffer += new_text
32
  yield buffer
33
  else:
 
34
  user_message = message["text"]
35
- history.append([user_message, None])
 
 
36
  prompt = [{"role": "user", "content": msg[0]} for msg in history if msg[0]]
37
  response = client_gemma.chat_completion(prompt, max_tokens=200)
 
 
38
  bot_message = response["choices"][0]["message"]["content"]
39
- history[-1][1] = bot_message
40
  yield history
41
 
42
  def generate_image(prompt):
43
- """Generates an image."""
44
  client = InferenceClient("KingNish/Image-Gen-Pro")
45
  return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro")
46
 
47
- # State management to control visibility
48
- def show_page(page, state):
49
- """Updates the state to show the selected page."""
50
- return {"chat_visible": page == "chat", "image_visible": page == "image"}
51
-
52
- # Gradio app setup
53
- with gr.Blocks(title="AI Chat & Tools") as demo:
54
- state = gr.State({"chat_visible": True, "image_visible": False})
55
-
56
  with gr.Row():
57
- with gr.Column(scale=1, min_width=200):
58
- gr.Markdown("## Navigation")
59
- chat_button = gr.Button("Chat Interface")
60
- image_button = gr.Button("Image Generation")
61
-
62
- with gr.Column(scale=3):
63
- with gr.Row(visible=lambda state: state["chat_visible"], interactive=True):
64
- gr.Markdown("## Chat with AI Assistant")
65
- chatbot = gr.Chatbot(label="Chat", show_label=False)
66
- text_input = gr.Textbox(placeholder="Enter your message...", lines=2, show_label=False)
67
- file_input = gr.File(label="Upload an image", file_types=["image/*"])
68
- text_input.submit(respond, [text_input, chatbot], [chatbot])
69
- file_input.change(respond, [file_input, chatbot], [chatbot])
70
 
71
- with gr.Row(visible=lambda state: state["image_visible"], interactive=True):
72
- gr.Markdown("## Image Generator")
73
- image_prompt = gr.Textbox(placeholder="Describe the image to generate", show_label=False)
74
- image_output = gr.Image(label="Generated Image")
75
- image_prompt.submit(generate_image, [image_prompt], [image_output])
76
 
77
- # Button actions to switch between pages
78
- chat_button.click(lambda: show_page("chat", state.value), None, state)
79
- image_button.click(lambda: show_page("image", state.value), None, state)
80
 
81
- # Launch the app
82
- demo.launch()
 
83
 
 
 
 
9
  processor = LlavaProcessor.from_pretrained(model_id)
10
  model = LlavaForConditionalGeneration.from_pretrained(model_id).to("cpu")
11
 
12
+ # Initialize inference clients
13
  client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
14
 
 
15
  def llava(inputs, history):
16
+ """Processes an image and text input using Llava."""
17
  image = Image.open(inputs["files"][0]).convert("RGB")
18
  prompt = f"<|im_start|>user <image>\n{inputs['text']}<|im_end|>"
19
  processed = processor(prompt, image, return_tensors="pt").to("cpu")
20
  return processed
21
 
22
  def respond(message, history):
23
+ """Generate a response based on text or image input."""
24
  if "files" in message and message["files"]:
25
+ # Handle image + text input
26
  inputs = llava(message, history)
27
  streamer = TextIteratorStreamer(skip_prompt=True, skip_special_tokens=True)
28
  thread = Thread(target=model.generate, kwargs=dict(inputs=inputs, max_new_tokens=512, streamer=streamer))
 
32
  buffer += new_text
33
  yield buffer
34
  else:
35
+ # Handle text-only input
36
  user_message = message["text"]
37
+ history.append([user_message, None]) # Append user message to history
38
+
39
+ # Prepare prompt for the language model
40
  prompt = [{"role": "user", "content": msg[0]} for msg in history if msg[0]]
41
  response = client_gemma.chat_completion(prompt, max_tokens=200)
42
+
43
+ # Extract response and update history
44
  bot_message = response["choices"][0]["message"]["content"]
45
+ history[-1][1] = bot_message # Update the last entry with bot's response
46
  yield history
47
 
48
  def generate_image(prompt):
49
+ """Generates an image based on the user prompt."""
50
  client = InferenceClient("KingNish/Image-Gen-Pro")
51
  return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro")
52
 
53
+ # Set up Gradio interface
54
+ with gr.Blocks() as demo:
55
+ chatbot = gr.Chatbot()
 
 
 
 
 
 
56
  with gr.Row():
57
+ with gr.Column():
58
+ text_input = gr.Textbox(placeholder="Enter your message...")
59
+ file_input = gr.File(label="Upload an image")
60
+ with gr.Column():
61
+ output = gr.Image(label="Generated Image")
 
 
 
 
 
 
 
 
62
 
63
+ def handle_text(text, history=[]):
64
+ """Handle text input and generate responses."""
65
+ return respond({"text": text}, history), history
 
 
66
 
67
+ def handle_file_upload(files, history=[]):
68
+ """Handle file uploads and generate responses."""
69
+ return respond({"files": files, "text": "Describe this image."}, history), history
70
 
71
+ # Connect components to callbacks
72
+ text_input.submit(handle_text, [text_input, chatbot], [chatbot])
73
+ file_input.change(handle_file_upload, [file_input, chatbot], [chatbot])
74
 
75
+ # Launch the Gradio app
76
+ demo.launch()