micole66 yuvalkirstain commited on
Commit
1dbe11d
·
0 Parent(s):

Duplicate from yuvalkirstain/PickScore

Browse files

Co-authored-by: Yuval Kirstain <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PickScore
3
+ emoji: 😻
4
+ colorFrom: yellow
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.28.3
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: yuvalkirstain/PickScore
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from PIL import Image
3
+ import gradio as gr
4
+ from glob import glob
5
+ import torch
6
+ from transformers import AutoModel, AutoProcessor
7
+
8
+ DEFAULT_EXAMPLE_PATH = f'examples/example_0'
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ weight_dtype = torch.bfloat16 if device == "cuda" else torch.float32
12
+ print(f"Using device: {device} ({weight_dtype})")
13
+ print("Loading model...")
14
+ model_pretrained_name_or_path = "yuvalkirstain/PickScore_v1"
15
+ processor = AutoProcessor.from_pretrained(model_pretrained_name_or_path)
16
+ model = AutoModel.from_pretrained(model_pretrained_name_or_path, torch_dtype=weight_dtype).eval().to(device)
17
+ print("Model loaded.")
18
+
19
+
20
+ def calc_probs(prompt, images):
21
+ print("Processing inputs...")
22
+ image_inputs = processor(
23
+ images=images,
24
+ padding=True,
25
+ truncation=True,
26
+ max_length=77,
27
+ return_tensors="pt",
28
+ ).to(device)
29
+
30
+ image_inputs = {k: v.to(weight_dtype) for k, v in image_inputs.items()}
31
+
32
+ text_inputs = processor(
33
+ text=prompt,
34
+ padding=True,
35
+ truncation=True,
36
+ max_length=77,
37
+ return_tensors="pt",
38
+ ).to(device)
39
+
40
+ with torch.no_grad():
41
+ print("Embedding images and text...")
42
+ image_embs = model.get_image_features(**image_inputs)
43
+ image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
44
+
45
+ text_embs = model.get_text_features(**text_inputs)
46
+ text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
47
+
48
+ print("Calculating scores...")
49
+ scores = model.logit_scale.exp() * (text_embs.float() @ image_embs.float().T)[0]
50
+
51
+ print("Calculating probabilities...")
52
+ probs = torch.softmax(scores, dim=-1)
53
+
54
+ return probs.cpu().tolist()
55
+
56
+
57
+ def predict(prompt, image_1, image_2):
58
+ print(f"Starting prediction for prompt: {prompt}")
59
+ start_time = time.time()
60
+ probs = calc_probs(prompt, [image_1, image_2])
61
+ print(f"Prediction: {probs} ({time.time() - start_time:.2f} seconds, ) ")
62
+ if device == "cuda":
63
+ print(f"GPU mem used: {round(torch.cuda.max_memory_allocated(device) / 1024 / 1024 / 1024, 2)}/{round(torch.cuda.get_device_properties(device).total_memory / 1024 / 1024 / 1024, 2)} GB")
64
+ return str(round(probs[0], 3)), str(round(probs[1], 3))
65
+
66
+
67
+ with gr.Blocks(title="PickScore v1") as demo:
68
+ gr.Markdown("# PickScore v1")
69
+ gr.Markdown(
70
+ "This is a demo for the PickScore model - see [paper](https://arxiv.org/abs/2305.01569), [code](https://github.com/yuvalkirstain/PickScore), [dataset](https://huggingface.co/datasets/pickapic-anonymous/pickapic_v1), and [model](https://huggingface.co/yuvalkirstain/PickScore_v1).")
71
+ gr.Markdown("## Instructions")
72
+ gr.Markdown("Write a prompt, place two images, and press run to get their PickScore!")
73
+ with gr.Row():
74
+ prompt = gr.inputs.Textbox(lines=1, label="Prompt",
75
+ default=open(f'{DEFAULT_EXAMPLE_PATH}/prompt.txt').readline())
76
+ with gr.Row():
77
+ image_1 = gr.components.Image(type="pil", label="image 1",
78
+ value=Image.open(f'{DEFAULT_EXAMPLE_PATH}/image_1.png'))
79
+ image_2 = gr.components.Image(type="pil", label="image 2",
80
+ value=Image.open(f'{DEFAULT_EXAMPLE_PATH}/image_2.png'))
81
+ with gr.Row():
82
+ pred_1 = gr.outputs.Textbox(label="Probability 1")
83
+ pred_2 = gr.outputs.Textbox(label="Probability 2")
84
+
85
+ btn = gr.Button("Run")
86
+ btn.click(fn=predict, inputs=[prompt, image_1, image_2], outputs=[pred_1, pred_2])
87
+ prompt.change(lambda: ("", ""), inputs=[], outputs=[pred_1, pred_2])
88
+
89
+ gr.Examples(
90
+ [[open(f'{path}/prompt.txt').readline(), f'{path}/image_1.png', f'{path}/image_2.png'] for path in
91
+ glob(f'examples/*')],
92
+ [prompt, image_1, image_2],
93
+ [pred_1, pred_2],
94
+ predict
95
+ )
96
+
97
+ demo.queue(concurrency_count=5).launch()
examples/example_0/image_1.png ADDED

Git LFS Details

  • SHA256: b447ee5abf70a31b20433aff393add09f9770bd6963dacc0f497f54ff1003f13
  • Pointer size: 131 Bytes
  • Size of remote file: 427 kB
examples/example_0/image_2.png ADDED

