Spaces:
Sleeping
Sleeping
idan shenfeld
commited on
Commit
·
549219e
1
Parent(s):
7ca0870
code cleanup
Browse files- app/app.py +60 -73
app/app.py
CHANGED
@@ -33,16 +33,17 @@ TEXT_ONLY = (
|
|
33 |
|
34 |
def create_inference_client(
|
35 |
model: Optional[str] = None, base_url: Optional[str] = None
|
36 |
-
) -> InferenceClient:
|
37 |
"""Create an InferenceClient instance with the given model or environment settings.
|
38 |
This function will run the model locally if ZERO_GPU is set to True.
|
39 |
This function will run the model locally if ZERO_GPU is set to True.
|
40 |
|
41 |
Args:
|
42 |
model: Optional model identifier to use. If not provided, will use environment settings.
|
|
|
43 |
|
44 |
Returns:
|
45 |
-
InferenceClient
|
46 |
"""
|
47 |
if ZERO_GPU:
|
48 |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
@@ -67,11 +68,17 @@ def create_inference_client(
|
|
67 |
CLIENT = create_inference_client()
|
68 |
|
69 |
|
70 |
-
def
|
71 |
-
"""
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
# Check if persistent storage is available and writable
|
77 |
use_persistent = False
|
@@ -86,35 +93,44 @@ def load_languages() -> dict[str, str]:
|
|
86 |
print("Persistent storage exists but is not writable, falling back to local storage")
|
87 |
use_persistent = False
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
languages_path = local_path
|
94 |
-
|
95 |
-
# If persistent storage is available and writable but file doesn't exist yet,
|
96 |
-
# copy the local file to persistent storage
|
97 |
-
if use_persistent:
|
98 |
-
try:
|
99 |
-
# Ensure local file exists first
|
100 |
-
if local_path.exists():
|
101 |
-
import shutil
|
102 |
-
# Copy the file to persistent storage
|
103 |
-
shutil.copy(local_path, persistent_path)
|
104 |
-
languages_path = persistent_path
|
105 |
-
print(f"Copied languages to persistent storage at {persistent_path}")
|
106 |
-
else:
|
107 |
-
# Create an empty languages file in persistent storage
|
108 |
-
with open(persistent_path, "w", encoding="utf-8") as f:
|
109 |
-
json.dump({"English": "You are a helpful assistant."}, f, ensure_ascii=False, indent=2)
|
110 |
-
languages_path = persistent_path
|
111 |
-
print(f"Created new languages file in persistent storage at {persistent_path}")
|
112 |
-
except Exception as e:
|
113 |
-
print(f"Error setting up persistent storage: {e}")
|
114 |
-
languages_path = local_path # Fall back to local path if any error occurs
|
115 |
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
|
120 |
# Initial load
|
@@ -257,6 +273,7 @@ def add_fake_like_data(
|
|
257 |
|
258 |
@spaces.GPU
|
259 |
def call_pipeline(messages: list, language: str):
|
|
|
260 |
if ZERO_GPU:
|
261 |
# Format the messages using the tokenizer's chat template
|
262 |
tokenizer = CLIENT["tokenizer"]
|
@@ -274,16 +291,14 @@ def call_pipeline(messages: list, language: str):
|
|
274 |
)
|
275 |
|
276 |
# Extract the generated content
|
277 |
-
|
278 |
-
return content
|
279 |
else:
|
280 |
response = CLIENT(
|
281 |
messages,
|
282 |
clean_up_tokenization_spaces=False,
|
283 |
max_length=2000,
|
284 |
)
|
285 |
-
|
286 |
-
return content
|
287 |
|
288 |
|
289 |
def respond(
|
@@ -291,11 +306,12 @@ def respond(
|
|
291 |
language: str,
|
292 |
temperature: Optional[float] = None,
|
293 |
seed: Optional[int] = None,
|
294 |
-
) -> list:
|
295 |
"""Respond to the user message with a system message
|
296 |
|
297 |
Return the history with the new message"""
|
298 |
messages = format_history_as_messages(history)
|
|
|
299 |
if ZERO_GPU:
|
300 |
content = call_pipeline(messages, language)
|
301 |
else:
|
@@ -307,17 +323,7 @@ def respond(
|
|
307 |
temperature=temperature,
|
308 |
)
|
309 |
content = response.choices[0].message.content
|
310 |
-
|
311 |
-
content = call_pipeline(messages, language)
|
312 |
-
else:
|
313 |
-
response = CLIENT.chat.completions.create(
|
314 |
-
messages=messages,
|
315 |
-
max_tokens=2000,
|
316 |
-
stream=False,
|
317 |
-
seed=seed,
|
318 |
-
temperature=temperature,
|
319 |
-
)
|
320 |
-
content = response.choices[0].message.content
|
321 |
message = gr.ChatMessage(role="assistant", content=content)
|
322 |
history.append(message)
|
323 |
return history
|
@@ -510,26 +516,10 @@ def save_new_language(lang_name, system_prompt):
|
|
510 |
"""Save the new language and system prompt to persistent storage if available, otherwise to local file."""
|
511 |
global LANGUAGES # Access the global variable
|
512 |
|
513 |
-
#
|
514 |
-
|
515 |
local_path = Path(__file__).parent / "languages.json"
|
516 |
|
517 |
-
# Check if persistent storage is available and writable
|
518 |
-
use_persistent = False
|
519 |
-
if Path("/data").exists() and Path("/data").is_dir():
|
520 |
-
try:
|
521 |
-
# Test if we can write to the directory
|
522 |
-
test_file = Path("/data/write_test.tmp")
|
523 |
-
test_file.touch()
|
524 |
-
test_file.unlink() # Remove the test file
|
525 |
-
use_persistent = True
|
526 |
-
except (PermissionError, OSError):
|
527 |
-
print("Persistent storage exists but is not writable, falling back to local storage")
|
528 |
-
use_persistent = False
|
529 |
-
|
530 |
-
# Use persistent storage if available and writable, otherwise fall back to local file
|
531 |
-
languages_path = persistent_path if use_persistent else local_path
|
532 |
-
|
533 |
# Load existing languages
|
534 |
if languages_path.exists():
|
535 |
with open(languages_path, "r", encoding="utf-8") as f:
|
@@ -545,7 +535,7 @@ def save_new_language(lang_name, system_prompt):
|
|
545 |
json.dump(data, f, ensure_ascii=False, indent=2)
|
546 |
|
547 |
# If we're using persistent storage, also update the local file as backup
|
548 |
-
if use_persistent and local_path !=
|
549 |
try:
|
550 |
with open(local_path, "w", encoding="utf-8") as f:
|
551 |
json.dump(data, f, ensure_ascii=False, indent=2)
|
@@ -555,11 +545,8 @@ def save_new_language(lang_name, system_prompt):
|
|
555 |
# Update the global LANGUAGES variable with the new data
|
556 |
LANGUAGES.update({lang_name: system_prompt})
|
557 |
|
558 |
-
# Update the dropdown choices
|
559 |
-
new_choices = list(LANGUAGES.keys())
|
560 |
-
|
561 |
# Return a message that will trigger a JavaScript refresh
|
562 |
-
return gr.Group(visible=False), gr.HTML("<script>window.location.reload();</script>"), gr.Dropdown(choices=
|
563 |
|
564 |
|
565 |
css = """
|
|
|
33 |
|
34 |
def create_inference_client(
|
35 |
model: Optional[str] = None, base_url: Optional[str] = None
|
36 |
+
) -> InferenceClient | dict:
|
37 |
"""Create an InferenceClient instance with the given model or environment settings.
|
38 |
This function will run the model locally if ZERO_GPU is set to True.
|
39 |
This function will run the model locally if ZERO_GPU is set to True.
|
40 |
|
41 |
Args:
|
42 |
model: Optional model identifier to use. If not provided, will use environment settings.
|
43 |
+
base_url: Optional base URL for the inference API.
|
44 |
|
45 |
Returns:
|
46 |
+
Either an InferenceClient instance or a dictionary with pipeline and tokenizer
|
47 |
"""
|
48 |
if ZERO_GPU:
|
49 |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
|
|
68 |
CLIENT = create_inference_client()
|
69 |
|
70 |
|
71 |
+
def get_persistent_storage_path(filename: str) -> tuple[Path, bool]:
|
72 |
+
"""Check if persistent storage is available and return the appropriate path.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
filename: The name of the file to check/create
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
A tuple containing (file_path, is_persistent)
|
79 |
+
"""
|
80 |
+
persistent_path = Path("/data") / filename
|
81 |
+
local_path = Path(__file__).parent / filename
|
82 |
|
83 |
# Check if persistent storage is available and writable
|
84 |
use_persistent = False
|
|
|
93 |
print("Persistent storage exists but is not writable, falling back to local storage")
|
94 |
use_persistent = False
|
95 |
|
96 |
+
return (persistent_path if use_persistent else local_path, use_persistent)
|
97 |
+
|
98 |
+
|
99 |
+
def load_languages() -> dict[str, str]:
|
100 |
+
"""Load languages from JSON file or persistent storage"""
|
101 |
+
languages_path, use_persistent = get_persistent_storage_path("languages.json")
|
102 |
+
local_path = Path(__file__).parent / "languages.json"
|
103 |
+
|
104 |
+
# If persistent storage is available but file doesn't exist yet,
|
105 |
+
# copy the local file to persistent storage
|
106 |
+
if use_persistent and not languages_path.exists():
|
107 |
+
try:
|
108 |
+
if local_path.exists():
|
109 |
+
import shutil
|
110 |
+
# Copy the file to persistent storage
|
111 |
+
shutil.copy(local_path, languages_path)
|
112 |
+
print(f"Copied languages to persistent storage at {languages_path}")
|
113 |
+
else:
|
114 |
+
# Create an empty languages file in persistent storage
|
115 |
+
with open(languages_path, "w", encoding="utf-8") as f:
|
116 |
+
json.dump({"English": "You are a helpful assistant."}, f, ensure_ascii=False, indent=2)
|
117 |
+
print(f"Created new languages file in persistent storage at {languages_path}")
|
118 |
+
except Exception as e:
|
119 |
+
print(f"Error setting up persistent storage: {e}")
|
120 |
+
languages_path = local_path # Fall back to local path if any error occurs
|
121 |
+
|
122 |
+
# If the file doesn't exist at the chosen path but exists at the local path, use local
|
123 |
+
if not languages_path.exists() and local_path.exists():
|
124 |
languages_path = local_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
+
# If the file exists, load it
|
127 |
+
if languages_path.exists():
|
128 |
+
with open(languages_path, "r", encoding="utf-8") as f:
|
129 |
+
return json.load(f)
|
130 |
+
else:
|
131 |
+
# Return a default if no file exists
|
132 |
+
default_languages = {"English": "You are a helpful assistant."}
|
133 |
+
return default_languages
|
134 |
|
135 |
|
136 |
# Initial load
|
|
|
273 |
|
274 |
@spaces.GPU
|
275 |
def call_pipeline(messages: list, language: str):
|
276 |
+
"""Call the appropriate model pipeline based on configuration"""
|
277 |
if ZERO_GPU:
|
278 |
# Format the messages using the tokenizer's chat template
|
279 |
tokenizer = CLIENT["tokenizer"]
|
|
|
291 |
)
|
292 |
|
293 |
# Extract the generated content
|
294 |
+
return response[0]["generated_text"]
|
|
|
295 |
else:
|
296 |
response = CLIENT(
|
297 |
messages,
|
298 |
clean_up_tokenization_spaces=False,
|
299 |
max_length=2000,
|
300 |
)
|
301 |
+
return response[0]["generated_text"][-1]["content"]
|
|
|
302 |
|
303 |
|
304 |
def respond(
|
|
|
306 |
language: str,
|
307 |
temperature: Optional[float] = None,
|
308 |
seed: Optional[int] = None,
|
309 |
+
) -> list:
|
310 |
"""Respond to the user message with a system message
|
311 |
|
312 |
Return the history with the new message"""
|
313 |
messages = format_history_as_messages(history)
|
314 |
+
|
315 |
if ZERO_GPU:
|
316 |
content = call_pipeline(messages, language)
|
317 |
else:
|
|
|
323 |
temperature=temperature,
|
324 |
)
|
325 |
content = response.choices[0].message.content
|
326 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
message = gr.ChatMessage(role="assistant", content=content)
|
328 |
history.append(message)
|
329 |
return history
|
|
|
516 |
"""Save the new language and system prompt to persistent storage if available, otherwise to local file."""
|
517 |
global LANGUAGES # Access the global variable
|
518 |
|
519 |
+
# Get the appropriate path
|
520 |
+
languages_path, use_persistent = get_persistent_storage_path("languages.json")
|
521 |
local_path = Path(__file__).parent / "languages.json"
|
522 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
# Load existing languages
|
524 |
if languages_path.exists():
|
525 |
with open(languages_path, "r", encoding="utf-8") as f:
|
|
|
535 |
json.dump(data, f, ensure_ascii=False, indent=2)
|
536 |
|
537 |
# If we're using persistent storage, also update the local file as backup
|
538 |
+
if use_persistent and local_path != languages_path:
|
539 |
try:
|
540 |
with open(local_path, "w", encoding="utf-8") as f:
|
541 |
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
|
545 |
# Update the global LANGUAGES variable with the new data
|
546 |
LANGUAGES.update({lang_name: system_prompt})
|
547 |
|
|
|
|
|
|
|
548 |
# Return a message that will trigger a JavaScript refresh
|
549 |
+
return gr.Group(visible=False), gr.HTML("<script>window.location.reload();</script>"), gr.Dropdown(choices=list(LANGUAGES.keys()))
|
550 |
|
551 |
|
552 |
css = """
|