ktllc commited on
Commit
d3bd556
·
1 Parent(s): 9f13edb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -5,11 +5,11 @@ import torch
5
  from PIL import Image
6
  import base64
7
  from io import BytesIO
8
- from decimal import Decimal # Import the Decimal module
9
 
10
  # Load the CLIP model
11
  model, preprocess = clip.load("ViT-B/32")
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
  model.to(device).eval()
14
 
15
  # Define a function to find similarity
@@ -34,10 +34,10 @@ def find_similarity(base64_image, text_input):
34
  # Calculate cosine similarity
35
  similarity = (image_features @ text_features.T).squeeze(0).cpu().numpy()
36
 
37
- # Convert the similarity score to a Decimal
38
- similarity_decimal = Decimal(similarity)
39
 
40
- return similarity_decimal # Return the similarity score as a Decimal
41
 
42
  # Create a Gradio interface
43
  iface = gr.Interface(
 
5
  from PIL import Image
6
  import base64
7
  from io import BytesIO
8
+ from decimal import Decimal
9
 
10
  # Load the CLIP model
11
  model, preprocess = clip.load("ViT-B/32")
12
+ device = "cuda" if torch.cuda.isavailable() else "cpu"
13
  model.to(device).eval()
14
 
15
  # Define a function to find similarity
 
34
  # Calculate cosine similarity
35
  similarity = (image_features @ text_features.T).squeeze(0).cpu().numpy()
36
 
37
+ # Convert each element in the similarity array to Decimal
38
+ similarity_decimal = [Decimal(score) for score in similarity]
39
 
40
+ return similarity_decimal
41
 
42
  # Create a Gradio interface
43
  iface = gr.Interface(