SmilingWolf commited on
Commit
0a8bf1b
·
1 Parent(s): 5b445fa

Feature parity with the other space

Browse files
Files changed (1) hide show
  1. app.py +39 -11
app.py CHANGED
@@ -55,7 +55,15 @@ class Predictor:
55
  config = json.loads(open("index/cosine_infos.json").read())["index_param"]
56
  faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
57
 
58
- def predict(self, positive_tags, negative_tags, n_neighbours=5):
 
 
 
 
 
 
 
 
59
  tags_df = self.tags_df
60
  model = self.model
61
 
@@ -99,12 +107,9 @@ class Predictor:
99
  for image_id, dist in zip(neighbours_ids, dists[0]):
100
  current_url = danbooru_id_to_url(
101
  image_id,
102
- [
103
- "General",
104
- "Sensitive",
105
- "Questionable",
106
- "Explicit",
107
- ],
108
  )
109
  if current_url is not None:
110
  image_urls.append(current_url)
@@ -117,16 +122,39 @@ def main():
117
 
118
  with gr.Blocks() as demo:
119
  with gr.Row():
120
- positive_tags = gr.Textbox(label="Positive tags")
121
- negative_tags = gr.Textbox(label="Negative tags")
122
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  find_btn = gr.Button("Find similar images")
124
 
125
  similar_images = gr.Gallery(label="Similar images", columns=[5])
126
 
127
  find_btn.click(
128
  fn=predictor.predict,
129
- inputs=[positive_tags, negative_tags],
 
 
 
 
 
 
 
130
  outputs=[similar_images],
131
  )
132
 
 
55
  config = json.loads(open("index/cosine_infos.json").read())["index_param"]
56
  faiss.ParameterSpace().set_index_parameters(self.knn_index, config)
57
 
58
+ def predict(
59
+ self,
60
+ positive_tags,
61
+ negative_tags,
62
+ selected_ratings,
63
+ n_neighbours,
64
+ api_username,
65
+ api_key,
66
+ ):
67
  tags_df = self.tags_df
68
  model = self.model
69
 
 
107
  for image_id, dist in zip(neighbours_ids, dists[0]):
108
  current_url = danbooru_id_to_url(
109
  image_id,
110
+ selected_ratings,
111
+ api_username,
112
+ api_key,
 
 
 
113
  )
114
  if current_url is not None:
115
  image_urls.append(current_url)
 
122
 
123
  with gr.Blocks() as demo:
124
  with gr.Row():
125
+ with gr.Column():
126
+ positive_tags = gr.Textbox(label="Positive tags")
127
+ negative_tags = gr.Textbox(label="Negative tags")
128
+ n_neighbours = gr.Slider(
129
+ minimum=1,
130
+ maximum=20,
131
+ value=5,
132
+ step=1,
133
+ label="# of images",
134
+ )
135
+
136
+ with gr.Column():
137
+ api_username = gr.Textbox(label="Danbooru API Username")
138
+ api_key = gr.Textbox(label="Danbooru API Key")
139
+ selected_ratings = gr.CheckboxGroup(
140
+ choices=["General", "Sensitive", "Questionable", "Explicit"],
141
+ value=["General", "Sensitive"],
142
+ label="Ratings",
143
+ )
144
  find_btn = gr.Button("Find similar images")
145
 
146
  similar_images = gr.Gallery(label="Similar images", columns=[5])
147
 
148
  find_btn.click(
149
  fn=predictor.predict,
150
+ inputs=[
151
+ positive_tags,
152
+ negative_tags,
153
+ selected_ratings,
154
+ n_neighbours,
155
+ api_username,
156
+ api_key,
157
+ ],
158
  outputs=[similar_images],
159
  )
160