zamal commited on
Commit
bcb49b1
·
verified ·
1 Parent(s): 2b7b95f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -27
app.py CHANGED
@@ -2,38 +2,49 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoProcessor
3
  from PIL import Image
4
  import torch
 
 
 
 
5
 
6
  # Define the repository for the quantized model
7
  repo_name = "cyan2k/molmo-7B-D-bnb-4bit"
8
- arguments = {"device_map": "auto", "torch_dtype": torch.float16, "trust_remote_code": True}
9
 
10
- # Load the processor and quantized model
11
- processor = AutoProcessor.from_pretrained(repo_name, **arguments)
12
- model = AutoModelForCausalLM.from_pretrained(repo_name, **arguments)
 
 
 
 
 
 
 
 
13
 
14
  def process_image_and_text(image, text):
15
- # Process the image and text
16
- inputs = processor(
17
- images=[Image.fromarray(image)],
18
- text=text,
19
- return_tensors="pt"
20
- )
21
 
22
- # Move inputs to the same device as the model (GPU)
23
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
24
 
25
- # Generate output
26
  output = model.generate(**inputs, max_new_tokens=200)
27
 
28
- # Only get generated tokens; decode them to text
29
- generated_text = processor.batch_decode(output, skip_special_tokens=True)[0]
30
  return generated_text
31
 
32
  def chatbot(image, text, history):
 
33
  if image is None:
34
  return history + [("Please upload an image first.", None)]
35
 
 
36
  response = process_image_and_text(image, text)
 
 
37
  history.append((text, response))
38
  return history
39
 
@@ -50,16 +61,9 @@ with gr.Blocks() as demo:
50
 
51
  state = gr.State([])
52
 
53
- submit_button.click(
54
- chatbot,
55
- inputs=[image_input, text_input, state],
56
- outputs=[chatbot_output]
57
- )
58
-
59
- text_input.submit(
60
- chatbot,
61
- inputs=[image_input, text_input, state],
62
- outputs=[chatbot_output]
63
- )
64
 
65
- demo.launch()
 
 
2
  from transformers import AutoModelForCausalLM, AutoProcessor
3
  from PIL import Image
4
  import torch
5
+ import os
6
+
7
+ # Set environment variable to skip CUDA build for flash-attn
8
+ os.environ["FLASH_ATTENTION_SKIP_CUDA_BUILD"] = "TRUE"
9
 
10
  # Define the repository for the quantized model
11
  repo_name = "cyan2k/molmo-7B-D-bnb-4bit"
 
12
 
13
+ # Load processor and model with GPU optimization
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ processor = AutoProcessor.from_pretrained(repo_name, trust_remote_code=True)
16
+
17
+ # Load model with 4-bit quantization
18
+ model = AutoModelForCausalLM.from_pretrained(repo_name,
19
+ device_map="auto",
20
+ torch_dtype=torch.float16,
21
+ load_in_4bit=True,
22
+ trust_remote_code=True)
23
+ model.to(device)
24
 
25
  def process_image_and_text(image, text):
26
+ # Convert numpy image to PIL format
27
+ pil_image = Image.fromarray(image)
 
 
 
 
28
 
29
+ # Process image and text with processor
30
+ inputs = processor(images=[pil_image], text=text, return_tensors="pt").to(device)
31
 
32
+ # Generate output using the model
33
  output = model.generate(**inputs, max_new_tokens=200)
34
 
35
+ # Decode the generated output
36
+ generated_text = processor.decode(output[0], skip_special_tokens=True)
37
  return generated_text
38
 
39
  def chatbot(image, text, history):
40
+ # Check if the image is uploaded
41
  if image is None:
42
  return history + [("Please upload an image first.", None)]
43
 
44
+ # Get response by processing the image and text
45
  response = process_image_and_text(image, text)
46
+
47
+ # Append question and response to the chat history
48
  history.append((text, response))
49
  return history
50
 
 
61
 
62
  state = gr.State([])
63
 
64
+ # Connect the submit button and textbox to the chatbot function
65
+ submit_button.click(fn=chatbot, inputs=[image_input, text_input, state], outputs=chatbot_output)
66
+ text_input.submit(fn=chatbot, inputs=[image_input, text_input, state], outputs=chatbot_output)
 
 
 
 
 
 
 
 
67
 
68
+ # Launch the Gradio app with GPU
69
+ demo.launch(share=True)