Prathamesh1420 commited on
Commit
6975a6c
·
verified ·
1 Parent(s): 7b6396d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -48
app.py CHANGED
@@ -1,63 +1,105 @@
1
  import streamlit as st
2
  import cv2
3
- import tempfile
4
- import os
5
  import numpy as np
 
 
 
 
6
  from ultralytics import YOLO
7
- from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- def process_video(video_path, model):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  cap = cv2.VideoCapture(video_path)
11
- temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
12
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
13
- out = cv2.VideoWriter(temp_output.name, fourcc, cap.get(cv2.CAP_PROP_FPS),
14
- (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))))
15
 
16
  while cap.isOpened():
17
  ret, frame = cap.read()
18
  if not ret:
19
  break
20
-
21
- results = model(frame)
22
- for result in results:
23
- for box in result.boxes:
24
- x1, y1, x2, y2 = map(int, box.xyxy[0])
25
- label = result.names[int(box.cls[0])]
26
- conf = float(box.conf[0])
27
- cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
28
- cv2.putText(frame, f'{label}: {conf:.2f}', (x1, y1 - 10),
29
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
30
-
 
31
  out.write(frame)
32
-
33
  cap.release()
34
  out.release()
35
- return temp_output.name
36
 
37
- def main():
38
- st.set_page_config(page_title="Bottle Label Checker", page_icon="🍾")
39
- st.title("Bottle Label Checking System using YOLO & Gemini")
40
-
41
- uploaded_video = st.file_uploader("Upload a video", type=["mp4", "avi", "mov", "mkv"])
42
-
43
- if uploaded_video is not None:
44
- temp_video_path = os.path.join(tempfile.gettempdir(), uploaded_video.name)
45
- with open(temp_video_path, "wb") as f:
46
- f.write(uploaded_video.read())
47
-
48
- st.video(temp_video_path)
49
-
50
- model = YOLO("yolov8n.pt") # Load YOLO model
51
-
52
- if st.button("Process Video"):
53
- st.write("Processing video... This may take some time.")
54
- output_path = process_video(temp_video_path, model)
55
-
56
- st.video(output_path)
57
- st.success("Processing complete!")
58
-
59
- with open(output_path, "rb") as file:
60
- st.download_button("Download Processed Video", file, file_name="processed_video.mp4", mime="video/mp4")
61
-
62
- if __name__ == "__main__":
63
- main()
 
1
  import streamlit as st
2
  import cv2
 
 
3
  import numpy as np
4
+ import os
5
+ import time
6
+ import threading
7
+ import base64
8
  from ultralytics import YOLO
9
+ from langchain_core.messages import HumanMessage
10
+ from langchain_google_genai import ChatGoogleGenerativeAI
11
+
12
+ # Set up Google API Key
13
+ os.environ["GOOGLE_API_KEY"] = "" # Replace with your API Key
14
+ gemini_model = ChatGoogleGenerativeAI(model="gemini-1.5-flash")
15
+
16
+ # Load YOLO model
17
+ yolo_model = YOLO("best.pt")
18
+ names = yolo_model.names
19
+
20
+ # Constants for ROI detection
21
+ cx1 = 491
22
+ offset = 8
23
+ current_date = time.strftime("%Y-%m-%d")
24
+ crop_folder = f"crop_{current_date}"
25
+ if not os.path.exists(crop_folder):
26
+ os.makedirs(crop_folder)
27
+ processed_track_ids = set()
28
+
29
+ def encode_image_to_base64(image):
30
+ _, img_buffer = cv2.imencode('.jpg', image)
31
+ return base64.b64encode(img_buffer).decode('utf-8')
32
+
33
+ def analyze_image_with_gemini(current_image):
34
+ if current_image is None:
35
+ return "No image available for analysis."
36
+ current_image_data = encode_image_to_base64(current_image)
37
+ message = HumanMessage(
38
+ content=[
39
+ {"type": "text", "text": "Analyze this image and check if the label is present on the bottle. Return results in a structured format."},
40
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{current_image_data}"}, "description": "Detected product"}
41
+ ]
42
+ )
43
+ try:
44
+ response = gemini_model.invoke([message])
45
+ return response.content
46
+ except Exception as e:
47
+ return f"Error processing image: {e}"
48
 
49
+ def save_crop_image(crop, track_id):
50
+ filename = f"{crop_folder}/{track_id}.jpg"
51
+ cv2.imwrite(filename, crop)
52
+ return filename
53
+
54
+ def process_crop_image(crop, track_id):
55
+ response = analyze_image_with_gemini(crop)
56
+ st.session_state["responses"].append((track_id, response))
57
+
58
+ def process_video(uploaded_file):
59
+ if not uploaded_file:
60
+ return None
61
+
62
+ video_bytes = uploaded_file.read()
63
+ video_path = "uploaded_video.mp4"
64
+ with open(video_path, "wb") as f:
65
+ f.write(video_bytes)
66
+
67
  cap = cv2.VideoCapture(video_path)
68
+ output_path = "output_video.mp4"
69
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
70
+ out = cv2.VideoWriter(output_path, fourcc, 20.0, (1020, 500))
 
71
 
72
  while cap.isOpened():
73
  ret, frame = cap.read()
74
  if not ret:
75
  break
76
+ frame = cv2.resize(frame, (1020, 500))
77
+ results = yolo_model.track(frame, persist=True)
78
+ if results[0].boxes is not None:
79
+ boxes = results[0].boxes.xyxy.int().cpu().tolist()
80
+ track_ids = results[0].boxes.id.int().cpu().tolist() if results[0].boxes.id is not None else [-1] * len(boxes)
81
+ for box, track_id in zip(boxes, track_ids):
82
+ if track_id not in processed_track_ids:
83
+ x1, y1, x2, y2 = box
84
+ crop = frame[y1:y2, x1:x2]
85
+ save_crop_image(crop, track_id)
86
+ threading.Thread(target=process_crop_image, args=(crop, track_id)).start()
87
+ processed_track_ids.add(track_id)
88
  out.write(frame)
 
89
  cap.release()
90
  out.release()
91
+ return output_path
92
 
93
+ st.title("Bottle Label Checking using YOLO & Gemini AI")
94
+ st.sidebar.header("Upload a video")
95
+ uploaded_file = st.sidebar.file_uploader("Choose a video file", type=["mp4", "avi", "mov"])
96
+ if "responses" not in st.session_state:
97
+ st.session_state["responses"] = []
98
+ if uploaded_file:
99
+ st.sidebar.write("Processing...")
100
+ output_video_path = process_video(uploaded_file)
101
+ st.sidebar.success("Processing completed!")
102
+ st.video(output_video_path)
103
+ st.subheader("AI Analysis Results")
104
+ for track_id, response in st.session_state["responses"]:
105
+ st.write(f"**Track ID {track_id}:** {response}")