prithivMLmods commited on
Commit
3a77e9a
·
verified ·
1 Parent(s): bf6d8a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py CHANGED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, AutoModelForCausalLM
6
+
7
+ # Attempt to install flash-attn
8
+ try:
9
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, check=True, shell=True)
10
+ except subprocess.CalledProcessError as e:
11
+ print(f"Error installing flash-attn: {e}")
12
+ print("Continuing without flash-attn.")
13
+
14
+ # Determine the device to use
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ # Load the base model and processor
18
+ try:
19
+ vision_language_model_base = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
20
+ vision_language_processor_base = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
21
+ except Exception as e:
22
+ print(f"Error loading base model: {e}")
23
+ vision_language_model_base = None
24
+ vision_language_processor_base = None
25
+
26
+ # Load the large model and processor
27
+ try:
28
+ vision_language_model_large = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True).to(device).eval()
29
+ vision_language_processor_large = AutoProcessor.from_pretrained('microsoft/Florence-2-large', trust_remote_code=True)
30
+ except Exception as e:
31
+ print(f"Error loading large model: {e}")
32
+ vision_language_model_large = None
33
+ vision_language_processor_large = None
34
+
35
+ def describe_image(uploaded_image, model_choice):
36
+ """
37
+ Generates a detailed description of the input image using the selected model.
38
+
39
+ Args:
40
+ uploaded_image (PIL.Image.Image): The image to describe.
41
+ model_choice (str): The model to use, either "Base" or "Large".
42
+
43
+ Returns:
44
+ str: A detailed textual description of the image or an error message.
45
+ """
46
+ if uploaded_image is None:
47
+ return "Please upload an image."
48
+
49
+ if model_choice == "Base":
50
+ if vision_language_model_base is None:
51
+ return "Base model failed to load."
52
+ model = vision_language_model_base
53
+ processor = vision_language_processor_base
54
+ elif model_choice == "Large":
55
+ if vision_language_model_large is None:
56
+ return "Large model failed to load."
57
+ model = vision_language_model_large
58
+ processor = vision_language_processor_large
59
+ else:
60
+ return "Invalid model choice."
61
+
62
+ if not isinstance(uploaded_image, Image.Image):
63
+ uploaded_image = Image.fromarray(uploaded_image)
64
+
65
+ inputs = processor(text="<MORE_DETAILED_CAPTION>", images=uploaded_image, return_tensors="pt").to(device)
66
+ with torch.no_grad():
67
+ generated_ids = model.generate(
68
+ input_ids=inputs["input_ids"],
69
+ pixel_values=inputs["pixel_values"],
70
+ max_new_tokens=1024,
71
+ early_stopping=False,
72
+ do_sample=False,
73
+ num_beams=3,
74
+ )
75
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
76
+ processed_description = processor.post_process_generation(
77
+ generated_text,
78
+ task="<MORE_DETAILED_CAPTION>",
79
+ image_size=(uploaded_image.width, uploaded_image.height)
80
+ )
81
+ image_description = processed_description["<MORE_DETAILED_CAPTION>"]
82
+ print("\nImage description generated!:", image_description)
83
+ return image_description
84
+
85
+ # Description for the interface
86
+ description = "Select the model to use for generating the image description. 'Base' is smaller and faster, while 'Large' is more accurate but slower."
87
+ if device == "cpu":
88
+ description += " Note: Running on CPU, which may be slow for large models."
89
+
90
+ # Create the Gradio interface
91
+ image_description_interface = gr.Interface(
92
+ fn=describe_image,
93
+ inputs=[
94
+ gr.Image(label="Upload Image", type="pil"),
95
+ gr.Radio(["Base", "Large"], label="Model Choice", value="Base")
96
+ ],
97
+ outputs=gr.Textbox(label="Generated Caption", lines=4, show_copy_button=True),
98
+ live=False,
99
+ title="# **[Florence-2 Models Image Captions](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**",
100
+ theme=="bethecloud/storj_theme",
101
+ description=description
102
+ )
103
+
104
+ # Launch the interface
105
+ image_description_interface.launch(debug=True, ssr_mode=False)