jwkirchenbauer commited on
Commit
858fe91
·
1 Parent(s): a98574a

enforce model not null constraint

Browse files
Files changed (1) hide show
  1. demo_watermark.py +24 -10
demo_watermark.py CHANGED
@@ -706,7 +706,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
706
  detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result, session_args,session_tokenizer], api_name="detection")
707
 
708
  # State management logic
709
- # update callbacks that change the state dict
710
  def update_model(session_state, value): session_state.model_name_or_path = value; return session_state
711
  def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
712
  def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
@@ -769,20 +769,35 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
769
  # return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH)
770
  # else:
771
  return AutoTokenizer.from_pretrained(model_name_or_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
772
  # registering callbacks for toggling the visibilty of certain parameters based on the values of others
773
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
774
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
775
  decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
776
- model_selector.change(toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams])
777
  decoding.change(toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams])
778
- model_selector.change(toggle_beams_for_api_model,inputs=[model_selector,n_beams], outputs=[n_beams])
779
- model_selector.change(toggle_interactive_for_api_model,inputs=[model_selector], outputs=[gamma])
780
- model_selector.change(toggle_interactive_for_api_model,inputs=[model_selector], outputs=[delta])
781
- model_selector.change(toggle_gamma_for_api_model,inputs=[model_selector,gamma], outputs=[gamma])
782
- model_selector.change(toggle_delta_for_api_model,inputs=[model_selector,delta], outputs=[delta])
783
- model_selector.change(update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer])
784
  # registering all state update callbacks
785
- model_selector.change(update_model,inputs=[session_args, model_selector], outputs=[session_args])
786
  decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
787
  sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
788
  generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
@@ -798,7 +813,6 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
798
  # register additional callback on button clicks that updates the shown parameters window
799
  generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
800
  detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
801
- model_selector.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
802
  # When the parameters change, display the update and also fire detection, since some detection params dont change the model output.
803
  delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
804
  gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
 
706
  detect_btn.click(fn=detect_partial, inputs=[detection_input,session_args,session_tokenizer], outputs=[detection_result, session_args,session_tokenizer], api_name="detection")
707
 
708
  # State management logic
709
+ # define update callbacks that change the state dict
710
  def update_model(session_state, value): session_state.model_name_or_path = value; return session_state
711
  def update_sampling_temp(session_state, value): session_state.sampling_temp = float(value); return session_state
712
  def update_generation_seed(session_state, value): session_state.generation_seed = int(value); return session_state
 
769
  # return ALPACA_MODEL_TOKENIZER.from_pretrained(ALPACA_TOKENIZER_PATH)
770
  # else:
771
  return AutoTokenizer.from_pretrained(model_name_or_path)
772
+
773
+ def check_model(value): return value if (value!="" and value is not None) else args.model_name_or_path
774
+ # enforce constraint that model cannot be null or empty
775
+ # then attach model callbacks in particular
776
+ model_selector.change(check_model, inputs=[model_selector], outputs=[model_selector]).then(
777
+ toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams]
778
+ ).then(
779
+ toggle_beams_for_api_model,inputs=[model_selector,n_beams], outputs=[n_beams]
780
+ ).then(
781
+ toggle_interactive_for_api_model,inputs=[model_selector], outputs=[gamma]
782
+ ).then(
783
+ toggle_interactive_for_api_model,inputs=[model_selector], outputs=[delta]
784
+ ).then(
785
+ toggle_gamma_for_api_model,inputs=[model_selector,gamma], outputs=[gamma]
786
+ ).then(
787
+ toggle_delta_for_api_model,inputs=[model_selector,delta], outputs=[delta]
788
+ ).then(
789
+ update_tokenizer,inputs=[model_selector], outputs=[session_tokenizer]
790
+ ).then(
791
+ update_model,inputs=[session_args, model_selector], outputs=[session_args]
792
+ ).then(
793
+ lambda value: str(value), inputs=[session_args], outputs=[current_parameters]
794
+ )
795
  # registering callbacks for toggling the visibilty of certain parameters based on the values of others
796
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[sampling_temp])
797
  decoding.change(toggle_sampling_vis,inputs=[decoding], outputs=[generation_seed])
798
  decoding.change(toggle_sampling_vis_inv,inputs=[decoding], outputs=[n_beams])
 
799
  decoding.change(toggle_vis_for_api_model,inputs=[model_selector], outputs=[n_beams])
 
 
 
 
 
 
800
  # registering all state update callbacks
 
801
  decoding.change(update_decoding,inputs=[session_args, decoding], outputs=[session_args])
802
  sampling_temp.change(update_sampling_temp,inputs=[session_args, sampling_temp], outputs=[session_args])
803
  generation_seed.change(update_generation_seed,inputs=[session_args, generation_seed], outputs=[session_args])
 
813
  # register additional callback on button clicks that updates the shown parameters window
814
  generate_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
815
  detect_btn.click(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
 
816
  # When the parameters change, display the update and also fire detection, since some detection params dont change the model output.
817
  delta.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])
818
  gamma.change(lambda value: str(value), inputs=[session_args], outputs=[current_parameters])