arjunanand13 commited on
Commit
114aae4
·
verified ·
1 Parent(s): 106a698

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -43
app.py CHANGED
@@ -1,51 +1,144 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- from process import inference
 
 
4
 
5
- def clickit(video, prompt):
6
- return inference(
7
- video,
8
- prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- with gr.Blocks() as blok:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  with gr.Row():
13
  with gr.Column():
14
- video = gr.Video(
15
- label="video input",
16
- value="delay_tire.mp4",
17
- )
18
- prompt = gr.Text(
19
- label="Prompt",
20
- value="""You are an expert AI model trained to analyze and interpret manufacturing processes. The task is to evaluate video footage of specific steps in a tire manufacturing process. The process has 8 total steps, but only delayed steps are provided for analysis.
21
-
22
- **Your Goal:**
23
- 1. Analyze the provided video.
24
- 2. Identify possible reasons for the delay in the manufacturing step shown in the video.
25
- 3. Provide a clear explanation of the delay based on observed factors, such as machinery issues, material handling delays, operator inefficiency, or other possible causes.
26
-
27
- **Input:**
28
- - The attached video shows a single step from the tire manufacturing process.
29
- - Context: Tire manufacturing involves 8 steps, and delays may occur due to machinery faults, raw material availability, labor efficiency, or unexpected disruptions.
30
-
31
- **Output:**
32
- Explain why the delay occurred in this step. Include specific observations and their connection to the delay.
33
-
34
- **Example Output:**
35
- "The delay in this step seems to be caused by improper alignment of the conveyor belt, which slowed down the transfer of materials. Additionally, manual intervention by operators was observed, which could indicate a malfunctioning automated system."
36
-
37
- **Important:**
38
- Focus on observations from the video and provide reasons that align with the context of tire manufacturing.
39
- """
40
- )
41
  with gr.Column():
42
- button = gr.Button("Reason for delay", variant="primary")
43
- text = gr.Text(label="Output")
44
-
45
- button.click(
46
- fn=clickit,
47
- inputs=[video, prompt],
48
- outputs=[text]
49
- )
50
-
51
- blok.launch()
 
 
1
  import gradio as gr
2
+ import io
3
+ import numpy as np
4
+ import torch
5
+ from decord import cpu, VideoReader, bridge
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from transformers import BitsAndBytesConfig
8
+ import json
9
 
10
+ MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
11
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
13
 
14
+ DELAY_REASONS = {
15
+ "step1": {"reasons": ["No raw material available", "Person repatching the tire"]},
16
+ "step2": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
17
+ "step3": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
18
+ "step4": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
19
+ "step5": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
20
+ "step6": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
21
+ "step7": {"reasons": ["Person repatching the tire", "Lack of raw material"]},
22
+ "step8": {"reasons": ["No person available to collect tire", "Person repatching the tire"]}
23
+ }
24
+
25
+ with open('delay_reasons.json', 'w') as f:
26
+ json.dump(DELAY_REASONS, f, indent=4)
27
+
28
+ def load_video(video_data, strategy='chat'):
29
+ bridge.set_bridge('torch')
30
+ mp4_stream = video_data
31
+ num_frames = 24
32
+ decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
33
+ frame_id_list = []
34
+ total_frames = len(decord_vr)
35
+ timestamps = [i[0] for i in decord_vr.get_frame_timestamp(np.arange(total_frames))]
36
+ max_second = round(max(timestamps)) + 1
37
+
38
+ for second in range(max_second):
39
+ closest_num = min(timestamps, key=lambda x: abs(x - second))
40
+ index = timestamps.index(closest_num)
41
+ frame_id_list.append(index)
42
+ if len(frame_id_list) >= num_frames:
43
+ break
44
+
45
+ video_data = decord_vr.get_batch(frame_id_list)
46
+ video_data = video_data.permute(3, 0, 1, 2)
47
+ return video_data
48
+
49
+ def load_model():
50
+ quantization_config = BitsAndBytesConfig(
51
+ load_in_4bit=True,
52
+ bnb_4bit_compute_dtype=TORCH_TYPE,
53
+ bnb_4bit_use_double_quant=True,
54
+ bnb_4bit_quant_type="nf4"
55
+ )
56
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ MODEL_PATH,
59
+ torch_dtype=TORCH_TYPE,
60
+ trust_remote_code=True,
61
+ quantization_config=quantization_config,
62
+ device_map="auto"
63
+ ).eval()
64
+ return model, tokenizer
65
+
66
+ def predict(prompt, video_data, temperature, model, tokenizer):
67
+ strategy = 'chat'
68
+ video = load_video(video_data, strategy=strategy)
69
+ history = []
70
+ inputs = model.build_conversation_input_ids(
71
+ tokenizer=tokenizer,
72
+ query=prompt,
73
+ images=[video],
74
+ history=history,
75
+ template_version=strategy
76
  )
