File size: 3,830 Bytes
1fc786c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import spacy

# Load models and resources
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
tokenizer = T5Tokenizer.from_pretrained("t5-small")
generator = T5ForConditionalGeneration.from_pretrained("t5-small")
nlp = spacy.load("en_core_web_sm")

# Load FAISS index and captions
faiss_index = faiss.read_index("./faiss_index.idx")
with open("./captions.json", "r", encoding="utf-8") as f:
    captions = json.load(f)

def extract_image_features(image):
    """
    Extract image features using CLIP model.
    Input: PIL Image or image path (str).
    Output: Normalized image embedding (numpy array).
    """
    try:
        # Handle both PIL Image and file path
        if isinstance(image, str):
            image = Image.open(image).convert("RGB")
        else:
            image = image.convert("RGB")
        inputs = clip_processor(images=image, return_tensors="pt")
        with torch.no_grad():
            features = clip_model.get_image_features(**inputs)
        features = torch.nn.functional.normalize(features, p=2, dim=-1)
        return features.squeeze(0).cpu().numpy().astype("float32")
    except Exception as e:
        print(f"Error extracting features: {e}")
        return None

def retrieve_similar_captions(image_embedding, k=5):
    """
    Retrieve k most similar captions using FAISS index.
    Input: Image embedding (numpy array).
    Output: List of captions.
    """
    if image_embedding.ndim == 1:
        image_embedding = image_embedding.reshape(1, -1)
    D, I = faiss_index.search(image_embedding, k)
    return [captions[i] for i in I[0]]

def extract_location_names(texts):
    """
    Extract location names from captions using spaCy.
    Input: List of captions.
    Output: List of unique location names.
    """
    names = []
    for text in texts:
        doc = nlp(text)
        for ent in doc.ents:
            if ent.label_ in ["GPE", "LOC", "FAC"]:
                names.append(ent.text)
    return list(set(names))

def generate_caption_from_retrieved(retrieved_captions):
    """
    Generate a caption from retrieved captions using T5.
    Input: List of retrieved captions.
    Output: Generated caption (str).
    """
    locations = extract_location_names(retrieved_captions)
    location_hint = f"The place might be: {', '.join(locations)}. " if locations else ""
    prompt = location_hint + " ".join(retrieved_captions) + " Generate a caption with the landmark name:"
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    outputs = generator.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        max_length=300,
        num_beams=5,
        early_stopping=True
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def generate_rag_caption(image):
    """
    Generate a RAG-based caption for an image.
    Input: PIL Image or image path (str).
    Output: Caption (str).
    """
    embedding = extract_image_features(image)
    if embedding is None:
        return "Failed to process image."
    retrieved = retrieve_similar_captions(embedding, k=5)
    if not retrieved:
        return "No similar captions found."
    return generate_caption_from_retrieved(retrieved)

def predict(image):
    """
    API-compatible function for inference.
    Input: PIL Image or image file path.
    Output: Dictionary with caption.
    """
    caption = generate_rag_caption(image)
    return {"caption": caption}