VishalD1234 commited on
Commit
01e9f83
·
verified ·
1 Parent(s): e1b6daf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -88
app.py CHANGED
@@ -6,12 +6,11 @@ from decord import cpu, VideoReader, bridge
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from transformers import BitsAndBytesConfig
8
 
9
- MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
10
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
11
  TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
12
 
13
 
14
-
15
  def get_step_info(step_number):
16
  """Returns detailed information about a manufacturing step."""
17
  step_details = {
@@ -126,89 +125,6 @@ def get_step_info(step_number):
126
 
127
  return step_details.get(step_number, {"Error": "Invalid step number. Please provide a valid step number."})
128
 
129
-
130
-
131
- def load_video(video_data, strategy='chat'):
132
- """Loads and processes video data into a format suitable for model input."""
133
- bridge.set_bridge('torch')
134
- num_frames = 24
135
-
136
- if isinstance(video_data, str):
137
- decord_vr = VideoReader(video_data, ctx=cpu(0))
138
- else:
139
- decord_vr = VideoReader(io.BytesIO(video_data), ctx=cpu(0))
140
-
141
- frame_id_list = []
142
- total_frames = len(decord_vr)
143
- timestamps = [i[0] for i in decord_vr.get_frame_timestamp(np.arange(total_frames))]
144
- max_second = round(max(timestamps)) + 1
145
-
146
- for second in range(max_second):
147
- closest_num = min(timestamps, key=lambda x: abs(x - second))
148
- index = timestamps.index(closest_num)
149
- frame_id_list.append(index)
150
- if len(frame_id_list) >= num_frames:
151
- break
152
-
153
- video_data = decord_vr.get_batch(frame_id_list)
154
- video_data = video_data.permute(3, 0, 1, 2)
155
- return video_data
156
-
157
- def load_model():
158
- """Loads the pre-trained model and tokenizer with quantization configurations."""
159
- quantization_config = BitsAndBytesConfig(
160
- load_in_4bit=True,
161
- bnb_4bit_compute_dtype=TORCH_TYPE,
162
- bnb_4bit_use_double_quant=True,
163
- bnb_4bit_quant_type="nf4"
164
- )
165
-
166
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
167
- model = AutoModelForCausalLM.from_pretrained(
168
- MODEL_PATH,
169
- torch_dtype=TORCH_TYPE,
170
- trust_remote_code=True,
171
- quantization_config=quantization_config,
172
- device_map="auto"
173
- ).eval()
174
-
175
- return model, tokenizer
176
-
177
- def predict(prompt, video_data, temperature, model, tokenizer):
178
- """Generates predictions based on the video and textual prompt."""
179
- video = load_video(video_data, strategy='chat')
180
-
181
- inputs = model.build_conversation_input_ids(
182
- tokenizer=tokenizer,
183
- query=prompt,
184
- images=[video],
185
- history=[],
186
- template_version='chat'
187
- )
188
-
189
- inputs = {
190
- 'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
191
- 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
192
- 'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
193
- 'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
194
- }
195
-
196
- gen_kwargs = {
197
- "max_new_tokens": 2048,
198
- "pad_token_id": 128002,
199
- "top_k": 1,
200
- "do_sample": False,
201
- "top_p": 0.1,
202
- "temperature": temperature,
203
- }
204
-
205
- with torch.no_grad():
206
- outputs = model.generate(**inputs, **gen_kwargs)
207
- outputs = outputs[:, inputs['input_ids'].shape[1]:]
208
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
209
-
210
- return response
211
-
212
  def get_analysis_prompt(step_number):
213
  """Constructs the prompt for analyzing delay reasons based on the selected step."""
214
  step_info = get_step_info(step_number)
@@ -254,10 +170,43 @@ Output:
254
  No person available to collect tire
255
  """
256
 
257
-
258
-
259
  model, tokenizer = load_model()
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  def inference(video, step_number):
262
  """Analyzes video to predict possible issues based on the manufacturing step."""
263
  try:
@@ -315,4 +264,4 @@ def create_interface():
315
 
316
  if __name__ == "__main__":
317
  demo = create_interface()
318
- demo.queue().launch(share=True)
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from transformers import BitsAndBytesConfig
8
 
9
+ MODEL_PATH = "THUDM/cogvlm2-video-llama3-chat"
10
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
11
  TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
12
 
13
 
 
14
  def get_step_info(step_number):
15
  """Returns detailed information about a manufacturing step."""
16
  step_details = {
 
125
 
126
  return step_details.get(step_number, {"Error": "Invalid step number. Please provide a valid step number."})
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def get_analysis_prompt(step_number):
129
  """Constructs the prompt for analyzing delay reasons based on the selected step."""
130
  step_info = get_step_info(step_number)
 
170
  No person available to collect tire
171
  """
172
 
 
 
173
  model, tokenizer = load_model()
174
 
175
+ def predict(prompt, video_data, temperature, model, tokenizer):
176
+ """Generates predictions based on the video and textual prompt."""
177
+ video = load_video(video_data, strategy='chat')
178
+
179
+ inputs = model.build_conversation_input_ids(
180
+ tokenizer=tokenizer,
181
+ query=prompt,
182
+ images=[video],
183
+ history=[],
184
+ template_version='chat'
185
+ )
186
+
187
+ inputs = {
188
+ 'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
189
+ 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
190
+ 'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
191
+ 'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
192
+ }
193
+
194
+ gen_kwargs = {
195
+ "max_new_tokens": 2048,
196
+ "pad_token_id": 128002,
197
+ "top_k": 1,
198
+ "do_sample": False,
199
+ "top_p": 0.1,
200
+ "temperature": temperature,
201
+ }
202
+
203
+ with torch.no_grad():
204
+ outputs = model.generate(**inputs, **gen_kwargs)
205
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
206
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
207
+
208
+ return response
209
+
210
  def inference(video, step_number):
211
  """Analyzes video to predict possible issues based on the manufacturing step."""
212
  try:
 
264
 
265
  if __name__ == "__main__":
266
  demo = create_interface()
267
+ demo.queue().launch(share=True)