Thouph commited on
Commit
b5d466c
·
verified ·
1 Parent(s): 552a612

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +337 -0
app.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import defaultdict
3
+ import safetensors
4
+ import timm
5
+ from transformers import AutoProcessor
6
+ import gradio as gr
7
+ import torch
8
+ import time
9
+ from florence2_implementation.modeling_florence2 import Florence2ForConditionalGeneration
10
+ from torchvision.transforms import InterpolationMode
11
+ from PIL import Image
12
+ import torchvision.transforms.functional as TF
13
+ from torchvision.transforms import transforms
14
+ import random
15
+ import csv
16
+ import os
17
+
18
+ torch.set_grad_enabled(False)
19
+
20
+ # HF now (Feb 20, 2025) impose storage limit of 1GB. Will have to pull JTP from other places.
21
+ os.system("wget -nv https://huggingface.co/spaces/RedRocket/JointTaggerProject-Inference-Beta/resolve/main/JTP_PILOT2-2-e3-vit_so400m_patch14_siglip_384.safetensors")
22
+
23
+
24
+ category_id_to_str = {
25
+ "0": "general",
26
+ # 3 copyright
27
+ "4": "character",
28
+ "5": "species",
29
+ "7": "meta",
30
+ "8": "lore",
31
+ "1": "artist",
32
+ }
33
+ class Pruner:
34
+ def __init__(self, path_to_tag_list_csv):
35
+ species_tags = set()
36
+ allowed_tags = set()
37
+ with open(path_to_tag_list_csv, "r") as f:
38
+ reader = csv.reader(f)
39
+ header = next(reader)
40
+ name_index = header.index("name")
41
+ category_index = header.index("category")
42
+ post_count_index = header.index("post_count")
43
+ for row in reader:
44
+ if int(row[post_count_index]) > 20:
45
+ category = row[category_index]
46
+ name = row[name_index]
47
+ if category == "5":
48
+ species_tags.add(name)
49
+ allowed_tags.add(name)
50
+ elif category == "0":
51
+ allowed_tags.add(name)
52
+ elif category == "7":
53
+ allowed_tags.add(name)
54
+
55
+ self.species_tags = species_tags
56
+ self.allowed_tags = allowed_tags
57
+
58
+ def _prune_not_allowed_tags(self, raw_tags):
59
+ this_allowed_tags = set()
60
+ for tag in raw_tags:
61
+ if tag in self.allowed_tags:
62
+ this_allowed_tags.add(tag)
63
+ return this_allowed_tags
64
+
65
+ def _find_and_format_species_tags(self, tag_set):
66
+ this_specie_tags = []
67
+ for tag in tag_set:
68
+ if tag in self.species_tags:
69
+ this_specie_tags.append(tag)
70
+
71
+ formatted_tags = f"species: {' '.join([t for t in this_specie_tags])}\n"
72
+ return formatted_tags, this_specie_tags
73
+
74
+ def prompt_construction_pipeline_florence2(self, tags, length):
75
+ if type(tags) is str:
76
+ tags = tags.split(" ")
77
+ random.shuffle(tags)
78
+ tags = self._prune_not_allowed_tags(tags, )
79
+ formatted_species_tags, this_specie_tags = self._find_and_format_species_tags(tags)
80
+ non_species_tags = [t for t in tags if t not in this_specie_tags]
81
+ prompt = f"{' '.join(non_species_tags)}\n{formatted_species_tags}\nlength: {length}\n\nSTYLE1 FURRY CAPTION:"
82
+ return prompt
83
+
84
+
85
+
86
+ class Fit(torch.nn.Module):
87
+ def __init__(
88
+ self,
89
+ bounds: tuple[int, int] | int,
90
+ interpolation=InterpolationMode.LANCZOS,
91
+ grow: bool = True,
92
+ pad: float | None = None
93
+ ):
94
+ super().__init__()
95
+
96
+ self.bounds = (bounds, bounds) if isinstance(bounds, int) else bounds
97
+ self.interpolation = interpolation
98
+ self.grow = grow
99
+ self.pad = pad
100
+
101
+ def forward(self, img: Image) -> Image:
102
+ wimg, himg = img.size
103
+ hbound, wbound = self.bounds
104
+
105
+ hscale = hbound / himg
106
+ wscale = wbound / wimg
107
+
108
+ if not self.grow:
109
+ hscale = min(hscale, 1.0)
110
+ wscale = min(wscale, 1.0)
111
+
112
+ scale = min(hscale, wscale)
113
+ if scale == 1.0:
114
+ return img
115
+
116
+ hnew = min(round(himg * scale), hbound)
117
+ wnew = min(round(wimg * scale), wbound)
118
+
119
+ img = TF.resize(img, (hnew, wnew), self.interpolation)
120
+
121
+ if self.pad is None:
122
+ return img
123
+
124
+ hpad = hbound - hnew
125
+ wpad = wbound - wnew
126
+
127
+ tpad = hpad // 2
128
+ bpad = hpad - tpad
129
+
130
+ lpad = wpad // 2
131
+ rpad = wpad - lpad
132
+
133
+ return TF.pad(img, (lpad, tpad, rpad, bpad), self.pad)
134
+
135
+ def __repr__(self) -> str:
136
+ return (
137
+ f"{self.__class__.__name__}(" +
138
+ f"bounds={self.bounds}, " +
139
+ f"interpolation={self.interpolation.value}, " +
140
+ f"grow={self.grow}, " +
141
+ f"pad={self.pad})"
142
+ )
143
+
144
+
145
+ class CompositeAlpha(torch.nn.Module):
146
+ def __init__(
147
+ self,
148
+ background: tuple[float, float, float] | float,
149
+ ):
150
+ super().__init__()
151
+
152
+ self.background = (background, background, background) if isinstance(background, float) else background
153
+ self.background = torch.tensor(self.background).unsqueeze(1).unsqueeze(2)
154
+
155
+ def forward(self, img: torch.Tensor) -> torch.Tensor:
156
+ if img.shape[-3] == 3:
157
+ return img
158
+
159
+ alpha = img[..., 3, None, :, :]
160
+
161
+ img[..., :3, :, :] *= alpha
162
+
163
+ background = self.background.expand(-1, img.shape[-2], img.shape[-1])
164
+ if background.ndim == 1:
165
+ background = background[:, None, None]
166
+ elif background.ndim == 2:
167
+ background = background[None, :, :]
168
+
169
+ img[..., :3, :, :] += (1.0 - alpha) * background
170
+ return img[..., :3, :, :]
171
+
172
+ def __repr__(self) -> str:
173
+ return (
174
+ f"{self.__class__.__name__}(" +
175
+ f"background={self.background})"
176
+ )
177
+
178
+
179
+ class GatedHead(torch.nn.Module):
180
+ def __init__(self,
181
+ num_features: int,
182
+ num_classes: int
183
+ ):
184
+ super().__init__()
185
+ self.num_classes = num_classes
186
+ self.linear = torch.nn.Linear(num_features, num_classes * 2)
187
+
188
+ self.act = torch.nn.Sigmoid()
189
+ self.gate = torch.nn.Sigmoid()
190
+
191
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
192
+ x = self.linear(x)
193
+ x = self.act(x[:, :self.num_classes]) * self.gate(x[:, self.num_classes:])
194
+ return x
195
+
196
+ model_id = "lodestone-horizon/furrence2-large"
197
+ model = Florence2ForConditionalGeneration.from_pretrained(model_id,).eval()
198
+ processor = AutoProcessor.from_pretrained("./florence2_implementation/", trust_remote_code=True)
199
+
200
+
201
+ tree = defaultdict(list)
202
+ with open('tag_implications-2024-05-05.csv', 'rt') as csvfile:
203
+ reader = csv.DictReader(csvfile)
204
+ for row in reader:
205
+ if row["status"] == "active":
206
+ tree[row["consequent_name"]].append(row["antecedent_name"])
207
+
208
+
209
+ title = """<h1 align="center">Furrence2 Captioner Demo</h1>"""
210
+ description=(
211
+ """<br> The captioner is being prompted by JTP Pilot2 tagger. You may use hand-curated tags to get better results. </a>
212
+ <br> This demo is running on CPU. For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.</a>"""
213
+ )
214
+ tagger_transform = transforms.Compose([
215
+ Fit((384, 384)),
216
+ transforms.ToTensor(),
217
+ CompositeAlpha(0.5),
218
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
219
+ transforms.CenterCrop((384, 384)),
220
+ ])
221
+
222
+ THRESHOLD = 0.2
223
+ tagger_model = timm.create_model(
224
+ "vit_so400m_patch14_siglip_384.webli",
225
+ pretrained=False,
226
+ num_classes=9083,
227
+ ) # type: VisionTransformer
228
+ tagger_model.head = GatedHead(min(tagger_model.head.weight.shape), 9083)
229
+ safetensors.torch.load_model(tagger_model, "JTP_PILOT2-2-e3-vit_so400m_patch14_siglip_384.safetensors")
230
+
231
+ tagger_model.eval()
232
+
233
+ with open("JTP_PILOT2_tags.json", "r") as file:
234
+ tags = json.load(file) # type: dict
235
+ allowed_tags = list(tags.keys())
236
+
237
+ for idx, tag in enumerate(allowed_tags):
238
+ allowed_tags[idx] = tag
239
+
240
+ pruner = Pruner("tags-2024-05-05.csv")
241
+
242
+ def generate_prompt(image, expected_caption_length):
243
+ global THRESHOLD, tree, tokenizer, model, tagger_model, tagger_transform
244
+ tagger_input = tagger_transform(image.convert('RGBA')).unsqueeze(0)
245
+ probabilities = tagger_model(tagger_input)
246
+ for prob in probabilities:
247
+ indices = torch.where(prob > THRESHOLD)[0]
248
+ sorted_indices = torch.argsort(prob[indices], descending=True)
249
+ final_tags = []
250
+ for i in sorted_indices:
251
+ final_tags.append(allowed_tags[indices[i]])
252
+
253
+ final_tags = " ".join(final_tags)
254
+ task_prompt = pruner.prompt_construction_pipeline_florence2(final_tags, expected_caption_length)
255
+ return task_prompt
256
+
257
+
258
+ def inference_caption(image, expected_caption_length, seq_len=512,):
259
+ start_time = time.time()
260
+ prompt_input = generate_prompt(image, expected_caption_length)
261
+ end_time = time.time()
262
+ execution_time = end_time - start_time
263
+ print(f"Finished tagging in {execution_time:.3f} seconds")
264
+ try:
265
+ pixel_values = processor.image_processor(image, return_tensors="pt", )["pixel_values"]
266
+ encoder_inputs = processor.tokenizer(
267
+ text=prompt_input,
268
+ return_tensors="pt",
269
+ # padding = "max_length",
270
+ # truncation = True,
271
+ # max_length = 256,
272
+ # don't add these; these will cause problems when doing inference
273
+ )
274
+ start_time = time.time()
275
+ generated_ids = model.generate(
276
+ input_ids=encoder_inputs["input_ids"],
277
+ attention_mask=encoder_inputs["attention_mask"],
278
+ pixel_values=pixel_values,
279
+ max_new_tokens=seq_len,
280
+ early_stopping=False,
281
+ do_sample=False,
282
+ num_beams=3,
283
+ )
284
+ end_time = time.time()
285
+ execution_time = end_time - start_time
286
+ print(f"Finished captioning in {execution_time:.3f} seconds")
287
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
288
+
289
+ return generated_text
290
+ except Exception as e:
291
+ print("error message:", e)
292
+ return "An error occurred."
293
+
294
+
295
+ def main():
296
+
297
+ with gr.Blocks() as iface:
298
+
299
+ gr.Markdown(title)
300
+ gr.Markdown(description)
301
+
302
+ with gr.Row():
303
+ with gr.Column(scale=1):
304
+ image_input = gr.Image(type="pil")
305
+
306
+ seq_len = gr.Number(
307
+ value=512, label="Output Cutoff Length", precision=0,
308
+ interactive=True
309
+ )
310
+
311
+ expected_length = gr.Number(minimum=50, maximum=200,
312
+ value=100, label="Expected Caption Length", precision=0,
313
+ interactive=True
314
+ )
315
+
316
+ with gr.Column(scale=1):
317
+ with gr.Column():
318
+ caption_button = gr.Button(
319
+ value="Caption it!", interactive=True, variant="primary",
320
+ )
321
+
322
+ caption_output = gr.Textbox(lines=1, label="Caption Output")
323
+ caption_button.click(
324
+ inference_caption,
325
+ [
326
+ image_input,
327
+ expected_length,
328
+ seq_len,
329
+ ],
330
+ [caption_output,],
331
+ )
332
+
333
+ iface.launch(share=False)
334
+
335
+ if __name__ == "__main__":
336
+ main()
337
+