Update my_model/tabs/run_inference.py
Browse files
    	
        my_model/tabs/run_inference.py
    CHANGED
    
    | @@ -56,8 +56,10 @@ class InferenceRunner(StateManager): | |
| 56 | 
             
                                    nested_col22.text("Please click 'Analyze Image'..")
         | 
| 57 | 
             
                                    with nested_col22:
         | 
| 58 | 
             
                                        if st.button('Analyze Image', key=f'analyze_{image_key}'):
         | 
|  | |
| 59 | 
             
                                            caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'], kbvqa)
         | 
| 60 | 
             
                                            self.update_image_data(image_key, caption, detected_objects_str, True)
         | 
|  | |
| 61 |  | 
| 62 | 
             
                                # Initialize qa_history for each image
         | 
| 63 | 
             
                                qa_history = image_data.get('qa_history', [])
         | 
| @@ -65,10 +67,13 @@ class InferenceRunner(StateManager): | |
| 65 | 
             
                                if image_data['analysis_done']:
         | 
| 66 | 
             
                                    question = nested_col22.text_input(f"Ask a question about this image ({image_key[-11:]}):", key=f'question_{image_key}')
         | 
| 67 | 
             
                                    if nested_col22.button('Get Answer', key=f'answer_{image_key}'):
         | 
|  | |
| 68 | 
             
                                        if question not in [q for q, _ in qa_history]:
         | 
| 69 | 
             
                                            answer = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question, kbvqa)
         | 
| 70 | 
             
                                            self.add_to_qa_history(image_key, question, answer)
         | 
| 71 | 
             
                                        else: nested_col22.warning("This questions has already been answered.")
         | 
|  | |
|  | |
| 72 |  | 
| 73 | 
             
                                # Display Q&A history for each image
         | 
| 74 | 
             
                                for q, a in qa_history:
         | 
| @@ -99,7 +104,7 @@ class InferenceRunner(StateManager): | |
| 99 |  | 
| 100 | 
             
                            with st.container():
         | 
| 101 | 
             
                                nested_col11, nested_col12 = st.columns([0.5, 0.5])
         | 
| 102 | 
            -
                                if nested_col11.button(st.session_state.button_label):
         | 
| 103 |  | 
| 104 | 
             
                                    if st.session_state.button_label == "Load Model":
         | 
| 105 | 
             
                                        st.session_state['load_button_clicked'] = True
         | 
| @@ -111,7 +116,7 @@ class InferenceRunner(StateManager): | |
| 111 | 
             
                                    else:
         | 
| 112 | 
             
                                        reload_detection_model = True
         | 
| 113 |  | 
| 114 | 
            -
                                if nested_col12.button("Force Reload"):
         | 
| 115 | 
             
                                    force_reload_full_model = True
         | 
| 116 |  | 
| 117 | 
             
                            if load_fine_tuned_model:
         | 
|  | |
| 56 | 
             
                                    nested_col22.text("Please click 'Analyze Image'..")
         | 
| 57 | 
             
                                    with nested_col22:
         | 
| 58 | 
             
                                        if st.button('Analyze Image', key=f'analyze_{image_key}'):
         | 
| 59 | 
            +
                                            st.session_state['loading_in_progress'] = True
         | 
| 60 | 
             
                                            caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'], kbvqa)
         | 
| 61 | 
             
                                            self.update_image_data(image_key, caption, detected_objects_str, True)
         | 
| 62 | 
            +
                                        st.session_state['loading_in_progress'] = False
         | 
| 63 |  | 
| 64 | 
             
                                # Initialize qa_history for each image
         | 
| 65 | 
             
                                qa_history = image_data.get('qa_history', [])
         | 
|  | |
| 67 | 
             
                                if image_data['analysis_done']:
         | 
| 68 | 
             
                                    question = nested_col22.text_input(f"Ask a question about this image ({image_key[-11:]}):", key=f'question_{image_key}')
         | 
| 69 | 
             
                                    if nested_col22.button('Get Answer', key=f'answer_{image_key}'):
         | 
| 70 | 
            +
                                        st.session_state['loading_in_progress'] = True
         | 
| 71 | 
             
                                        if question not in [q for q, _ in qa_history]:
         | 
| 72 | 
             
                                            answer = self.answer_question(image_data['caption'], image_data['detected_objects_str'], question, kbvqa)
         | 
| 73 | 
             
                                            self.add_to_qa_history(image_key, question, answer)
         | 
| 74 | 
             
                                        else: nested_col22.warning("This questions has already been answered.")
         | 
| 75 | 
            +
                                            
         | 
| 76 | 
            +
                                        st.session_state['loading_in_progress'] = False
         | 
| 77 |  | 
| 78 | 
             
                                # Display Q&A history for each image
         | 
| 79 | 
             
                                for q, a in qa_history:
         | 
|  | |
| 104 |  | 
| 105 | 
             
                            with st.container():
         | 
| 106 | 
             
                                nested_col11, nested_col12 = st.columns([0.5, 0.5])
         | 
| 107 | 
            +
                                if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
         | 
| 108 |  | 
| 109 | 
             
                                    if st.session_state.button_label == "Load Model":
         | 
| 110 | 
             
                                        st.session_state['load_button_clicked'] = True
         | 
|  | |
| 116 | 
             
                                    else:
         | 
| 117 | 
             
                                        reload_detection_model = True
         | 
| 118 |  | 
| 119 | 
            +
                                if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
         | 
| 120 | 
             
                                    force_reload_full_model = True
         | 
| 121 |  | 
| 122 | 
             
                            if load_fine_tuned_model:
         | 