Git LFS Details

  • SHA256: c497ec4b927d7a38df6f59ab7dab0ea907ca2eb90ac9fb08b811acc6bb78e04a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
examples/example_0/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A sign that says PICK A PIC
examples/example_1/image_1.png ADDED

Git LFS Details

  • SHA256: 6afec84a65f687c76704f6dabd73718d346f0d093f655b625a590f4229ee6be5
  • Pointer size: 131 Bytes
  • Size of remote file: 841 kB
examples/example_1/image_2.png ADDED

Git LFS Details

  • SHA256: aca819d9f57e35eddd3cc971002719869b709ea88e02e19db0323d8d59a46a6b
  • Pointer size: 131 Bytes
  • Size of remote file: 934 kB
examples/example_1/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A bee devouring the world
examples/example_2/image_1.png ADDED

Git LFS Details

  • SHA256: ab1ef7d55c7d04827e8b0e01f1c197c7a82644b6ba672bf71c4e7e26f6a36c5b
  • Pointer size: 131 Bytes
  • Size of remote file: 992 kB
examples/example_2/image_2.png ADDED

Git LFS Details

  • SHA256: 73b3764e591b45eb4f19cf24b0c0d92b1d60644c711652ba3ec9bf2737253d42
  • Pointer size: 131 Bytes
  • Size of remote file: 857 kB
examples/example_2/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Crazy frog on one wheel
examples/example_3/image_1.png ADDED

Git LFS Details

  • SHA256: 27a539671b2b928e6a57c11452991150bea83ce572d500db68dbf92826ebd733
  • Pointer size: 131 Bytes
  • Size of remote file: 740 kB
examples/example_3/image_2.png ADDED

Git LFS Details

  • SHA256: 37bfc68bebe837e432d64725439519cf85a06c80ba8ecd31f3738102fa80badc
  • Pointer size: 131 Bytes
  • Size of remote file: 760 kB
examples/example_3/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A cute fluffy easter bunny singing
examples/example_4/image_1.png ADDED

Git LFS Details

  • SHA256: d5d00769883a61b32c9589913c956acddb81181ebd2cc56ffc0de3a0dc72fe84
  • Pointer size: 131 Bytes
  • Size of remote file: 971 kB
examples/example_4/image_2.png ADDED

Git LFS Details

  • SHA256: b38d23aaf87595ac3aa6f74258fbdfc5b7342cb92f8edf38e96eda0376726d36
  • Pointer size: 131 Bytes
  • Size of remote file: 928 kB
examples/example_4/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Silver furred Lion man hybrid with few leather clothes
examples/example_5/image_1.png ADDED

Git LFS Details

  • SHA256: f38c1605bccdd08783c0b362a1948998fdb0dfb7f269ee3f4a9a782782264b75
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
examples/example_5/image_2.png ADDED

Git LFS Details

  • SHA256: b3fadbd6c75ee681588c6155a159770b9cfcc12ac6b7217c52548dd1ead8e12f
  • Pointer size: 131 Bytes
  • Size of remote file: 999 kB
examples/example_5/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Superman's corpse found
examples/example_6/image_1.png ADDED

Git LFS Details

  • SHA256: 73ddbe2000310f1bb6f8712a5f08c28c60788cbf107b29e4c2f549fc206dc520
  • Pointer size: 131 Bytes
  • Size of remote file: 877 kB
examples/example_6/image_2.png ADDED

Git LFS Details

  • SHA256: ec2bb1b16de70f8ff009deb043f1ea9835497c6d3421fa9d6a2179808711b8d2
  • Pointer size: 131 Bytes
  • Size of remote file: 717 kB
examples/example_6/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ a girl
examples/example_7/image_1.png ADDED

Git LFS Details

  • SHA256: cb8290da1816eb9ea701154677a83afbfcb977978e2d7473dd6bef1944852fcf
  • Pointer size: 131 Bytes
  • Size of remote file: 846 kB
examples/example_7/image_2.png ADDED

Git LFS Details

  • SHA256: 07f8b3657fee644b8ef129f60d5c0d7d2caad3a97992f075e55c675b6df70b1a
  • Pointer size: 131 Bytes
  • Size of remote file: 385 kB
examples/example_7/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A gloomy rabbit drinks wine
examples/example_8/image_1.png ADDED

Git LFS Details

  • SHA256: 83bf47ebdff85b39fe67cb34ecd413832ec1d492eef823c4395a73b04412fd55
  • Pointer size: 132 Bytes
  • Size of remote file: 1.76 MB
examples/example_8/image_2.png ADDED

Git LFS Details

  • SHA256: 9952429fcac76f9a116ac043615668528817b51174f5bb2eea4e8bc9529dad12
  • Pointer size: 132 Bytes
  • Size of remote file: 1.77 MB
examples/example_8/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ a beautiful landscape
examples/example_9/image_1.png ADDED

Git LFS Details

  • SHA256: cb8290da1816eb9ea701154677a83afbfcb977978e2d7473dd6bef1944852fcf
  • Pointer size: 131 Bytes
  • Size of remote file: 846 kB
examples/example_9/image_2.png ADDED

Git LFS Details

  • SHA256: 07f8b3657fee644b8ef129f60d5c0d7d2caad3a97992f075e55c675b6df70b1a
  • Pointer size: 131 Bytes
  • Size of remote file: 385 kB
examples/example_9/prompt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ A gloomy rabbit drinks wine
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ transformers