Spaces:
Running
Running
import re | |
from enum import Enum | |
import streamlit as st | |
import streamlit_tags as st_tags | |
from streamlit import session_state as ss | |
from streamlit.delta_generator import DeltaGenerator | |
from streamlit.runtime.scriptrunner_utils.exceptions import StopException | |
from form.form import build_form_data_from_answers, write_pdf_form, work_categories | |
from llm_manager.llm_parser import LlmParser | |
from local_storage.entities import PersonalDetails, LocationDetails, ContractorDetails, SavedDetails | |
from local_storage.ls_manager import LocalStorageManager | |
from prompts.prompts_manager import PromptsManager | |
from enums import Questions as Q, DetailsType | |
from repository.repository import build_repo_from_environment, get_repository | |
from repository import ModelRoles, Model | |
from utils.parsing_utils import check_for_missing_answers | |
user_msg = "Please describe what you need to do. To get the best results try to answer all the following questions:" | |
find_tags_regex = re.compile(r"@(\S*)") | |
class Steps(Enum): | |
INITIAL_STATE = 1 | |
PARSING_ANSWERS = 2 | |
ASK_AGAIN = 3 | |
FIND_CATEGORIES = 4 | |
VALIDATE_DATA = 5 | |
PARSING_ERROR = 6 | |
FORM_CREATED = 7 | |
def __eq__(self, other): | |
if not isinstance(other, self.__class__): | |
return False | |
return self.value == other.value | |
def __hash__(self): | |
return hash(self.value) | |
class UIManager: | |
def __init__(self): | |
self.pm: PromptsManager = PromptsManager(work_categories=work_categories) | |
self.repository = (build_repo_from_environment(self.pm.system_prompt) or | |
get_repository("testing", | |
Model("fakeModel", | |
ModelRoles("a", "b", "c")))) | |
self.update_in_progress = False | |
self.lsm = None | |
def get_current_step() -> int: | |
try: | |
return ss.get("step") or Steps.INITIAL_STATE.value | |
except StopException: | |
return Steps.INITIAL_STATE.value | |
def set_current_step(step: Steps): | |
ss["step"] = step.value | |
def _build_base_ui(self): | |
st.markdown("## Dubai Asset Management red tape cutter") | |
self.lsm = LocalStorageManager() if not self.lsm else self.lsm | |
with st.sidebar: | |
st.markdown("### Personal details") | |
self.build_details_checkboxes(DetailsType.PERSONAL_DETAILS) | |
st.markdown("### Locations details") | |
self.build_details_checkboxes(DetailsType.LOCATION_DETAILS) | |
st.markdown("### Contractors details") | |
self.build_details_checkboxes(DetailsType.CONTRACTOR_DETAILS) | |
def build_ui_for_initial_state(self, user_message): | |
help_ = user_message + "\n".join(self.pm.questions) | |
self._build_base_ui() | |
with st.form("Please describe your request"): | |
st.text_area("Your input", height=700, label_visibility="hidden", placeholder=help_, | |
help=help_, key="user_input") | |
signature = st.file_uploader("Your signature", key="file_upload") | |
ss["signature"] = signature | |
submit_button = st.form_submit_button() | |
if submit_button: | |
self.set_current_step(Steps.PARSING_ANSWERS) | |
st.rerun() | |
def build_ui_for_parsing_answers(self): | |
self._build_base_ui() | |
with st.status("parsing user input for tags"): | |
tags = find_tags_regex.findall(ss["user_input"]) | |
details = [self.lsm.get_detail(t) for t in tags] | |
with st.status("initialising LLM"): | |
self.repository.init() | |
with st.status("waiting for LLM"): | |
answer = self.repository.send_prompt(self.pm.verify_user_input_prompt(ss["user_input"], details)) | |
st.write(f"answers from LLM: {answer['content']}") | |
with st.status("Checking for missing answers"): | |
answers = LlmParser.parse_verification_prompt_answers(answer['content'], details) | |
ss["answers"] = answers | |
if len(answers) != len(Q): | |
self.set_current_step(Steps.PARSING_ERROR) | |
st.rerun() | |
ss["missing_answers"] = check_for_missing_answers(ss["answers"]) | |
if not ss.get("missing_answers"): | |
self.set_current_step(Steps.FIND_CATEGORIES) | |
else: | |
self.set_current_step(Steps.ASK_AGAIN) | |
st.rerun() | |
def build_ui_for_ask_again(self): | |
self._build_base_ui() | |
with st.form("form1"): | |
for ma in ss["missing_answers"]: | |
st.text_input(self.pm.questions[ma.value].lower(), key=ma) | |
submitted = st.form_submit_button("Submit answers") | |
if submitted: | |
for ma in ss["missing_answers"]: | |
ss["answers"][ma] = ss[ma] | |
self.set_current_step(Steps.FIND_CATEGORIES) | |
st.rerun() | |
def build_ui_for_check_category(self): | |
self._build_base_ui() | |
with st.status("finding the work categories applicable to your work"): | |
answer = self.repository.send_prompt(self.pm.get_work_category(ss["answers"][Q.WORK_TO_DO])) | |
categories = LlmParser.parse_get_categories_answer(answer['content']) | |
ss["categories"] = categories | |
self.set_current_step(Steps.VALIDATE_DATA) | |
st.rerun() | |
def build_ui_for_form_created(self): | |
self._build_base_ui() | |
st.download_button("download form", ss["pdf_form"], | |
file_name=ss["pdf_form_filename"], mime="application/pdf") | |
start_over_button = st.button("Start over") | |
if start_over_button: | |
del ss["step"] | |
del ss["pdf_form"] | |
del ss["pdf_form_filename"] | |
if "signature" in ss: | |
del ss["signature"] | |
st.rerun() | |
def _integrate_llm_answers_with_user_corrections(self): | |
for i in range(len(Q)): | |
ss["answers"][Q(i)] = ss[f"fq_{Q(i).name}"] | |
for details_key, func in [("your_details", self._get_personal_details), | |
("location_details", self._get_location_details), | |
("contractor_details", self._get_contractor_details)]: | |
details = func(details_key) | |
if details: | |
key = ss[details_key] # get the name under which this data should be saved | |
self.lsm.save_details(details, key) | |
self.set_current_step(Steps.FIND_CATEGORIES) | |
st.rerun() | |
def _create_pdf_form(self): | |
with st.status("categories found, creating PDF form"): | |
form_data, filename = build_form_data_from_answers(ss["answers"], ss["categories"], | |
ss.get("signature")) | |
pdf_form = write_pdf_form(form_data) | |
pdf_form_filename = filename | |
ss["pdf_form"] = pdf_form | |
ss["pdf_form_filename"] = pdf_form_filename | |
self.set_current_step(Steps.FORM_CREATED) | |
st.rerun() | |
def build_ui_for_validate_data_after_correction(self): | |
self._build_validation_form(False, self._integrate_llm_answers_with_user_corrections, | |
"Find work categories") | |
def build_ui_to_confirm_form_data(self): | |
self._build_validation_form(True, self._create_pdf_form, | |
"Create work permit request") | |
def _get_personal_details(personal_details_key) -> PersonalDetails | None: | |
key_ = ss.get(personal_details_key) | |
if key_: | |
details = PersonalDetails(ss[f"fq_{Q.FULL_NAME.name}"], ss[f"fq_{Q.YOUR_EMAIL.name}"], | |
ss[f"fq_{Q.CONTACT_NUMBER.name}"]) | |
return details | |
return None | |
def _get_location_details(location_details_key) -> LocationDetails | None: | |
if ss.get(location_details_key): | |
return LocationDetails(ss[f"fq_{Q.OWNER_OR_TENANT.name}"], ss[f"fq_{Q.COMMUNITY.name}"], | |
ss[f"fq_{Q.BUILDING.name}"], ss[f"fq_{Q.UNIT_APT_NUMBER.name}"]) | |
return None | |
def _get_contractor_details(contractor_details_key) -> ContractorDetails | None: | |
if ss.get(contractor_details_key): | |
return ContractorDetails(ss[f"fq_{Q.COMPANY_NAME.name}"], ss[f"fq_{Q.COMPANY_NUMBER.name}"], | |
ss[f"fq_{Q.COMPANY_EMAIL.name}"]) | |
return None | |
def _build_validation_form(self, show_categories: bool, onsubmit, submit_button_label): | |
def build_form_fragment(form_, col, title, add_save, *questions): | |
form_.text(title) | |
for user_data in questions: | |
with col: | |
form_.text_input(self.pm.questions_to_field_labels()[user_data], value=ss.get("answers", {}) | |
.get(user_data), key=f"fq_{user_data.name}") | |
if add_save: | |
with col: | |
form_.text_input("Save as", key=title.replace(" ", "_")) | |
self._build_base_ui() | |
f = st.form("Please check the following information and correct fix any inaccuracies") | |
col1, col2 = f.columns(2) | |
build_form_fragment(f, col1, "your details", True, Q.FULL_NAME, Q.CONTACT_NUMBER, Q.YOUR_EMAIL) | |
build_form_fragment(f, col2, "work details", False, Q.WORK_TO_DO, Q.START_DATE, Q.END_DATE) | |
build_form_fragment(f, col1, "location details", True, Q.COMMUNITY, Q.BUILDING, Q.UNIT_APT_NUMBER, | |
Q.OWNER_OR_TENANT) | |
build_form_fragment(f, col2, "contractor details", True, Q.COMPANY_NAME, Q.COMPANY_NUMBER, Q.COMPANY_EMAIL) | |
if show_categories: | |
for k, wc in work_categories.items(): | |
f.checkbox(label=wc, key=k, value=k in ss["categories"]) | |
submit_data = f.form_submit_button(label=submit_button_label) | |
if submit_data: | |
onsubmit() | |
def build_details_checkboxes(self, dt: DetailsType): | |
details = self.lsm.get_details(dt) | |
with st.container(border=True): | |
col1, col2 = st.columns(2) | |
with col1: | |
st.markdown(f"#### {dt.title()}") | |
with col2: | |
st.markdown("#### Default") | |
for d in details: | |
with col1: | |
st.checkbox(label=d.short_description(), key=f"{dt.name}_{d.key}", | |
on_change=self._update_user_prompt, args=[dt, d.key]) | |
with col2: | |
st.toggle(f"favourite_{d.key}", label_visibility="hidden", value=ss.get(f"DEFAULT_{dt.name}")) | |
add_new = st.button(f"Add {dt.title()}") | |
if add_new: | |
self.add_new_detail_dialog(dt) | |
def add_new_detail_dialog(self, type_: DetailsType): | |
if type_.name == DetailsType.CONTRACTOR_DETAILS.name: | |
new_item = ContractorDetails() | |
elif type_.name == DetailsType.PERSONAL_DETAILS.name: | |
new_item = PersonalDetails() | |
else: | |
new_item = LocationDetails() | |
with st.form("new item", border=False): | |
fields_labels = new_item.widget_labels() | |
for k,v in fields_labels.items(): | |
st.text_input(label=v, key=k) | |
btn_save = st.form_submit_button("Save") | |
if btn_save: | |
for k in fields_labels: | |
setattr(new_item, k, ss[k]) | |
self.lsm.save_details(new_item) | |
st.rerun() | |
def _update_user_prompt(self, type_: DetailsType, key: str): | |
if not self.update_in_progress: | |
self.update_in_progress = True | |
checkbox_key = f"{type_.name}_{key}" | |
if ss.get(checkbox_key) is True: | |
# if the checkbox is _selected_ | |
to_deselect = [d for d in self.lsm.get_details(type_) if d.key != key] | |
for td in to_deselect: | |
# deselect other checkbox in the same type | |
ss[f"{type_.name}_{td}"] = False | |
if f"@{td}" in ss.get("user_input"): | |
# remove the key associated with this checkbox from the user input textarea | |
ss["user_input"] = ss.get["user_input"].replace(f"@{td}", "") | |
# add the key associated to the newly selected checkbox in the user input textarea | |
ss["user_input"] = f"{ss.get('user_input', '')} @{key}".strip() | |
else: | |
# remove the key associated to the newly deselected checkbox in the user input textarea | |
if f"@{key}" in ss.get("user_input"): | |
ss["user_input"] = ss.get("user_input", "").replace(f"@{key}", "").strip() | |
self.update_in_progress = False | |
um = UIManager() | |
def use_streamlit(): | |
if um.get_current_step() == Steps.INITIAL_STATE.value: | |
um.build_ui_for_initial_state(user_msg) | |
elif um.get_current_step() == Steps.PARSING_ANSWERS.value: | |
um.build_ui_for_parsing_answers() | |
elif um.get_current_step() == Steps.PARSING_ERROR.value: | |
um.build_ui_for_validate_data_after_correction() | |
elif um.get_current_step() == Steps.ASK_AGAIN.value: | |
um.build_ui_for_ask_again() | |
elif um.get_current_step() == Steps.FIND_CATEGORIES.value: | |
um.build_ui_for_check_category() | |
elif um.get_current_step() == Steps.VALIDATE_DATA.value: | |
um.build_ui_to_confirm_form_data() | |
elif um.get_current_step() == Steps.FORM_CREATED.value: | |
um.build_ui_for_form_created() | |
use_streamlit() | |