Mediocreatmybest commited on
Commit
bd9b9f2
·
1 Parent(s): 9f77cf7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -14
app.py CHANGED
@@ -1,6 +1,6 @@
 
1
  import gradio as gr
2
  from transformers import pipeline
3
- import torch
4
 
5
  CAPTION_MODELS = {
6
  'blip-base': 'Salesforce/blip-image-captioning-base',
@@ -14,11 +14,14 @@ CAPTION_MODELS = {
14
  # Create a dictionary to store loaded models
15
  loaded_models = {}
16
 
17
- # Modify caption_image to accept and process lists of images
18
- def caption_image(model_choice, images_input, urls_input, load_in_8bit, device):
19
- input_data = images_input if all(i is not None for i in images_input) else urls_input
 
 
 
20
 
21
- model_key = (model_choice, load_in_8bit, device) # Update the model key to include the device
22
 
23
  # Check if the model is already loaded
24
  if model_key in loaded_models:
@@ -29,7 +32,7 @@ def caption_image(model_choice, images_input, urls_input, load_in_8bit, device):
29
  captioner = pipeline(task="image-to-text",
30
  model=CAPTION_MODELS[model_choice],
31
  max_new_tokens=30,
32
- device=device, # Set the device as selected by the user
33
  model_kwargs=model_kwargs,
34
  torch_dtype=dtype, # Set the floating point
35
  use_fast=True
@@ -37,16 +40,14 @@ def caption_image(model_choice, images_input, urls_input, load_in_8bit, device):
37
  # Store the loaded model
38
  loaded_models[model_key] = captioner
39
 
40
- captions = captioner(input_data) # Run the model on the batch of images
41
- results = [str(caption['generated_text']).strip() for caption in captions] # Extract the captions from the outputs
42
-
43
- return results
44
 
45
  model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model')
46
- image_input = gr.Image(type="pil", label="Input Image") # Now takes multiple images
47
- url_input = gr.Text(label="Input URL") # Now takes multiple URLs
48
  load_in_8bit = gr.Checkbox(label="Load model in 8bit")
49
- device = gr.Radio(choices=['cpu', 'cuda'], label='Device') # Radio button for device selection
50
 
51
- iface = gr.Interface(caption_image, inputs=[model_dropdown, image_input, url_input, load_in_8bit, device], outputs=gr.outputs.Textbox(type="auto", label="Caption"))
52
  iface.launch()
 
1
+ import torch
2
  import gradio as gr
3
  from transformers import pipeline
 
4
 
5
  CAPTION_MODELS = {
6
  'blip-base': 'Salesforce/blip-image-captioning-base',
 
14
  # Create a dictionary to store loaded models
15
  loaded_models = {}
16
 
17
+ # Simple caption creation
18
+ def caption_image(model_choice, image_input, url_input, load_in_8bit, device):
19
+ if image_input is not None:
20
+ input_data = image_input
21
+ else:
22
+ input_data = url_input
23
 
24
+ model_key = (model_choice, load_in_8bit) # Create a tuple to represent the unique combination of model and 8bit loading
25
 
26
  # Check if the model is already loaded
27
  if model_key in loaded_models:
 
32
  captioner = pipeline(task="image-to-text",
33
  model=CAPTION_MODELS[model_choice],
34
  max_new_tokens=30,
35
+ device=device, # Set the device as selected
36
  model_kwargs=model_kwargs,
37
  torch_dtype=dtype, # Set the floating point
38
  use_fast=True
 
40
  # Store the loaded model
41
  loaded_models[model_key] = captioner
42
 
43
+ caption = captioner(input_data)
44
+ return [str(c['generated_text']).strip() for c in caption]
 
 
45
 
46
  model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model')
47
+ image_input = gr.Image(type="pil", label="Input Image", multiple=True) # Enable multiple inputs
48
+ url_input = gr.Text(label="Input URL")
49
  load_in_8bit = gr.Checkbox(label="Load model in 8bit")
50
+ device = gr.Radio(['cpu', 'cuda'], label='Select device', default='cpu')
51
 
52
+ iface = gr.Interface(caption_image, inputs=[model_dropdown, image_input, url_input, load_in_8bit, device], outputs=gr.interfaces.outputs.Textbox(type="text", label="Caption"))
53
  iface.launch()