Update joycaption.py
Browse files- joycaption.py +4 -4
joycaption.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import os
|
2 |
-
if os.environ.get("SPACES_ZERO_GPU") is not None:
|
3 |
-
import spaces
|
4 |
else:
|
5 |
class spaces:
|
6 |
@staticmethod
|
@@ -266,7 +266,7 @@ load_text_model(MODEL_PATH, None, LOAD_IN_NF4, True)
|
|
266 |
#print(f"pixtral_processor: {type(pixtral_processor)}") #
|
267 |
|
268 |
|
269 |
-
|
270 |
@torch.inference_mode()
|
271 |
def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_length: Union[str, int], extra_options: list[str], name_input: str, custom_prompt: str,
|
272 |
max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> tuple[str, str]:
|
@@ -469,7 +469,7 @@ def get_repo_gguf(repo_id: str):
|
|
469 |
else: return gr.update(value=files[0], choices=files)
|
470 |
|
471 |
|
472 |
-
|
473 |
def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, gguf_file: Union[str, None]=None,
|
474 |
is_nf4: bool=True, is_lora: bool=True, progress=gr.Progress(track_tqdm=True)):
|
475 |
global use_inference_client, llm_models
|
|
|
1 |
import os
|
2 |
+
#if os.environ.get("SPACES_ZERO_GPU") is not None:
|
3 |
+
#import spaces
|
4 |
else:
|
5 |
class spaces:
|
6 |
@staticmethod
|
|
|
266 |
#print(f"pixtral_processor: {type(pixtral_processor)}") #
|
267 |
|
268 |
|
269 |
+
#@spaces.GPU()
|
270 |
@torch.inference_mode()
|
271 |
def stream_chat_mod(input_image: Image.Image, caption_type: str, caption_length: Union[str, int], extra_options: list[str], name_input: str, custom_prompt: str,
|
272 |
max_new_tokens: int=300, top_p: float=0.9, temperature: float=0.6, model_name: str=MODEL_PATH, progress=gr.Progress(track_tqdm=True)) -> tuple[str, str]:
|
|
|
469 |
else: return gr.update(value=files[0], choices=files)
|
470 |
|
471 |
|
472 |
+
#@spaces.GPU
|
473 |
def change_text_model(model_name: str=MODEL_PATH, use_client: bool=False, gguf_file: Union[str, None]=None,
|
474 |
is_nf4: bool=True, is_lora: bool=True, progress=gr.Progress(track_tqdm=True)):
|
475 |
global use_inference_client, llm_models
|