ginipick commited on
Commit
f1cb913
·
verified ·
1 Parent(s): 5b664f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -26
app.py CHANGED
@@ -19,6 +19,7 @@ from torch import Tensor, nn
19
  from transformers import CLIPTextModel, CLIPTokenizer
20
  from transformers import T5EncoderModel, T5Tokenizer
21
  from transformers import MarianMTModel, MarianTokenizer, pipeline
 
22
 
23
  class HFEmbedder(nn.Module):
24
  def __init__(self, version: str, max_length: int, **hf_kwargs):
@@ -777,20 +778,41 @@ TRANSLATORS = {
777
 
778
  translators_cache = {}
779
 
780
- from transformers import MarianMTModel, MarianTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
 
782
  def get_translator(lang):
783
- """단일 번역기를 초기화하고 반환하는 함수"""
784
  if lang == "English":
785
  return None
786
 
787
  if lang not in translators_cache:
788
  try:
789
  model_name = TRANSLATORS[lang]
790
- tokenizer = MarianTokenizer.from_pretrained(model_name)
791
- model = MarianMTModel.from_pretrained(model_name)
 
 
 
 
 
792
 
793
- # CPU에서 실행
794
  model = model.to("cpu")
795
 
796
  translators_cache[lang] = {
@@ -804,35 +826,26 @@ def get_translator(lang):
804
 
805
  return translators_cache[lang]
806
 
807
- def translate_prompt(prompt, source_lang):
808
- """프롬프트를 번역하는 함수"""
809
- if source_lang == "English":
810
- return prompt
811
-
812
- translator_info = get_translator(source_lang)
813
  if translator_info is None:
814
- print(f"No translator available for {source_lang}, using original prompt")
815
- return prompt
816
-
817
  try:
818
  tokenizer = translator_info["tokenizer"]
819
  model = translator_info["model"]
820
 
821
- # 텍스트를 토큰화하고 모델 입력으로 변환
822
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
823
-
824
- # 번역 수행
825
  translated = model.generate(**inputs)
 
826
 
827
- # 번역된 텍스트 디코딩
828
- translated_text = tokenizer.decode(translated[0], skip_special_tokens=True)
829
 
830
- print(f"Original ({source_lang}): {prompt}")
831
- print(f"Translated: {translated_text}")
832
- return translated_text
833
  except Exception as e:
834
- print(f"Translation error for {source_lang}: {e}")
835
- return prompt
836
 
837
  @spaces.GPU
838
  @torch.no_grad()
@@ -843,13 +856,15 @@ def generate_image(
843
  ):
844
  try:
845
  if source_lang != "English":
846
- translated_prompt = translate_prompt(prompt, source_lang)
 
847
  print(f"Using translated prompt: {translated_prompt}")
848
  else:
849
  translated_prompt = prompt
850
  except Exception as e:
851
  print(f"Translation failed: {e}")
852
  translated_prompt = prompt
 
853
 
854
 
855
 
 
19
  from transformers import CLIPTextModel, CLIPTokenizer
20
  from transformers import T5EncoderModel, T5Tokenizer
21
  from transformers import MarianMTModel, MarianTokenizer, pipeline
22
+ from huggingface_hub import snapshot_download
23
 
24
  class HFEmbedder(nn.Module):
25
  def __init__(self, version: str, max_length: int, **hf_kwargs):
 
778
 
779
  translators_cache = {}
780
 
781
+
782
+
783
+ # 모델 캐시 디렉토리 설정
784
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
785
+
786
+ def download_model(model_name):
787
+ """모델을 미리 다운로드"""
788
+ try:
789
+ cache_dir = os.path.join('/tmp/transformers_cache', model_name.split('/')[-1])
790
+ snapshot_download(
791
+ repo_id=model_name,
792
+ cache_dir=cache_dir,
793
+ local_files_only=False
794
+ )
795
+ return cache_dir
796
+ except Exception as e:
797
+ print(f"Error downloading model {model_name}: {e}")
798
+ return None
799
 
800
  def get_translator(lang):
801
+ """번역기 초기화 반환"""
802
  if lang == "English":
803
  return None
804
 
805
  if lang not in translators_cache:
806
  try:
807
  model_name = TRANSLATORS[lang]
808
+ cache_dir = download_model(model_name)
809
+
810
+ if cache_dir is None:
811
+ return None
812
+
813
+ tokenizer = MarianTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
814
+ model = MarianMTModel.from_pretrained(model_name, cache_dir=cache_dir)
815
 
 
816
  model = model.to("cpu")
817
 
818
  translators_cache[lang] = {
 
826
 
827
  return translators_cache[lang]
828
 
829
+ def translate_text(text, translator_info):
830
+ """번역 수행"""
 
 
 
 
831
  if translator_info is None:
832
+ return text
833
+
 
834
  try:
835
  tokenizer = translator_info["tokenizer"]
836
  model = translator_info["model"]
837
 
838
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
 
 
839
  translated = model.generate(**inputs)
840
+ result = tokenizer.decode(translated[0], skip_special_tokens=True)
841
 
842
+ print(f"Original text: {text}")
843
+ print(f"Translated text: {result}")
844
 
845
+ return result
 
 
846
  except Exception as e:
847
+ print(f"Translation error: {e}")
848
+ return text
849
 
850
  @spaces.GPU
851
  @torch.no_grad()
 
856
  ):
857
  try:
858
  if source_lang != "English":
859
+ translator_info = get_translator(source_lang)
860
+ translated_prompt = translate_text(prompt, translator_info)
861
  print(f"Using translated prompt: {translated_prompt}")
862
  else:
863
  translated_prompt = prompt
864
  except Exception as e:
865
  print(f"Translation failed: {e}")
866
  translated_prompt = prompt
867
+
868
 
869
 
870