idan shenfeld commited on
Commit
549219e
·
1 Parent(s): 7ca0870

code cleanup

Browse files
Files changed (1) hide show
  1. app/app.py +60 -73
app/app.py CHANGED
@@ -33,16 +33,17 @@ TEXT_ONLY = (
33
 
34
  def create_inference_client(
35
  model: Optional[str] = None, base_url: Optional[str] = None
36
- ) -> InferenceClient:
37
  """Create an InferenceClient instance with the given model or environment settings.
38
  This function will run the model locally if ZERO_GPU is set to True.
39
  This function will run the model locally if ZERO_GPU is set to True.
40
 
41
  Args:
42
  model: Optional model identifier to use. If not provided, will use environment settings.
 
43
 
44
  Returns:
45
- InferenceClient: Configured client instance
46
  """
47
  if ZERO_GPU:
48
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
@@ -67,11 +68,17 @@ def create_inference_client(
67
  CLIENT = create_inference_client()
68
 
69
 
70
- def load_languages() -> dict[str, str]:
71
- """Load languages from JSON file or persistent storage"""
72
- # First check if we have persistent storage available
73
- persistent_path = Path("/data/languages.json")
74
- local_path = Path(__file__).parent / "languages.json"
 
 
 
 
 
 
75
 
76
  # Check if persistent storage is available and writable
77
  use_persistent = False
@@ -86,35 +93,44 @@ def load_languages() -> dict[str, str]:
86
  print("Persistent storage exists but is not writable, falling back to local storage")
87
  use_persistent = False
88
 
89
- # Use persistent storage if available and writable, otherwise fall back to local file
90
- if use_persistent and persistent_path.exists():
91
- languages_path = persistent_path
92
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  languages_path = local_path
94
-
95
- # If persistent storage is available and writable but file doesn't exist yet,
96
- # copy the local file to persistent storage
97
- if use_persistent:
98
- try:
99
- # Ensure local file exists first
100
- if local_path.exists():
101
- import shutil
102
- # Copy the file to persistent storage
103
- shutil.copy(local_path, persistent_path)
104
- languages_path = persistent_path
105
- print(f"Copied languages to persistent storage at {persistent_path}")
106
- else:
107
- # Create an empty languages file in persistent storage
108
- with open(persistent_path, "w", encoding="utf-8") as f:
109
- json.dump({"English": "You are a helpful assistant."}, f, ensure_ascii=False, indent=2)
110
- languages_path = persistent_path
111
- print(f"Created new languages file in persistent storage at {persistent_path}")
112
- except Exception as e:
113
- print(f"Error setting up persistent storage: {e}")
114
- languages_path = local_path # Fall back to local path if any error occurs
115
 
116
- with open(languages_path, "r", encoding="utf-8") as f:
117
- return json.load(f)
 
 
 
 
 
 
118
 
119
 
120
  # Initial load
@@ -257,6 +273,7 @@ def add_fake_like_data(
257
 
258
  @spaces.GPU
259
  def call_pipeline(messages: list, language: str):
 
260
  if ZERO_GPU:
261
  # Format the messages using the tokenizer's chat template
262
  tokenizer = CLIENT["tokenizer"]
@@ -274,16 +291,14 @@ def call_pipeline(messages: list, language: str):
274
  )
275
 
276
  # Extract the generated content
277
- content = response[0]["generated_text"]
278
- return content
279
  else:
280
  response = CLIENT(
281
  messages,
282
  clean_up_tokenization_spaces=False,
283
  max_length=2000,
284
  )
285
- content = response[0]["generated_text"][-1]["content"]
286
- return content
287
 
288
 
289
  def respond(
@@ -291,11 +306,12 @@ def respond(
291
  language: str,
292
  temperature: Optional[float] = None,
293
  seed: Optional[int] = None,
294
- ) -> list: # -> list:
295
  """Respond to the user message with a system message
296
 
297
  Return the history with the new message"""
298
  messages = format_history_as_messages(history)
 
299
  if ZERO_GPU:
300
  content = call_pipeline(messages, language)
301
  else:
@@ -307,17 +323,7 @@ def respond(
307
  temperature=temperature,
308
  )
309
  content = response.choices[0].message.content
310
- if ZERO_GPU:
311
- content = call_pipeline(messages, language)
312
- else:
313
- response = CLIENT.chat.completions.create(
314
- messages=messages,
315
- max_tokens=2000,
316
- stream=False,
317
- seed=seed,
318
- temperature=temperature,
319
- )
320
- content = response.choices[0].message.content
321
  message = gr.ChatMessage(role="assistant", content=content)
322
  history.append(message)
323
  return history
@@ -510,26 +516,10 @@ def save_new_language(lang_name, system_prompt):
510
  """Save the new language and system prompt to persistent storage if available, otherwise to local file."""
511
  global LANGUAGES # Access the global variable
512
 
513
- # First determine where to save the file
514
- persistent_path = Path("/data/languages.json")
515
  local_path = Path(__file__).parent / "languages.json"
516
 
517
- # Check if persistent storage is available and writable
518
- use_persistent = False
519
- if Path("/data").exists() and Path("/data").is_dir():
520
- try:
521
- # Test if we can write to the directory
522
- test_file = Path("/data/write_test.tmp")
523
- test_file.touch()
524
- test_file.unlink() # Remove the test file
525
- use_persistent = True
526
- except (PermissionError, OSError):
527
- print("Persistent storage exists but is not writable, falling back to local storage")
528
- use_persistent = False
529
-
530
- # Use persistent storage if available and writable, otherwise fall back to local file
531
- languages_path = persistent_path if use_persistent else local_path
532
-
533
  # Load existing languages
534
  if languages_path.exists():
535
  with open(languages_path, "r", encoding="utf-8") as f:
@@ -545,7 +535,7 @@ def save_new_language(lang_name, system_prompt):
545
  json.dump(data, f, ensure_ascii=False, indent=2)
546
 
547
  # If we're using persistent storage, also update the local file as backup
548
- if use_persistent and local_path != persistent_path:
549
  try:
550
  with open(local_path, "w", encoding="utf-8") as f:
551
  json.dump(data, f, ensure_ascii=False, indent=2)
@@ -555,11 +545,8 @@ def save_new_language(lang_name, system_prompt):
555
  # Update the global LANGUAGES variable with the new data
556
  LANGUAGES.update({lang_name: system_prompt})
557
 
558
- # Update the dropdown choices
559
- new_choices = list(LANGUAGES.keys())
560
-
561
  # Return a message that will trigger a JavaScript refresh
562
- return gr.Group(visible=False), gr.HTML("<script>window.location.reload();</script>"), gr.Dropdown(choices=new_choices)
563
 
564
 
565
  css = """
 
33
 
34
  def create_inference_client(
35
  model: Optional[str] = None, base_url: Optional[str] = None
36
+ ) -> InferenceClient | dict:
37
  """Create an InferenceClient instance with the given model or environment settings.
38
  This function will run the model locally if ZERO_GPU is set to True.
39
  This function will run the model locally if ZERO_GPU is set to True.
40
 
41
  Args:
42
  model: Optional model identifier to use. If not provided, will use environment settings.
43
+ base_url: Optional base URL for the inference API.
44
 
45
  Returns:
46
+ Either an InferenceClient instance or a dictionary with pipeline and tokenizer
47
  """
48
  if ZERO_GPU:
49
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
 
68
  CLIENT = create_inference_client()
69
 
70
 
71
+ def get_persistent_storage_path(filename: str) -> tuple[Path, bool]:
72
+ """Check if persistent storage is available and return the appropriate path.
73
+
74
+ Args:
75
+ filename: The name of the file to check/create
76
+
77
+ Returns:
78
+ A tuple containing (file_path, is_persistent)
79
+ """
80
+ persistent_path = Path("/data") / filename
81
+ local_path = Path(__file__).parent / filename
82
 
83
  # Check if persistent storage is available and writable
84
  use_persistent = False
 
93
  print("Persistent storage exists but is not writable, falling back to local storage")
94
  use_persistent = False
95
 
96
+ return (persistent_path if use_persistent else local_path, use_persistent)
97
+
98
+
99
+ def load_languages() -> dict[str, str]:
100
+ """Load languages from JSON file or persistent storage"""
101
+ languages_path, use_persistent = get_persistent_storage_path("languages.json")
102
+ local_path = Path(__file__).parent / "languages.json"
103
+
104
+ # If persistent storage is available but file doesn't exist yet,
105
+ # copy the local file to persistent storage
106
+ if use_persistent and not languages_path.exists():
107
+ try:
108
+ if local_path.exists():
109
+ import shutil
110
+ # Copy the file to persistent storage
111
+ shutil.copy(local_path, languages_path)
112
+ print(f"Copied languages to persistent storage at {languages_path}")
113
+ else:
114
+ # Create an empty languages file in persistent storage
115
+ with open(languages_path, "w", encoding="utf-8") as f:
116
+ json.dump({"English": "You are a helpful assistant."}, f, ensure_ascii=False, indent=2)
117
+ print(f"Created new languages file in persistent storage at {languages_path}")
118
+ except Exception as e:
119
+ print(f"Error setting up persistent storage: {e}")
120
+ languages_path = local_path # Fall back to local path if any error occurs
121
+
122
+ # If the file doesn't exist at the chosen path but exists at the local path, use local
123
+ if not languages_path.exists() and local_path.exists():
124
  languages_path = local_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ # If the file exists, load it
127
+ if languages_path.exists():
128
+ with open(languages_path, "r", encoding="utf-8") as f:
129
+ return json.load(f)
130
+ else:
131
+ # Return a default if no file exists
132
+ default_languages = {"English": "You are a helpful assistant."}
133
+ return default_languages
134
 
135
 
136
  # Initial load
 
273
 
274
  @spaces.GPU
275
  def call_pipeline(messages: list, language: str):
276
+ """Call the appropriate model pipeline based on configuration"""
277
  if ZERO_GPU:
278
  # Format the messages using the tokenizer's chat template
279
  tokenizer = CLIENT["tokenizer"]
 
291
  )
292
 
293
  # Extract the generated content
294
+ return response[0]["generated_text"]
 
295
  else:
296
  response = CLIENT(
297
  messages,
298
  clean_up_tokenization_spaces=False,
299
  max_length=2000,
300
  )
301
+ return response[0]["generated_text"][-1]["content"]
 
302
 
303
 
304
  def respond(
 
306
  language: str,
307
  temperature: Optional[float] = None,
308
  seed: Optional[int] = None,
309
+ ) -> list:
310
  """Respond to the user message with a system message
311
 
312
  Return the history with the new message"""
313
  messages = format_history_as_messages(history)
314
+
315
  if ZERO_GPU:
316
  content = call_pipeline(messages, language)
317
  else:
 
323
  temperature=temperature,
324
  )
325
  content = response.choices[0].message.content
326
+
 
 
 
 
 
 
 
 
 
 
327
  message = gr.ChatMessage(role="assistant", content=content)
328
  history.append(message)
329
  return history
 
516
  """Save the new language and system prompt to persistent storage if available, otherwise to local file."""
517
  global LANGUAGES # Access the global variable
518
 
519
+ # Get the appropriate path
520
+ languages_path, use_persistent = get_persistent_storage_path("languages.json")
521
  local_path = Path(__file__).parent / "languages.json"
522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  # Load existing languages
524
  if languages_path.exists():
525
  with open(languages_path, "r", encoding="utf-8") as f:
 
535
  json.dump(data, f, ensure_ascii=False, indent=2)
536
 
537
  # If we're using persistent storage, also update the local file as backup
538
+ if use_persistent and local_path != languages_path:
539
  try:
540
  with open(local_path, "w", encoding="utf-8") as f:
541
  json.dump(data, f, ensure_ascii=False, indent=2)
 
545
  # Update the global LANGUAGES variable with the new data
546
  LANGUAGES.update({lang_name: system_prompt})
547
 
 
 
 
548
  # Return a message that will trigger a JavaScript refresh
549
+ return gr.Group(visible=False), gr.HTML("<script>window.location.reload();</script>"), gr.Dropdown(choices=list(LANGUAGES.keys()))
550
 
551
 
552
  css = """