Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -19,6 +19,7 @@ from torch import Tensor, nn
|
|
19 |
from transformers import CLIPTextModel, CLIPTokenizer
|
20 |
from transformers import T5EncoderModel, T5Tokenizer
|
21 |
from transformers import MarianMTModel, MarianTokenizer, pipeline
|
|
|
22 |
|
23 |
class HFEmbedder(nn.Module):
|
24 |
def __init__(self, version: str, max_length: int, **hf_kwargs):
|
@@ -777,20 +778,41 @@ TRANSLATORS = {
|
|
777 |
|
778 |
translators_cache = {}
|
779 |
|
780 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
781 |
|
782 |
def get_translator(lang):
|
783 |
-
"""
|
784 |
if lang == "English":
|
785 |
return None
|
786 |
|
787 |
if lang not in translators_cache:
|
788 |
try:
|
789 |
model_name = TRANSLATORS[lang]
|
790 |
-
|
791 |
-
|
|
|
|
|
|
|
|
|
|
|
792 |
|
793 |
-
# CPU에서 실행
|
794 |
model = model.to("cpu")
|
795 |
|
796 |
translators_cache[lang] = {
|
@@ -804,35 +826,26 @@ def get_translator(lang):
|
|
804 |
|
805 |
return translators_cache[lang]
|
806 |
|
807 |
-
def
|
808 |
-
"""
|
809 |
-
if source_lang == "English":
|
810 |
-
return prompt
|
811 |
-
|
812 |
-
translator_info = get_translator(source_lang)
|
813 |
if translator_info is None:
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
try:
|
818 |
tokenizer = translator_info["tokenizer"]
|
819 |
model = translator_info["model"]
|
820 |
|
821 |
-
|
822 |
-
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
823 |
-
|
824 |
-
# 번역 수행
|
825 |
translated = model.generate(**inputs)
|
|
|
826 |
|
827 |
-
|
828 |
-
|
829 |
|
830 |
-
|
831 |
-
print(f"Translated: {translated_text}")
|
832 |
-
return translated_text
|
833 |
except Exception as e:
|
834 |
-
print(f"Translation error
|
835 |
-
return
|
836 |
|
837 |
@spaces.GPU
|
838 |
@torch.no_grad()
|
@@ -843,13 +856,15 @@ def generate_image(
|
|
843 |
):
|
844 |
try:
|
845 |
if source_lang != "English":
|
846 |
-
|
|
|
847 |
print(f"Using translated prompt: {translated_prompt}")
|
848 |
else:
|
849 |
translated_prompt = prompt
|
850 |
except Exception as e:
|
851 |
print(f"Translation failed: {e}")
|
852 |
translated_prompt = prompt
|
|
|
853 |
|
854 |
|
855 |
|
|
|
19 |
from transformers import CLIPTextModel, CLIPTokenizer
|
20 |
from transformers import T5EncoderModel, T5Tokenizer
|
21 |
from transformers import MarianMTModel, MarianTokenizer, pipeline
|
22 |
+
from huggingface_hub import snapshot_download
|
23 |
|
24 |
class HFEmbedder(nn.Module):
|
25 |
def __init__(self, version: str, max_length: int, **hf_kwargs):
|
|
|
778 |
|
779 |
translators_cache = {}
|
780 |
|
781 |
+
|
782 |
+
|
783 |
+
# 모델 캐시 디렉토리 설정
|
784 |
+
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
|
785 |
+
|
786 |
+
def download_model(model_name):
|
787 |
+
"""모델을 미리 다운로드"""
|
788 |
+
try:
|
789 |
+
cache_dir = os.path.join('/tmp/transformers_cache', model_name.split('/')[-1])
|
790 |
+
snapshot_download(
|
791 |
+
repo_id=model_name,
|
792 |
+
cache_dir=cache_dir,
|
793 |
+
local_files_only=False
|
794 |
+
)
|
795 |
+
return cache_dir
|
796 |
+
except Exception as e:
|
797 |
+
print(f"Error downloading model {model_name}: {e}")
|
798 |
+
return None
|
799 |
|
800 |
def get_translator(lang):
|
801 |
+
"""번역기 초기화 및 반환"""
|
802 |
if lang == "English":
|
803 |
return None
|
804 |
|
805 |
if lang not in translators_cache:
|
806 |
try:
|
807 |
model_name = TRANSLATORS[lang]
|
808 |
+
cache_dir = download_model(model_name)
|
809 |
+
|
810 |
+
if cache_dir is None:
|
811 |
+
return None
|
812 |
+
|
813 |
+
tokenizer = MarianTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
814 |
+
model = MarianMTModel.from_pretrained(model_name, cache_dir=cache_dir)
|
815 |
|
|
|
816 |
model = model.to("cpu")
|
817 |
|
818 |
translators_cache[lang] = {
|
|
|
826 |
|
827 |
return translators_cache[lang]
|
828 |
|
829 |
+
def translate_text(text, translator_info):
|
830 |
+
"""번역 수행"""
|
|
|
|
|
|
|
|
|
831 |
if translator_info is None:
|
832 |
+
return text
|
833 |
+
|
|
|
834 |
try:
|
835 |
tokenizer = translator_info["tokenizer"]
|
836 |
model = translator_info["model"]
|
837 |
|
838 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
|
|
|
|
|
|
839 |
translated = model.generate(**inputs)
|
840 |
+
result = tokenizer.decode(translated[0], skip_special_tokens=True)
|
841 |
|
842 |
+
print(f"Original text: {text}")
|
843 |
+
print(f"Translated text: {result}")
|
844 |
|
845 |
+
return result
|
|
|
|
|
846 |
except Exception as e:
|
847 |
+
print(f"Translation error: {e}")
|
848 |
+
return text
|
849 |
|
850 |
@spaces.GPU
|
851 |
@torch.no_grad()
|
|
|
856 |
):
|
857 |
try:
|
858 |
if source_lang != "English":
|
859 |
+
translator_info = get_translator(source_lang)
|
860 |
+
translated_prompt = translate_text(prompt, translator_info)
|
861 |
print(f"Using translated prompt: {translated_prompt}")
|
862 |
else:
|
863 |
translated_prompt = prompt
|
864 |
except Exception as e:
|
865 |
print(f"Translation failed: {e}")
|
866 |
translated_prompt = prompt
|
867 |
+
|
868 |
|
869 |
|
870 |
|