77
+ inputs = {
78
+ 'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
79
+ 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
80
+ 'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
81
+ 'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
82
+ }
83
+ gen_kwargs = {
84
+ "max_new_tokens": 2048,
85
+ "pad_token_id": 128002,
86
+ "top_k": 1,
87
+ "do_sample": False,
88
+ "top_p": 0.1,
89
+ "temperature": temperature,
90
+ }
91
+ with torch.no_grad():
92
+ outputs = model.generate(**inputs, **gen_kwargs)
93
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
94
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
95
+ return response
96
 
97
+ def get_base_prompt():
98
+ return """You are an expert AI model trained to analyze and interpret manufacturing processes.
99
+ The task is to evaluate video footage of specific steps in a tire manufacturing process.
100
+ The process has 8 total steps, but only delayed steps are provided for analysis.
101
+
102
+ **Your Goal:**
103
+ 1. Analyze the provided video.
104
+ 2. Identify possible reasons for the delay in the manufacturing step shown in the video.
105
+ 3. Provide a clear explanation of the delay based on observed factors.
106
+
107
+ **Context:**
108
+ Tire manufacturing involves 8 steps, and delays may occur due to machinery faults,
109
+ raw material availability, labor efficiency, or unexpected disruptions.
110
+
111
+ **Output:**
112
+ Explain why the delay occurred in this step. Include specific observations
113
+ and their connection to the delay."""
114
+
115
+ def inference(video, step_number, selected_reason):
116
+ if not video:
117
+ return "Please upload a video first."
118
+ model, tokenizer = load_model()
119
+ video_data = video.read()
120
+ base_prompt = get_base_prompt()
121
+ full_prompt = f"{base_prompt}\n\nAnalyzing Step {step_number}\nPossible reason: {selected_reason}"
122
+ temperature = 0.8
123
+ response = predict(full_prompt, video_data, temperature, model, tokenizer)
124
+ return response
125
+
126
+ with gr.Blocks() as demo:
127
  with gr.Row():
128
  with gr.Column():
129
+ video = gr.Video(label="Video Input", source="upload", cache_examples=False)
130
+ step_number = gr.Dropdown(choices=[f"Step {i}" for i in range(1, 9)], label="Manufacturing Step", value="Step 1")
131
+ reason = gr.Dropdown(choices=DELAY_REASONS["step1"]["reasons"], label="Possible Delay Reason", value=DELAY_REASONS["step1"]["reasons"][0])
132
+ analyze_btn = gr.Button("Analyze Delay", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  with gr.Column():
134
+ output = gr.Textbox(label="Analysis Result")
135
+
136
+ def update_reasons(step):
137
+ step_num = step.lower().replace(" ", "")
138
+ return gr.Dropdown(choices=DELAY_REASONS[step_num]["reasons"])
139
+
140
+ step_number.change(fn=update_reasons, inputs=[step_number], outputs=[reason])
141
+ analyze_btn.click(fn=inference, inputs=[video, step_number, reason], outputs=[output])
142
+
143
+ if __name__ == "__main__":
144
+ demo.launch()