bsoupy commited on
Commit
1fc786c
·
verified ·
1 Parent(s): 5b7ac37

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. README.md +13 -3
  3. captions.json +0 -0
  4. faiss_index.idx +3 -0
  5. inference.py +110 -0
  6. requirements.txt +11 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ faiss_index.idx filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,13 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # 🏛️ RAG Image Captioning with Landmark Location
3
+
4
+ This model generates captions for monument/landmark images using a retrieval-augmented generation approach.
5
+
6
+ ## How it works:
7
+ - Uses CLIP to extract image embeddings.
8
+ - Retrieves top-k similar captions via FAISS.
9
+ - Generates a detailed caption with name and location using T5.
10
+
11
+ ## Example
12
+ Input: 🏰 Image of the Taj Mahal
13
+ Output: _"The place might be: Agra. The Taj Mahal is a white marble mausoleum located in Agra, India."_
captions.json ADDED
The diff for this file is too large to render. See raw diff
 
faiss_index.idx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a43f9410efe919810cd35354d77d7396cdee594a5d3998aadb6bc03606274332
3
+ size 3446829
inference.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration
5
+ from sentence_transformers import SentenceTransformer
6
+ import faiss
7
+ import numpy as np
8
+ import json
9
+ import spacy
10
+
11
+ # Load models and resources
12
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
13
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
+ text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
15
+ tokenizer = T5Tokenizer.from_pretrained("t5-small")
16
+ generator = T5ForConditionalGeneration.from_pretrained("t5-small")
17
+ nlp = spacy.load("en_core_web_sm")
18
+
19
+ # Load FAISS index and captions
20
+ faiss_index = faiss.read_index("./faiss_index.idx")
21
+ with open("./captions.json", "r", encoding="utf-8") as f:
22
+ captions = json.load(f)
23
+
24
+ def extract_image_features(image):
25
+ """
26
+ Extract image features using CLIP model.
27
+ Input: PIL Image or image path (str).
28
+ Output: Normalized image embedding (numpy array).
29
+ """
30
+ try:
31
+ # Handle both PIL Image and file path
32
+ if isinstance(image, str):
33
+ image = Image.open(image).convert("RGB")
34
+ else:
35
+ image = image.convert("RGB")
36
+ inputs = clip_processor(images=image, return_tensors="pt")
37
+ with torch.no_grad():
38
+ features = clip_model.get_image_features(**inputs)
39
+ features = torch.nn.functional.normalize(features, p=2, dim=-1)
40
+ return features.squeeze(0).cpu().numpy().astype("float32")
41
+ except Exception as e:
42
+ print(f"Error extracting features: {e}")
43
+ return None
44
+
45
+ def retrieve_similar_captions(image_embedding, k=5):
46
+ """
47
+ Retrieve k most similar captions using FAISS index.
48
+ Input: Image embedding (numpy array).
49
+ Output: List of captions.
50
+ """
51
+ if image_embedding.ndim == 1:
52
+ image_embedding = image_embedding.reshape(1, -1)
53
+ D, I = faiss_index.search(image_embedding, k)
54
+ return [captions[i] for i in I[0]]
55
+
56
+ def extract_location_names(texts):
57
+ """
58
+ Extract location names from captions using spaCy.
59
+ Input: List of captions.
60
+ Output: List of unique location names.
61
+ """
62
+ names = []
63
+ for text in texts:
64
+ doc = nlp(text)
65
+ for ent in doc.ents:
66
+ if ent.label_ in ["GPE", "LOC", "FAC"]:
67
+ names.append(ent.text)
68
+ return list(set(names))
69
+
70
+ def generate_caption_from_retrieved(retrieved_captions):
71
+ """
72
+ Generate a caption from retrieved captions using T5.
73
+ Input: List of retrieved captions.
74
+ Output: Generated caption (str).
75
+ """
76
+ locations = extract_location_names(retrieved_captions)
77
+ location_hint = f"The place might be: {', '.join(locations)}. " if locations else ""
78
+ prompt = location_hint + " ".join(retrieved_captions) + " Generate a caption with the landmark name:"
79
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
80
+ outputs = generator.generate(
81
+ input_ids=inputs.input_ids,
82
+ attention_mask=inputs.attention_mask,
83
+ max_length=300,
84
+ num_beams=5,
85
+ early_stopping=True
86
+ )
87
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
88
+
89
+ def generate_rag_caption(image):
90
+ """
91
+ Generate a RAG-based caption for an image.
92
+ Input: PIL Image or image path (str).
93
+ Output: Caption (str).
94
+ """
95
+ embedding = extract_image_features(image)
96
+ if embedding is None:
97
+ return "Failed to process image."
98
+ retrieved = retrieve_similar_captions(embedding, k=5)
99
+ if not retrieved:
100
+ return "No similar captions found."
101
+ return generate_caption_from_retrieved(retrieved)
102
+
103
+ def predict(image):
104
+ """
105
+ API-compatible function for inference.
106
+ Input: PIL Image or image file path.
107
+ Output: Dictionary with caption.
108
+ """
109
+ caption = generate_rag_caption(image)
110
+ return {"caption": caption}
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ transformers>=4.30.0
3
+ sentence-transformers>=2.2.0
4
+ faiss-cpu>=1.7.0
5
+ torch>=2.0.0
6
+ torchvision>=0.15.0
7
+ torchaudio>=2.0.0
8
+ pillow>=9.0.0
9
+ spacy>=3.5.0
10
+ langchain>=0.0.200
11
+ huggingface_hub>=0.15.0