selamw commited on
Commit
3f22c56
·
verified ·
1 Parent(s): facd6c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -5
app.py CHANGED
@@ -5,9 +5,12 @@ import spaces
5
  import torch
6
  import os
7
 
 
 
 
8
 
9
  access_token = os.getenv('HF_token')
10
- model_id = "selamw/BirdWatcher"
11
  bnb_config = BitsAndBytesConfig(load_in_8bit=True)
12
 
13
 
@@ -41,11 +44,18 @@ def convert_to_markdown(input_text):
41
 
42
  @spaces.GPU
43
  def infer_fin_pali(image, question):
44
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
 
46
- model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, token=access_token)
47
- processor = PaliGemmaProcessor.from_pretrained(model_id, token=access_token)
 
48
 
 
 
 
 
 
 
49
 
50
  inputs = processor(images=image, text=question, return_tensors="pt").to(device)
51
 
@@ -110,4 +120,118 @@ with gr.Blocks(css=css) as demo:
110
  label='Examples 👇'
111
  )
112
 
113
- demo.launch(debug=True, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
  import os
7
 
8
+ from transformers import AutoProcessor, AutoModelForCausalLM
9
+
10
+
11
 
12
  access_token = os.getenv('HF_token')
13
+ model_id = "selamw/BirdWatcher2"
14
  bnb_config = BitsAndBytesConfig(load_in_8bit=True)
15
 
16
 
 
44
 
45
  @spaces.GPU
46
  def infer_fin_pali(image, question):
47
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
 
49
+ # model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, token=access_token)
50
+ # processor = PaliGemmaProcessor.from_pretrained(model_id, token=access_token)
51
+
52
 
53
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
54
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
55
+
56
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, trust_remote_code=True, quantization_config=bnb_config,token=access_token).to(device)
57
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, token=access_token)
58
+ ###
59
 
60
  inputs = processor(images=image, text=question, return_tensors="pt").to(device)
61
 
 
120
  label='Examples 👇'
121
  )
122
 
123
+ demo.launch(debug=True, share=True)
124
+
125
+ # import gradio as gr
126
+ # from PIL import Image
127
+ # from transformers import BitsAndBytesConfig, PaliGemmaForConditionalGeneration, PaliGemmaProcessor
128
+ # import spaces
129
+ # import torch
130
+ # import os
131
+
132
+
133
+ # access_token = os.getenv('HF_token')
134
+ # model_id = "selamw/BirdWatcher"
135
+ # bnb_config = BitsAndBytesConfig(load_in_8bit=True)
136
+
137
+
138
+ # def convert_to_markdown(input_text):
139
+ # """Converts bird information text to Markdown format,
140
+ # making specific keywords bold and adding headings.
141
+ # Args:
142
+ # input_text (str): The input text containing bird information.
143
+ # Returns:
144
+ # str: The formatted Markdown text.
145
+ # """
146
+
147
+ # bold_words = ['Look:', 'Cool Fact!:', 'Habitat:', 'Food:', 'Birdie Behaviors:']
148
+
149
+ # # Split into title and content based on the first ":", handling extra whitespace
150
+ # if ":" in input_text:
151
+ # title, content = map(str.strip, input_text.split(":", 1))
152
+ # else:
153
+ # title = input_text
154
+ # content = ""
155
+
156
+ # # Bold the keywords
157
+ # for word in bold_words:
158
+ # content = content.replace(word, f'\n\n**{word}')
159
+
160
+ # # Construct the Markdown output with headings
161
+ # formatted_output = f"**{title}**{content}"
162
+
163
+ # return formatted_output.strip()
164
+
165
+
166
+ # @spaces.GPU
167
+ # def infer_fin_pali(image, question):
168
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
169
+
170
+ # model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, token=access_token)
171
+ # processor = PaliGemmaProcessor.from_pretrained(model_id, token=access_token)
172
+
173
+
174
+ # inputs = processor(images=image, text=question, return_tensors="pt").to(device)
175
+
176
+ # predictions = model.generate(**inputs, max_new_tokens=512)
177
+ # decoded_output = processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
178
+
179
+ # # Ensure proper Markdown formatting
180
+ # formatted_output = convert_to_markdown(decoded_output)
181
+
182
+ # return formatted_output
183
+
184
+
185
+ # css = """
186
+ # #mkd {
187
+ # height: 500px;
188
+ # overflow: auto;
189
+ # border: 1px solid #ccc;
190
+ # }
191
+ # h1 {
192
+ # text-align: center;
193
+ # }
194
+ # h3 {
195
+ # text-align: center;
196
+ # }
197
+ # h2 {
198
+ # text-align: center;
199
+ # }
200
+ # span.gray-text {
201
+ # color: gray;
202
+ # }
203
+ # """
204
+
205
+ # with gr.Blocks(css=css) as demo:
206
+ # gr.HTML("<h1>🦩 BirdWatcher 🦜</h1>")
207
+ # gr.HTML("<h3>[Powered by Fine-tuned PaliGemma]</h3>")
208
+ # gr.HTML("<h3>Upload an image of a bird, and the model will generate a detailed description of its species.</h3>")
209
+ # gr.HTML("<p style='text-align: center;'>(There are over 11,000 bird species in the world, and this model was fine-tuned with over 500)</p>")
210
+
211
+ # with gr.Tab(label="Bird Identification"):
212
+ # with gr.Row():
213
+ # input_img = gr.Image(label="Input Bird Image")
214
+ # with gr.Column():
215
+ # with gr.Row():
216
+ # question = gr.Text(label="Default Prompt", value="Describe this bird species", elem_id="default-prompt", interactive=True)
217
+ # with gr.Row():
218
+ # submit_btn = gr.Button(value="Run")
219
+ # with gr.Row():
220
+ # output = gr.Markdown(label="Response") # Use Markdown component to display output
221
+
222
+ # submit_btn.click(infer_fin_pali, [input_img, question], [output])
223
+
224
+ # gr.Examples(
225
+ # [["01.jpg", "Describe this bird species"],
226
+ # ["02.jpg", "Describe this bird species"],
227
+ # ["03.jpg", "Describe this bird species"],
228
+ # ["04.jpg", "Describe this bird species"],
229
+ # ["05.jpg", "Describe this bird species"],
230
+ # ["06.jpg", "Describe this bird species"]],
231
+ # inputs=[input_img, question],
232
+ # outputs=[output],
233
+ # fn=infer_fin_pali,
234
+ # label='Examples 👇'
235
+ # )
236
+
237
+ # demo.launch(debug=True, share=True)