MegaTronX commited on
Commit
50f0254
·
verified ·
1 Parent(s): 5f23d3c

Update joycaption.py

Browse files
Files changed (1) hide show
  1. joycaption.py +4 -4
joycaption.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
- if os.environ.get("SPACES_ZERO_GPU") is not None:
3
- import spaces
4
  else:
5
  class spaces:
6
  @staticmethod
@@ -266,7 +266,7 @@ load_text_model(MODEL_PATH, None, LOAD_IN_NF4, True)
266
  #print(f"pixtral_processor: {type(pixtral_processor)}") #
267
 
268
 
269
- @spaces.GPU()
270
  @torch.inference_mode()
271
  def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_length: Union[str, int], extra_options: list[str], name_input: str, custom_prompt: str,
272
  max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> tuple[str, str]:
@@ -469,7 +469,7 @@ def get_repo_gguf(repo_id: str):
469
  else: return gr.update(value=files[0], choices=files)
470
 
471
 
472
- @spaces.GPU
473
  def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, gguf_file: Union[str, None]=None,
474
  is_nf4: bool=True, is_lora: bool=True, progress=gr.Progress(track_tqdm=True)):
475
  global use_inference_client, llm_models
 
1
  import os
2
+ #if os.environ.get("SPACES_ZERO_GPU") is not None:
3
+ #import spaces
4
  else:
5
  class spaces:
6
  @staticmethod
 
266
  #print(f"pixtral_processor: {type(pixtral_processor)}") #
267
 
268
 
269
+ #@spaces.GPU()
270
  @torch.inference_mode()
271
  def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_length: Union[str, int], extra_options: list[str], name_input: str, custom_prompt: str,
272
  max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> tuple[str, str]:
 
469
  else: return gr.update(value=files[0], choices=files)
470
 
471
 
472
+ #@spaces.GPU
473
  def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, gguf_file: Union[str, None]=None,
474
  is_nf4: bool=True, is_lora: bool=True, progress=gr.Progress(track_tqdm=True)):
475
  global use_inference_client, llm_models