ktllc commited on
Commit
d4c665a
·
1 Parent(s): 7413961

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -14,11 +14,11 @@ print(device)
14
  # Define the Business Listing variable
15
  Business_Listing = "Air Guide"
16
 
17
- def find_similarity(images, text_input):
18
  image_features = []
19
 
20
- # Preprocess and encode multiple images
21
- for image in images:
22
  image = preprocess(image).unsqueeze(0).to(device)
23
  with torch.no_grad():
24
  image_feature = model.encode_image(image).float()
@@ -39,15 +39,19 @@ def find_similarity(images, text_input):
39
  similarity = (text_features @ image_feature.T).cpu().numpy()
40
  similarities.append(similarity[0, 0])
41
 
42
- # Find the index of the image with the highest similarity
43
- best_match_index = np.argmax(similarities)
44
 
45
  return similarities, best_match_index
46
 
47
  # Define a Gradio interface
48
  iface = gr.Interface(
49
  fn=find_similarity,
50
- inputs=[gr.Image(type="pil", label="Image 1"), gr.Image(type="pil", label="Image 2"), "text"],
 
 
 
 
51
  outputs=["text", "number"],
52
  live=True,
53
  interpretation="default",
 
14
  # Define the Business Listing variable
15
  Business_Listing = "Air Guide"
16
 
17
+ def find_similarity(image1, image2, text_input):
18
  image_features = []
19
 
20
+ # Preprocess and encode the two images
21
+ for image in [image1, image2]:
22
  image = preprocess(image).unsqueeze(0).to(device)
23
  with torch.no_grad():
24
  image_feature = model.encode_image(image).float()
 
39
  similarity = (text_features @ image_feature.T).cpu().numpy()
40
  similarities.append(similarity[0, 0])
41
 
42
+ # Determine which image has a higher similarity to the text
43
+ best_match_index = 0 if similarities[0] > similarities[1] else 1
44
 
45
  return similarities, best_match_index
46
 
47
  # Define a Gradio interface
48
  iface = gr.Interface(
49
  fn=find_similarity,
50
+ inputs=[
51
+ gr.Image(type="pil", label="Image 1"),
52
+ gr.Image(type="pil", label="Image 2"),
53
+ "text"
54
+ ],
55
  outputs=["text", "number"],
56
  live=True,
57
  interpretation="default",