ktllc commited on
Commit
ad91d57
·
1 Parent(s): 2b19e7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -23,24 +23,22 @@ def find_similarity(base64_image, text_input):
23
  # Preprocess the image
24
  image = preprocess(image).unsqueeze(0).to(device)
25
 
26
- # Tokenize the text input
27
  text_tokens = clip.tokenize([text_input]).to(device)
28
 
29
  # Encode image and text features
30
- with torch no grad():
31
- image_features = model.encode_image(image)
32
- text_features = model.encode_text(text_tokens)
33
 
34
- # Calculate cosine similarity
35
- similarity = (image_features @ text_features.T).squeeze(0).cpu().numpy()
36
 
37
- # Convert each element in the similarity array to Decimal
38
- similarity_decimal = [Decimal(float(score)) for score in similarity]
 
39
 
40
- # Format Decimal values as floats with specific precision (e.g., 4 decimal places)
41
- formatted_similarity = [f'{float(score):.5f}' for score in similarity_decimal]
 
 
42
 
43
- return formatted_similarity
44
 
45
  # Create a Gradio interface
46
  iface = gr.Interface(
 
23
  # Preprocess the image
24
  image = preprocess(image).unsqueeze(0).to(device)
25
 
26
+ # Prepare input text
27
  text_tokens = clip.tokenize([text_input]).to(device)
28
 
29
  # Encode image and text features
 
 
 
30
 
 
 
31
 
32
+ with torch.no_grad():
33
+ image_features = model.encode_image(image)
34
+ text_features = model.encode_text(text_tokens)
35
 
36
+ # Normalize features and calculate similarity
37
+ image_features /= image_features.norm(dim=-1, keepdim=True)
38
+ text_features /= text_features.norm(dim=-1, keepdim=True)
39
+ similarity = (text_features @ image_features.T).squeeze(0).cpu().numpy()
40
 
41
+ return similarity[0, 0]
42
 
43
  # Create a Gradio interface
44
  iface = gr.Interface(