Spaces:
Running
Running
File size: 5,532 Bytes
18d1852 e46d486 18d1852 753c201 1d51bf5 18d1852 a696a1a 18d1852 7d71f1b 2152f1f c03044f 1c8b5fe 18d1852 2152f1f 41a01e5 c03044f 2152f1f 2957e90 2152f1f f51ceea 18d1852 d80fd56 18d1852 2957e90 18d1852 d0a09f4 18d1852 d0a09f4 18d1852 d0e9fe6 753c201 d0e9fe6 d9364fd d0e9fe6 753c201 18d1852 105e89e 18d1852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import pandas as pd
import copy
import streamlit as st
from my_model.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model
class StateManager:
# def __init__(self):
# self.initialize_state()
def initialize_state(self):
if 'images_data' not in st.session_state:
st.session_state['images_data'] = {}
if 'method' not in st.session_state:
st.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
if 'detection_model' not in st.session_state:
self.detection_model = st.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
if 'kbvqa' not in st.session_state:
st.session_state['kbvqa'] = None
self.default_confidence = 0.2 if self.detection_model == "yolov5" else 0.4
if 'confidence_level' not in st.session_state:
self.set_slider_value(text="Select minimum detection confidence level",
min_value=0.1,
max_value=0.9,
value=self.default_confidence,
step=0.1,
slider_key_name='confidence_level'
)
# confidence_level = st.slider("Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.1)
def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name):
return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
def check_settings_changed(self, current_selected_method, current_detection_model, current_confidence_level):
return (st.session_state['model_settings']['detection_model'] != current_detection_model or
st.session_state['model_settings']['confidence_level'] != current_confidence_level or
st.session_state['model_settings']['selected_method'] != current_selected_method)
def display_model_settings(self):
st.write("### Current Model Settings:")
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa']]
st.table(pd.DataFrame(data))
def display_session_state(self):
st.write("### Current Model:")
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
df = pd.DataFrame(data)
st.table(df)
def load_model(self):
"""Load the KBVQA model with specified settings."""
try:
free_gpu_resources()
st.text("Loading the model, this should take no more than a few minutes, please wait...")
st.session_state['kbvqa'] = prepare_kbvqa_model(st.session_state.detection_model)
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
#self.update_model_settings(detection_model, confidence_level)
st.text("Model is ready for inference.")
free_gpu_resources()
except Exception as e:
st.error(f"Error loading model: {e}")
def get_model(self):
"""Retrieve the KBVQA model from the session state."""
return st.session_state.get('kbvqa', None)
def is_model_loaded(self):
return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
def reload_detection_model(self, detection_model, confidence_level):
try:
free_gpu_resources()
if self.is_model_loaded():
prepare_kbvqa_model(detection_model, only_reload_detection_model=True)
st.session_state['kbvqa'].detection_confidence = confidence_level
#self.update_model_settings(detection_model, confidence_level)
free_gpu_resources()
except Exception as e:
st.error(f"Error reloading detection model: {e}")
# New methods to be added
def process_new_image(self, image_key, image, kbvqa):
if image_key not in st.session_state['images_data']:
st.session_state['images_data'][image_key] = {
'image': image,
'caption': '',
'detected_objects_str': '',
'qa_history': [],
'analysis_done': False
}
def analyze_image(self, image, kbvqa):
img = copy.deepcopy(image)
caption = kbvqa.get_caption(img)
image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
return caption, detected_objects_str, image_with_boxes
def add_to_qa_history(self, image_key, question, answer):
if image_key in st.session_state['images_data']:
st.session_state['images_data'][image_key]['qa_history'].append((question, answer))
def get_images_data(self):
return st.session_state['images_data']
def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
if image_key in st.session_state['images_data']:
st.session_state['images_data'][image_key].update({
'caption': caption,
'detected_objects_str': detected_objects_str,
'analysis_done': analysis_done
})
|