merve HF staff commited on
Commit
530cb47
·
1 Parent(s): f3032c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -1,22 +1,38 @@
1
  import torch
2
- from transformers import pipeline
 
3
  import gradio as gr
4
 
5
 
6
  siglip_checkpoint = "nielsr/siglip-base-patch16-224"
7
  clip_checkpoint = "openai/clip-vit-base-patch16"
8
- siglip_detector = pipeline(model=siglip_checkpoint, task="zero-shot-image-classification")
9
  clip_detector = pipeline(model=clip_checkpoint, task="zero-shot-image-classification")
 
 
 
10
 
11
  def postprocess(output):
12
  return {out["label"]: float(out["score"]) for out in output}
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def infer(image, candidate_labels):
16
  candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
17
- siglip_out = siglip_detector(image, candidate_labels=candidate_labels)
18
  clip_out = clip_detector(image, candidate_labels=candidate_labels)
19
- return postprocess(clip_out), postprocess(siglip_out)
20
 
21
 
22
  with gr.Blocks() as demo:
 
1
  import torch
2
+ from transformers import pipeline, SiglipModel, AutoProcessor
3
+ import numpy as np
4
  import gradio as gr
5
 
6
 
7
  siglip_checkpoint = "nielsr/siglip-base-patch16-224"
8
  clip_checkpoint = "openai/clip-vit-base-patch16"
 
9
  clip_detector = pipeline(model=clip_checkpoint, task="zero-shot-image-classification")
10
+ siglip_model = SiglipModel.from_pretrained("nielsr/siglip-base-patch16-224")
11
+ siglip_processor = AutoProcessor.from_pretrained("nielsr/siglip-base-patch16-224")
12
+
13
 
14
  def postprocess(output):
15
  return {out["label"]: float(out["score"]) for out in output}
16
 
17
+ def postprocess_siglip(output, labels):
18
+ return {labels[i]: float(np.array(output[0])[i]) for i in range(len(labels))}
19
+
20
+ def siglip_detector(image, texts):
21
+ inputs = siglip_processor(text=texts, images=image, return_tensors="pt",
22
+ padding="max_length")
23
+
24
+ with torch.no_grad():
25
+ outputs = model(**inputs)
26
+ logits_per_image = outputs.logits_per_image
27
+ probs = torch.sigmoid(logits_per_image)
28
+ return probs
29
+
30
 
31
  def infer(image, candidate_labels):
32
  candidate_labels = [label.lstrip(" ") for label in candidate_labels.split(",")]
33
+ siglip_out = siglip_detector(image, candidate_labels)
34
  clip_out = clip_detector(image, candidate_labels=candidate_labels)
35
+ return postprocess(clip_out), postprocess_siglip(siglip_out, labels=candidate_labels)
36
 
37
 
38
  with gr.Blocks() as demo: