ginipick commited on
Commit
5b664f6
·
verified ·
1 Parent(s): 6bb005a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -54
app.py CHANGED
@@ -777,6 +777,8 @@ TRANSLATORS = {
777
 
778
  translators_cache = {}
779
 
 
 
780
  def get_translator(lang):
781
  """단일 번역기를 초기화하고 반환하는 함수"""
782
  if lang == "English":
@@ -785,12 +787,16 @@ def get_translator(lang):
785
  if lang not in translators_cache:
786
  try:
787
  model_name = TRANSLATORS[lang]
788
- translator = pipeline(
789
- "translation",
790
- model=model_name,
791
- device="cpu" # CPU 고정
792
- )
793
- translators_cache[lang] = translator
 
 
 
 
794
  print(f"Successfully loaded translator for {lang}")
795
  except Exception as e:
796
  print(f"Error loading translator for {lang}: {e}")
@@ -803,64 +809,30 @@ def translate_prompt(prompt, source_lang):
803
  if source_lang == "English":
804
  return prompt
805
 
806
- translator = get_translator(source_lang)
807
- if translator is None:
808
  print(f"No translator available for {source_lang}, using original prompt")
809
  return prompt
810
 
811
- try:
812
- translation = translator(prompt, max_length=512)
813
- translated_text = translation[0]['translation_text']
814
- print(f"Original ({source_lang}): {prompt}")
815
- print(f"Translated: {translated_text}")
816
- return translated_text
817
- except Exception as e:
818
- print(f"Translation error for {source_lang}: {e}")
819
- return prompt
820
-
821
-
822
- def get_translator(lang):
823
- if lang == "English":
824
- return None
825
-
826
- if lang not in translators_cache:
827
- try:
828
- model_name = TRANSLATORS[lang]
829
- tokenizer = MarianTokenizer.from_pretrained(model_name)
830
- model = MarianMTModel.from_pretrained(model_name).to("cpu")
831
-
832
- translators_cache[lang] = pipeline(
833
- "translation",
834
- model=model,
835
- tokenizer=tokenizer,
836
- device="cpu"
837
- )
838
- except Exception as e:
839
- print(f"Error loading translator for {lang}: {e}")
840
- return None
841
-
842
- return translators_cache[lang]
843
-
844
- def translate_text(text, translator_info):
845
- if translator_info is None:
846
- return text
847
-
848
  try:
849
  tokenizer = translator_info["tokenizer"]
850
  model = translator_info["model"]
851
 
852
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
 
 
853
  translated = model.generate(**inputs)
854
- result = tokenizer.decode(translated[0], skip_special_tokens=True)
855
 
856
- print(f"Original text: {text}")
857
- print(f"Translated text: {result}")
858
 
859
- return result
 
 
860
  except Exception as e:
861
- print(f"Translation error: {e}")
862
- return text
863
-
864
 
865
  @spaces.GPU
866
  @torch.no_grad()
@@ -869,7 +841,6 @@ def generate_image(
869
  do_img2img, init_image, image2image_strength, resize_img,
870
  progress=gr.Progress(track_tqdm=True),
871
  ):
872
- # 번역 처리
873
  try:
874
  if source_lang != "English":
875
  translated_prompt = translate_prompt(prompt, source_lang)
@@ -880,6 +851,7 @@ def generate_image(
880
  print(f"Translation failed: {e}")
881
  translated_prompt = prompt
882
 
 
883
 
884
  if seed == 0:
885
  seed = int(random.random() * 1000000)
 
777
 
778
  translators_cache = {}
779
 
780
+ from transformers import MarianMTModel, MarianTokenizer
781
+
782
  def get_translator(lang):
783
  """단일 번역기를 초기화하고 반환하는 함수"""
784
  if lang == "English":
 
787
  if lang not in translators_cache:
788
  try:
789
  model_name = TRANSLATORS[lang]
790
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
791
+ model = MarianMTModel.from_pretrained(model_name)
792
+
793
+ # CPU에서 실행
794
+ model = model.to("cpu")
795
+
796
+ translators_cache[lang] = {
797
+ "model": model,
798
+ "tokenizer": tokenizer
799
+ }
800
  print(f"Successfully loaded translator for {lang}")
801
  except Exception as e:
802
  print(f"Error loading translator for {lang}: {e}")
 
809
  if source_lang == "English":
810
  return prompt
811
 
812
+ translator_info = get_translator(source_lang)
813
+ if translator_info is None:
814
  print(f"No translator available for {source_lang}, using original prompt")
815
  return prompt
816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
817
  try:
818
  tokenizer = translator_info["tokenizer"]
819
  model = translator_info["model"]
820
 
821
+ # 텍스트를 토큰화하고 모델 입력으로 변환
822
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
823
+
824
+ # 번역 수행
825
  translated = model.generate(**inputs)
 
826
 
827
+ # 번역된 텍스트 디코딩
828
+ translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
829
 
830
+ print(f"Original ({source_lang}): {prompt}")
831
+ print(f"Translated: {translated_text}")
832
+ return translated_text
833
  except Exception as e:
834
+ print(f"Translation error for {source_lang}: {e}")
835
+ return prompt
 
836
 
837
  @spaces.GPU
838
  @torch.no_grad()
 
841
  do_img2img, init_image, image2image_strength, resize_img,
842
  progress=gr.Progress(track_tqdm=True),
843
  ):
 
844
  try:
845
  if source_lang != "English":
846
  translated_prompt = translate_prompt(prompt, source_lang)
 
851
  print(f"Translation failed: {e}")
852
  translated_prompt = prompt
853
 
854
+
855
 
856
  if seed == 0:
857
  seed = int(random.random() * 1000000)