VishalD1234 commited on
Commit
3f60308
·
verified ·
1 Parent(s): c33b13b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -10,6 +10,13 @@ MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
10
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
11
  TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
12
 
 
 
 
 
 
 
 
13
  # Delay Reasons for Each Manufacturing Step
14
  DELAY_REASONS = {
15
  "Step 1": ["Delay in Bead Insertion", "Lack of raw material"],
@@ -100,15 +107,14 @@ def get_step_info(step_number):
100
  def load_video(video_data, strategy='chat'):
101
  """Loads and processes video data into a format suitable for model input."""
102
  bridge.set_bridge('torch')
103
- num_frames = 24
104
-
105
  if isinstance(video_data, str):
106
  decord_vr = VideoReader(video_data, ctx=cpu(0))
107
  else:
108
  decord_vr = VideoReader(io.BytesIO(video_data), ctx=cpu(0))
109
 
110
  total_frames = len(decord_vr)
111
- if total_frames < num_frames:
112
  raise ValueError("Uploaded video is too short for meaningful analysis.")
113
 
114
  timestamps = [i[0] for i in decord_vr.get_frame_timestamp(np.arange(total_frames))]
@@ -119,7 +125,7 @@ def load_video(video_data, strategy='chat'):
119
  closest_num = min(timestamps, key=lambda x: abs(x - second))
120
  index = timestamps.index(closest_num)
121
  frame_id_list.append(index)
122
- if len(frame_id_list) >= num_frames:
123
  break
124
 
125
  video_data = decord_vr.get_batch(frame_id_list)
@@ -148,7 +154,10 @@ def load_model():
148
 
149
  def predict(prompt, video_data, temperature, model, tokenizer):
150
  """Generates predictions based on the video and textual prompt."""
151
- video = load_video(video_data, strategy='chat')
 
 
 
152
 
153
  inputs = model.build_conversation_input_ids(
154
  tokenizer=tokenizer,
@@ -166,12 +175,12 @@ def predict(prompt, video_data, temperature, model, tokenizer):
166
  }
167
 
168
  gen_kwargs = {
169
- "max_new_tokens": 2048,
170
  "pad_token_id": tokenizer.pad_token_id,
171
- "top_k": 1,
172
  "do_sample": False,
173
- "top_p": 0.1,
174
- "temperature": 0.3,
175
  }
176
 
177
  with torch.no_grad():
@@ -208,5 +217,5 @@ Potential Delay Reasons:
208
 
209
  Task: Analyze the provided video to identify the delay reason. Use the following format:
210
  1. **Selected Reason:** [Choose the most likely reason from the list above]
211
- 2. **Visual Evidence:** [Describe specific visual cues from the
212
  """
 
10
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
11
  TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
12
 
13
+ # Configurable constants
14
+ NUM_FRAMES = 24 # Default number of frames to extract
15
+ MAX_NEW_TOKENS = 2048
16
+ TOP_K = 1
17
+ TOP_P = 0.1
18
+ DEFAULT_TEMPERATURE = 1.0
19
+
20
  # Delay Reasons for Each Manufacturing Step
21
  DELAY_REASONS = {
22
  "Step 1": ["Delay in Bead Insertion", "Lack of raw material"],
 
107
  def load_video(video_data, strategy='chat'):
108
  """Loads and processes video data into a format suitable for model input."""
109
  bridge.set_bridge('torch')
110
+
 
111
  if isinstance(video_data, str):
112
  decord_vr = VideoReader(video_data, ctx=cpu(0))
113
  else:
114
  decord_vr = VideoReader(io.BytesIO(video_data), ctx=cpu(0))
115
 
116
  total_frames = len(decord_vr)
117
+ if total_frames < NUM_FRAMES:
118
  raise ValueError("Uploaded video is too short for meaningful analysis.")
119
 
120
  timestamps = [i[0] for i in decord_vr.get_frame_timestamp(np.arange(total_frames))]
 
125
  closest_num = min(timestamps, key=lambda x: abs(x - second))
126
  index = timestamps.index(closest_num)
127
  frame_id_list.append(index)
128
+ if len(frame_id_list) >= NUM_FRAMES:
129
  break
130
 
131
  video_data = decord_vr.get_batch(frame_id_list)
 
154
 
155
  def predict(prompt, video_data, temperature, model, tokenizer):
156
  """Generates predictions based on the video and textual prompt."""
157
+ try:
158
+ video = load_video(video_data, strategy='chat')
159
+ except ValueError as e:
160
+ return f"Error loading video: {str(e)}"
161
 
162
  inputs = model.build_conversation_input_ids(
163
  tokenizer=tokenizer,
 
175
  }
176
 
177
  gen_kwargs = {
178
+ "max_new_tokens": MAX_NEW_TOKENS,
179
  "pad_token_id": tokenizer.pad_token_id,
180
+ "top_k": TOP_K,
181
  "do_sample": False,
182
+ "top_p": TOP_P,
183
+ "temperature": temperature or DEFAULT_TEMPERATURE,
184
  }
185
 
186
  with torch.no_grad():
 
217
 
218
  Task: Analyze the provided video to identify the delay reason. Use the following format:
219
  1. **Selected Reason:** [Choose the most likely reason from the list above]
220
+ 2. **Visual Evidence:** [Describe specific visual cues from the video that support your analysis.]
221
  """