Spaces:
Runtime error
Runtime error
Commit
·
858fe91
1
Parent(s):
a98574a
enforce model not null constraint
Browse files- demo_watermark.py +24 -10
demo_watermark.py
CHANGED
@@ -706,7 +706,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
706 |
detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result, session_args,session_tokenizer], api_name="detection")
|
707 |
|
708 |
# State management logic
|
709 |
-
# update callbacks that change the state dict
|
710 |
def update_model(session_state, value): session_state.model_name_or_path = value; return session_state
|
711 |
def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
|
712 |
def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
|
@@ -769,20 +769,35 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
769 |
# return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH)
|
770 |
# else:
|
771 |
return AutoTokenizer.from_pretrained(model_name_or_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
772 |
# registering callbacks for toggling the visibilty of certain parameters based on the values of others
|
773 |
decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
|
774 |
decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
|
775 |
decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
|
776 |
-
model_selector.change(toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams])
|
777 |
decoding.change(toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams])
|
778 |
-
model_selector.change(toggle_beams_for_api_model,inputs=[model_selector,n_beams], outputs=[n_beams])
|
779 |
-
model_selector.change(toggle_interactive_for_api_model,inputs=[model_selector], outputs=[gamma])
|
780 |
-
model_selector.change(toggle_interactive_for_api_model,inputs=[model_selector], outputs=[delta])
|
781 |
-
model_selector.change(toggle_gamma_for_api_model,inputs=[model_selector,gamma], outputs=[gamma])
|
782 |
-
model_selector.change(toggle_delta_for_api_model,inputs=[model_selector,delta], outputs=[delta])
|
783 |
-
model_selector.change(update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer])
|
784 |
# registering all state update callbacks
|
785 |
-
model_selector.change(update_model,inputs=[session_args, model_selector], outputs=[session_args])
|
786 |
decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
|
787 |
sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
|
788 |
generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
|
@@ -798,7 +813,6 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
|
|
798 |
# register additional callback on button clicks that updates the shown parameters window
|
799 |
generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
800 |
detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
801 |
-
model_selector.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
802 |
# When the parameters change, display the update and also fire detection, since some detection params dont change the model output.
|
803 |
delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
804 |
gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
|
|
706 |
detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result, session_args,session_tokenizer], api_name="detection")
|
707 |
|
708 |
# State management logic
|
709 |
+
# define update callbacks that change the state dict
|
710 |
def update_model(session_state, value): session_state.model_name_or_path = value; return session_state
|
711 |
def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
|
712 |
def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
|
|
|
769 |
# return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH)
|
770 |
# else:
|
771 |
return AutoTokenizer.from_pretrained(model_name_or_path)
|
772 |
+
|
773 |
+
def check_model(value): return value if (value!="" and value is not None) else args.model_name_or_path
|
774 |
+
# enforce constraint that model cannot be null or empty
|
775 |
+
# then attach model callbacks in particular
|
776 |
+
model_selector.change(check_model, inputs=[model_selector], outputs=[model_selector]).then(
|
777 |
+
toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams]
|
778 |
+
).then(
|
779 |
+
toggle_beams_for_api_model,inputs=[model_selector,n_beams], outputs=[n_beams]
|
780 |
+
).then(
|
781 |
+
toggle_interactive_for_api_model,inputs=[model_selector], outputs=[gamma]
|
782 |
+
).then(
|
783 |
+
toggle_interactive_for_api_model,inputs=[model_selector], outputs=[delta]
|
784 |
+
).then(
|
785 |
+
toggle_gamma_for_api_model,inputs=[model_selector,gamma], outputs=[gamma]
|
786 |
+
).then(
|
787 |
+
toggle_delta_for_api_model,inputs=[model_selector,delta], outputs=[delta]
|
788 |
+
).then(
|
789 |
+
update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer]
|
790 |
+
).then(
|
791 |
+
update_model,inputs=[session_args, model_selector], outputs=[session_args]
|
792 |
+
).then(
|
793 |
+
lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
|
794 |
+
)
|
795 |
# registering callbacks for toggling the visibilty of certain parameters based on the values of others
|
796 |
decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
|
797 |
decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
|
798 |
decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
|
|
|
799 |
decoding.change(toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams])
|
|
|
|
|
|
|
|
|
|
|
|
|
800 |
# registering all state update callbacks
|
|
|
801 |
decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
|
802 |
sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
|
803 |
generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
|
|
|
813 |
# register additional callback on button clicks that updates the shown parameters window
|
814 |
generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
815 |
detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
|
|
816 |
# When the parameters change, display the update and also fire detection, since some detection params dont change the model output.
|
817 |
delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|
818 |
gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
|