Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -20,8 +20,6 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
|
20 |
from transformers import T5EncoderModel, T5Tokenizer
|
21 |
from transformers import pipeline, AutoTokenizer, MarianMTModel
|
22 |
|
23 |
-
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
|
24 |
-
# ---------------- Encoders ----------------
|
25 |
|
26 |
|
27 |
class HFEmbedder(nn.Module):
|
@@ -793,38 +791,62 @@ TRANSLATORS = {
|
|
793 |
# 번역기 캐시 딕셔너리
|
794 |
translators_cache = {}
|
795 |
|
|
|
796 |
def get_translator(lang):
|
|
|
|
|
|
|
797 |
if lang not in translators_cache:
|
798 |
-
model_name = TRANSLATORS[lang]
|
799 |
try:
|
|
|
800 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
801 |
model = MarianMTModel.from_pretrained(
|
802 |
model_name,
|
803 |
-
torch_dtype=torch.float16
|
804 |
low_cpu_mem_usage=True
|
805 |
-
)
|
806 |
-
|
807 |
-
|
808 |
-
model
|
809 |
-
tokenizer
|
810 |
-
|
811 |
-
)
|
812 |
except Exception as e:
|
813 |
print(f"Error loading translator for {lang}: {e}")
|
814 |
return None
|
|
|
815 |
return translators_cache[lang]
|
816 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
817 |
def translate_prompt(prompt, source_lang):
|
818 |
if source_lang == "English":
|
819 |
return prompt
|
820 |
|
821 |
translator = get_translator(source_lang)
|
822 |
if translator is None:
|
823 |
-
print(f"
|
824 |
return prompt
|
825 |
|
826 |
try:
|
827 |
-
translated =
|
828 |
print(f"Translated from {source_lang}: {translated}")
|
829 |
return translated
|
830 |
except Exception as e:
|
@@ -840,6 +862,7 @@ def generate_image(
|
|
840 |
):
|
841 |
try:
|
842 |
translated_prompt = translate_prompt(prompt, source_lang)
|
|
|
843 |
except Exception as e:
|
844 |
print(f"Translation failed: {e}")
|
845 |
translated_prompt = prompt
|
|
|
20 |
from transformers import T5EncoderModel, T5Tokenizer
|
21 |
from transformers import pipeline, AutoTokenizer, MarianMTModel
|
22 |
|
|
|
|
|
23 |
|
24 |
|
25 |
class HFEmbedder(nn.Module):
|
|
|
791 |
# 번역기 캐시 딕셔너리
|
792 |
translators_cache = {}
|
793 |
|
794 |
+
# 번역기 초기화 부분 수정
|
795 |
def get_translator(lang):
|
796 |
+
if lang == "English":
|
797 |
+
return None
|
798 |
+
|
799 |
if lang not in translators_cache:
|
|
|
800 |
try:
|
801 |
+
model_name = TRANSLATORS[lang]
|
802 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
803 |
model = MarianMTModel.from_pretrained(
|
804 |
model_name,
|
805 |
+
torch_dtype=torch.float32, # float16 대신 float32 사용
|
806 |
low_cpu_mem_usage=True
|
807 |
+
).to("cpu") # 명시적으로 CPU 지정
|
808 |
+
|
809 |
+
translators_cache[lang] = {
|
810 |
+
"model": model,
|
811 |
+
"tokenizer": tokenizer
|
812 |
+
}
|
|
|
813 |
except Exception as e:
|
814 |
print(f"Error loading translator for {lang}: {e}")
|
815 |
return None
|
816 |
+
|
817 |
return translators_cache[lang]
|
818 |
|
819 |
+
def translate_text(text, translator_info):
|
820 |
+
if translator_info is None:
|
821 |
+
return text
|
822 |
+
|
823 |
+
try:
|
824 |
+
tokenizer = translator_info["tokenizer"]
|
825 |
+
model = translator_info["model"]
|
826 |
+
|
827 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
828 |
+
translated = model.generate(**inputs)
|
829 |
+
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
830 |
+
|
831 |
+
print(f"Original text: {text}")
|
832 |
+
print(f"Translated text: {result}")
|
833 |
+
|
834 |
+
return result
|
835 |
+
except Exception as e:
|
836 |
+
print(f"Translation error: {e}")
|
837 |
+
return text
|
838 |
+
|
839 |
def translate_prompt(prompt, source_lang):
|
840 |
if source_lang == "English":
|
841 |
return prompt
|
842 |
|
843 |
translator = get_translator(source_lang)
|
844 |
if translator is None:
|
845 |
+
print(f"No translator available for {source_lang}, using original prompt")
|
846 |
return prompt
|
847 |
|
848 |
try:
|
849 |
+
translated = translate_text(prompt, translator)
|
850 |
print(f"Translated from {source_lang}: {translated}")
|
851 |
return translated
|
852 |
except Exception as e:
|
|
|
862 |
):
|
863 |
try:
|
864 |
translated_prompt = translate_prompt(prompt, source_lang)
|
865 |
+
print(f"Using prompt: {translated_prompt}")
|
866 |
except Exception as e:
|
867 |
print(f"Translation failed: {e}")
|
868 |
translated_prompt = prompt
|