ginipick commited on
Commit
9fd5050
·
verified ·
1 Parent(s): f2c840c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -13
app.py CHANGED
@@ -20,8 +20,6 @@ from transformers import CLIPTextModel, CLIPTokenizer
20
  from transformers import T5EncoderModel, T5Tokenizer
21
  from transformers import pipeline, AutoTokenizer, MarianMTModel
22
 
23
- translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
24
- # ---------------- Encoders ----------------
25
 
26
 
27
  class HFEmbedder(nn.Module):
@@ -793,38 +791,62 @@ TRANSLATORS = {
793
  # 번역기 캐시 딕셔너리
794
  translators_cache = {}
795
 
 
796
  def get_translator(lang):
 
 
 
797
  if lang not in translators_cache:
798
- model_name = TRANSLATORS[lang]
799
  try:
 
800
  tokenizer = AutoTokenizer.from_pretrained(model_name)
801
  model = MarianMTModel.from_pretrained(
802
  model_name,
803
- torch_dtype=torch.float16,
804
  low_cpu_mem_usage=True
805
- )
806
- translators_cache[lang] = pipeline(
807
- "translation",
808
- model=model,
809
- tokenizer=tokenizer,
810
- device=-1 # CPU 사용
811
- )
812
  except Exception as e:
813
  print(f"Error loading translator for {lang}: {e}")
814
  return None
 
815
  return translators_cache[lang]
816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
817
  def translate_prompt(prompt, source_lang):
818
  if source_lang == "English":
819
  return prompt
820
 
821
  translator = get_translator(source_lang)
822
  if translator is None:
823
- print(f"Translation failed for {source_lang}, using original prompt")
824
  return prompt
825
 
826
  try:
827
- translated = translator(prompt, max_length=512)[0]['translation_text']
828
  print(f"Translated from {source_lang}: {translated}")
829
  return translated
830
  except Exception as e:
@@ -840,6 +862,7 @@ def generate_image(
840
  ):
841
  try:
842
  translated_prompt = translate_prompt(prompt, source_lang)
 
843
  except Exception as e:
844
  print(f"Translation failed: {e}")
845
  translated_prompt = prompt
 
20
  from transformers import T5EncoderModel, T5Tokenizer
21
  from transformers import pipeline, AutoTokenizer, MarianMTModel
22
 
 
 
23
 
24
 
25
  class HFEmbedder(nn.Module):
 
791
  # 번역기 캐시 딕셔너리
792
  translators_cache = {}
793
 
794
+ # 번역기 초기화 부분 수정
795
  def get_translator(lang):
796
+ if lang == "English":
797
+ return None
798
+
799
  if lang not in translators_cache:
 
800
  try:
801
+ model_name = TRANSLATORS[lang]
802
  tokenizer = AutoTokenizer.from_pretrained(model_name)
803
  model = MarianMTModel.from_pretrained(
804
  model_name,
805
+ torch_dtype=torch.float32, # float16 대신 float32 사용
806
  low_cpu_mem_usage=True
807
+ ).to("cpu") # 명시적으로 CPU 지정
808
+
809
+ translators_cache[lang] = {
810
+ "model": model,
811
+ "tokenizer": tokenizer
812
+ }
 
813
  except Exception as e:
814
  print(f"Error loading translator for {lang}: {e}")
815
  return None
816
+
817
  return translators_cache[lang]
818
 
819
+ def translate_text(text, translator_info):
820
+ if translator_info is None:
821
+ return text
822
+
823
+ try:
824
+ tokenizer = translator_info["tokenizer"]
825
+ model = translator_info["model"]
826
+
827
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
828
+ translated = model.generate(**inputs)
829
+ result = tokenizer.decode(translated[0], skip_special_tokens=True)
830
+
831
+ print(f"Original text: {text}")
832
+ print(f"Translated text: {result}")
833
+
834
+ return result
835
+ except Exception as e:
836
+ print(f"Translation error: {e}")
837
+ return text
838
+
839
  def translate_prompt(prompt, source_lang):
840
  if source_lang == "English":
841
  return prompt
842
 
843
  translator = get_translator(source_lang)
844
  if translator is None:
845
+ print(f"No translator available for {source_lang}, using original prompt")
846
  return prompt
847
 
848
  try:
849
+ translated = translate_text(prompt, translator)
850
  print(f"Translated from {source_lang}: {translated}")
851
  return translated
852
  except Exception as e:
 
862
  ):
863
  try:
864
  translated_prompt = translate_prompt(prompt, source_lang)
865
+ print(f"Using prompt: {translated_prompt}")
866
  except Exception as e:
867
  print(f"Translation failed: {e}")
868
  translated_prompt = prompt