zamal commited on
Commit
0568dda
·
verified ·
1 Parent(s): efdbe7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -36
app.py CHANGED
@@ -1,46 +1,47 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
3
  from PIL import Image
4
  import torch
5
  import subprocess
 
6
 
7
- # Run pip install command
8
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
9
 
 
 
10
 
11
  # Define the repository for the quantized model
12
  repo_name = "cyan2k/molmo-7B-D-bnb-4bit"
13
- arguments = {"device_map": "auto", "torch_dtype": "auto", "trust_remote_code": True}
14
 
15
- # Load the processor and quantized model
16
- processor = AutoProcessor.from_pretrained(repo_name, **arguments)
17
- model = AutoModelForCausalLM.from_pretrained(repo_name, **arguments)
 
18
 
19
  def process_image_and_text(image, text):
20
- # Process the image and text
21
- inputs = processor.process(
22
- images=[Image.fromarray(image)],
23
- text=text
24
- )
25
-
26
- # Move inputs to the same device as the model (GPU) and make a batch of size 1
27
- inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
28
-
29
- # Generate output
30
- output = model.generate(
31
- **inputs,
32
- max_new_tokens=200
33
- )
34
-
35
- # Only get generated tokens; decode them to text
36
- generated_text = processor.decode(output, skip_special_tokens=True)
37
  return generated_text
38
 
39
  def chatbot(image, text, history):
 
40
  if image is None:
41
  return history + [("Please upload an image first.", None)]
42
 
 
43
  response = process_image_and_text(image, text)
 
 
44
  history.append((text, response))
45
  return history
46
 
@@ -57,16 +58,9 @@ with gr.Blocks() as demo:
57
 
58
  state = gr.State([])
59
 
60
- submit_button.click(
61
- chatbot,
62
- inputs=[image_input, text_input, state],
63
- outputs=[chatbot_output]
64
- )
65
-
66
- text_input.submit(
67
- chatbot,
68
- inputs=[image_input, text_input, state],
69
- outputs=[chatbot_output]
70
- )
71
 
72
- demo.launch()
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoProcessor
3
  from PIL import Image
4
  import torch
5
  import subprocess
6
+ import os
7
 
8
+ # Set environment variable to skip CUDA build for flash-attn
9
+ os.environ["FLASH_ATTENTION_SKIP_CUDA_BUILD"] = "TRUE"
10
 
11
+ # Install flash-attn package
12
+ subprocess.run('pip install flash-attn --no-build-isolation', shell=True)
13
 
14
  # Define the repository for the quantized model
15
  repo_name = "cyan2k/molmo-7B-D-bnb-4bit"
 
16
 
17
+ # Load processor and model with GPU optimization
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ processor = AutoProcessor.from_pretrained(repo_name, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
20
+ model = AutoModelForCausalLM.from_pretrained(repo_name, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
21
 
22
  def process_image_and_text(image, text):
23
+ # Convert numpy image to PIL format
24
+ pil_image = Image.fromarray(image)
25
+
26
+ # Process image and text with processor
27
+ inputs = processor(images=[pil_image], text=text, return_tensors="pt").to(device)
28
+
29
+ # Generate output using the model
30
+ output = model.generate(**inputs, max_new_tokens=200)
31
+
32
+ # Decode the generated output
33
+ generated_text = processor.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
34
  return generated_text
35
 
36
  def chatbot(image, text, history):
37
+ # Check if the image is uploaded
38
  if image is None:
39
  return history + [("Please upload an image first.", None)]
40
 
41
+ # Get response by processing the image and text
42
  response = process_image_and_text(image, text)
43
+
44
+ # Append question and response to the chat history
45
  history.append((text, response))
46
  return history
47
 
 
58
 
59
  state = gr.State([])
60
 
61
+ # Connect the submit button and textbox to the chatbot function
62
+ submit_button.click(fn=chatbot, inputs=[image_input, text_input, state], outputs=chatbot_output)
63
+ text_input.submit(fn=chatbot, inputs=[image_input, text_input, state], outputs=chatbot_output)
 
 
 
 
 
 
 
 
64
 
65
+ # Launch the Gradio app with GPU
66
+ demo.launch(share=True)