File size: 7,268 Bytes
ade9ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28b7868
ade9ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667e965
ade9ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66ece16
de5e541
 
 
 
 
ade9ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import streamlit as st
import chromadb
from chromadb.config import Settings
from transformers import CLIPProcessor, CLIPModel
import cv2
from PIL import Image
import torch
import logging
import uuid
import tempfile
import os
import requests
import json
from dotenv import load_dotenv
import shutil

load_dotenv()
HF_TOKEN = os.getenv('hf_token')
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
    
        

    @st.cache_resource
    def load_model():
        device = 'cpu'
        processor = CLIPProcessor.from_pretrained(
            "openai/clip-vit-large-patch14", token=HF_TOKEN)
        model = CLIPModel.from_pretrained(
            "openai/clip-vit-large-patch14", token=HF_TOKEN)
        model.eval().to(device)
        return processor, model

    @st.cache_resource
    def load_chromadb():
        chroma_client = chromadb.PersistentClient(
            path='Data', settings=Settings(anonymized_telemetry=False))
        collection = chroma_client.get_or_create_collection(name='images')
        return chroma_client, collection

    def resize_image(image_path, size=(224, 224)):
        if isinstance(image_path, str):
            img = Image.open(image_path).convert("RGB")
        else:
            img = Image.open(image_path).convert("RGB")
        img_resized = img.resize(size, Image.LANCZOS)
        return img_resized

    def get_image_embedding(image, model, preprocess, device='cpu'):
        image = Image.open(f'{image}').convert('RGB')
        input_tensor = preprocess(images=[image], return_tensors='pt')[
            'pixel_values'].to(device)
        with torch.no_grad():
            embedding = model.get_image_features(
                pixel_values=input_tensor)

        return torch.nn.functional.normalize(embedding, p=2, dim=1)

    def extract_frames(v_path, frame_interval=30):
        cap = cv2.VideoCapture(v_path)
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
        total_seconds = frame_count//frame_rate
        frame_idx = 0
        saved_frames = 0
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            if frame_idx % frame_interval == 0:
                unique_image_id = str(uuid.uuid4())
                frame_name = f"{temp_dir}/frame_{unique_image_id}_{saved_frames}.jpg"
                cv2.imwrite(frame_name, frame)

                saved_frames += 1
            frame_idx += 1
        cap.release()
        logger.info("Frames extracted")

    def insert_into_db(collection, dir):
        embedding_list = []
        file_names = []
        ids = []
        with st.status("Generating embedding... ⏳", expanded=True) as status:
            for i in os.listdir(dir):
                embedding = get_image_embedding(
                    f"{dir}/{i}", model, processor)
                embedding_list.append(
                    embedding.squeeze(0).numpy().tolist())
                file_names.append(
                    {'path': f"{dir}/{i}", 'type': 'photo'})
                unique_id = str(uuid.uuid4())
                ids.append(unique_id)
            status.update(label="Embedding generation complete",
                          state="complete")

        collection.add(
            embeddings=embedding_list,
            ids=ids,
            metadatas=file_names
        )
        logger.info("Data inserted into DB")

    processor, model = load_model()
    logger.info("Model and processor loaded")
    client, collection = load_chromadb()
    logger.info("ChromaDB loaded")
    logger.info(
        f"Connected to ChromaDB collection images with {collection.count()} items")
    
    temp_dir = 'temp_folder'
    if 'cleaned_temp' not in st.session_state:
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
        os.makedirs(temp_dir, exist_ok=True)
        st.session_state.cleaned_temp = True
        results=collection.get(include=["metadatas"])
        ids_to_delete=[
            _id for _id,metadata in zip(results["ids"],results['metadatas']) if metadata.get("path","").startswith("temp")
        ]
        if ids_to_delete:
            collection.delete(ids=ids_to_delete)

    st.title("Extract frames from video using text")
    # Upload section
    st.sidebar.subheader("Upload video")
    video_file = st.sidebar.file_uploader(
        "Upload videos", type=["mp4", "webm", "avi", "mov"], accept_multiple_files=False
    )
    num_images = st.sidebar.slider(
        "Number of images to  be shown", min_value=1, max_value=10, value=3)
    if video_file:
        with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmpfile:
            tmpfile.write(video_file.read())
            video_path = tmpfile.name
        st.video(video_path)
        st.sidebar.subheader("Add uploaded videos to collection")
        if st.sidebar.button("Add uploaded video"):
            extract_frames(video_path)
            insert_into_db(collection, temp_dir)
    else:
        video_path = 'Videos/Video.mp4'
        st.video(video_path)
        st.write(
            f"Video credits: https://www.kaggle.com/datasets/icebearisin/raw-skates")

    st.write("Enter the description of image to be  extracted from the video")
    text_input = st.text_input("Description", "Flying Skater")
    if st.button("Search"):
        if text_input.strip():
            params = {'text': text_input.strip()}
            response = requests.get(
                'https://ashish-001-text-embedding-api.hf.space/embedding', params=params)
            if response.status_code == 200:
                logger.info("Embedding returned by API successfully")
                data = json.loads(response.content)
                embedding = data['embedding']
                results = collection.query(
                    query_embeddings=[embedding],
                    n_results=num_images
                )
                images = [results['metadatas'][0][i]['path']
                          for i in range(len(results['metadatas'][0]))]
                distances = [results['distances'][0][i]
                             for i in range(len(results['metadatas'][0]))]
                if images:
                    cols_per_row = 3
                    rows = (len(images)+cols_per_row-1)//cols_per_row
                    for row in range(rows):
                        cols = st.columns(cols_per_row)
                        for col_idx, col in enumerate(cols):
                            img_idx = row*cols_per_row+col_idx
                            if img_idx < len(images):
                                resized_img = resize_image(
                                    images[img_idx])
                                col.image(resized_img,
                                          caption=f"Image {img_idx+1}", use_container_width=True)
                else:
                    st.write("No image found")
            else:
                st.write("Please try again later")
                logger.info(f"status code {response.status_code} returned")
        else:
            st.write("Please enter text in the text area")

except Exception as e:
    logger.exception(f"Exception occured, {e}")