Mediocreatmybest commited on
Commit
9f77cf7
·
1 Parent(s): 593f239

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -16
app.py CHANGED
@@ -1,6 +1,6 @@
1
- import torch
2
  import gradio as gr
3
  from transformers import pipeline
 
4
 
5
  CAPTION_MODELS = {
6
  'blip-base': 'Salesforce/blip-image-captioning-base',
@@ -14,14 +14,11 @@ CAPTION_MODELS = {
14
  # Create a dictionary to store loaded models
15
  loaded_models = {}
16
 
17
- # Simple caption creation
18
- def caption_image(model_choice, image_input, url_input, load_in_8bit):
19
- if image_input is not None:
20
- input_data = image_input
21
- else:
22
- input_data = url_input
23
 
24
- model_key = (model_choice, load_in_8bit) # Create a tuple to represent the unique combination of model and 8bit loading
25
 
26
  # Check if the model is already loaded
27
  if model_key in loaded_models:
@@ -32,7 +29,7 @@ def caption_image(model_choice, image_input, url_input, load_in_8bit):
32
  captioner = pipeline(task="image-to-text",
33
  model=CAPTION_MODELS[model_choice],
34
  max_new_tokens=30,
35
- device='cpu', # Set the device as CPU
36
  model_kwargs=model_kwargs,
37
  torch_dtype=dtype, # Set the floating point
38
  use_fast=True
@@ -40,16 +37,16 @@ def caption_image(model_choice, image_input, url_input, load_in_8bit):
40
  # Store the loaded model
41
  loaded_models[model_key] = captioner
42
 
43
- caption = captioner(input_data)[0]['generated_text']
44
- return str(caption).strip()
45
 
46
- def launch(model_choice, image_input, url_input, load_in_8bit):
47
- return caption_image(model_choice, image_input, url_input, load_in_8bit)
48
 
49
  model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model')
50
- image_input = gr.Image(type="pil", label="Input Image")
51
- url_input = gr.Text(label="Input URL")
52
  load_in_8bit = gr.Checkbox(label="Load model in 8bit")
 
53
 
54
- iface = gr.Interface(launch, inputs=[model_dropdown, image_input, url_input, load_in_8bit], outputs="text")
55
  iface.launch()
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ import torch
4
 
5
  CAPTION_MODELS = {
6
  'blip-base': 'Salesforce/blip-image-captioning-base',
 
14
  # Create a dictionary to store loaded models
15
  loaded_models = {}
16
 
17
+ # Modify caption_image to accept and process lists of images
18
+ def caption_image(model_choice, images_input, urls_input, load_in_8bit, device):
19
+ input_data = images_input if all(i is not None for i in images_input) else urls_input
 
 
 
20
 
21
+ model_key = (model_choice, load_in_8bit, device) # Update the model key to include the device
22
 
23
  # Check if the model is already loaded
24
  if model_key in loaded_models:
 
29
  captioner = pipeline(task="image-to-text",
30
  model=CAPTION_MODELS[model_choice],
31
  max_new_tokens=30,
32
+ device=device, # Set the device as selected by the user
33
  model_kwargs=model_kwargs,
34
  torch_dtype=dtype, # Set the floating point
35
  use_fast=True
 
37
  # Store the loaded model
38
  loaded_models[model_key] = captioner
39
 
40
+ captions = captioner(input_data) # Run the model on the batch of images
41
+ results = [str(caption['generated_text']).strip() for caption in captions] # Extract the captions from the outputs
42
 
43
+ return results
 
44
 
45
  model_dropdown = gr.Dropdown(choices=list(CAPTION_MODELS.keys()), label='Select Caption Model')
46
+ image_input = gr.Image(type="pil", label="Input Image") # Now takes multiple images
47
+ url_input = gr.Text(label="Input URL") # Now takes multiple URLs
48
  load_in_8bit = gr.Checkbox(label="Load model in 8bit")
49
+ device = gr.Radio(choices=['cpu', 'cuda'], label='Device') # Radio button for device selection
50
 
51
+ iface = gr.Interface(caption_image, inputs=[model_dropdown, image_input, url_input, load_in_8bit, device], outputs=gr.outputs.Textbox(type="auto", label="Caption"))
52
  iface.launch()