ginipick commited on
Commit
1c3c162
·
verified ·
1 Parent(s): 6c83a88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -10
app.py CHANGED
@@ -753,7 +753,30 @@ ko_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
753
  ja_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en")
754
  zh_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en")
755
 
 
756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757
 
758
  @spaces.GPU
759
  @torch.no_grad()
@@ -762,7 +785,11 @@ def generate_image(
762
  do_img2img, init_image, image2image_strength, resize_img,
763
  progress=gr.Progress(track_tqdm=True),
764
  ):
765
- translated_prompt = prompt
 
 
 
 
766
 
767
  # 한글, 일본어, 중국어 문자 감지
768
  def contains_korean(text):
@@ -887,16 +914,14 @@ def create_demo():
887
  outputs=[init_image, image2image_strength, resize_img]
888
  )
889
 
890
- examples = [
891
- ["A magical fairy garden with glowing mushrooms and floating lanterns"], # English
892
- ["아름다운 벚꽃이 흩날리는 한옥 정원에서 한복을 입은 소녀"], # Korean
893
- ["夕暮れの富士山と桜の木の下で休んでいる可愛い柴犬"], # Japanese
894
- ["古老的中国庭园里,一只熊猫正在竹林中悠闲地吃着竹子"] # Chinese
895
- ]
896
-
897
  gr.Examples(
898
- examples=examples,
899
- inputs=prompt,
 
 
 
 
 
900
  outputs=[output_image, output_seed],
901
  fn=generate_image,
902
  cache_examples=True
 
753
  ja_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en")
754
  zh_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en")
755
 
756
+ from transformers import MarianMTModel, MarianTokenizer
757
 
758
+ def translate_text(text, src_lang, model_name):
759
+ try:
760
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
761
+ model = MarianMTModel.from_pretrained(model_name)
762
+ model = model.to(device)
763
+
764
+ inputs = tokenizer(text, return_tensors="pt", padding=True).to(device)
765
+ translated = model.generate(**inputs)
766
+ translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
767
+ return translated_text
768
+ except:
769
+ return text # 번역 실패시 원본 텍스트 반환
770
+
771
+ # 기존의 translator 정의 부분을 삭제하고 아래 코드로 대체
772
+ def translate_if_needed(prompt):
773
+ if any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in prompt): # Korean
774
+ return translate_text(prompt, 'ko', 'Helsinki-NLP/opus-mt-ko-en')
775
+ elif any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' for c in prompt): # Japanese
776
+ return translate_text(prompt, 'ja', 'Helsinki-NLP/opus-mt-ja-en')
777
+ elif any('\u4e00' <= c <= '\u9fff' for c in prompt): # Chinese
778
+ return translate_text(prompt, 'zh', 'Helsinki-NLP/opus-mt-zh-en')
779
+ return prompt
780
 
781
  @spaces.GPU
782
  @torch.no_grad()
 
785
  do_img2img, init_image, image2image_strength, resize_img,
786
  progress=gr.Progress(track_tqdm=True),
787
  ):
788
+ translated_prompt = translate_if_needed(prompt)
789
+ if translated_prompt != prompt:
790
+ print(f"Translated prompt: {translated_prompt}")
791
+ prompt = translated_prompt
792
+
793
 
794
  # 한글, 일본어, 중국어 문자 감지
795
  def contains_korean(text):
 
914
  outputs=[init_image, image2image_strength, resize_img]
915
  )
916
 
 
 
 
 
 
 
 
917
  gr.Examples(
918
+ examples=[
919
+ ["A magical fairy garden with glowing mushrooms and floating lanterns", 768, 768, 3.5, 30, 0, False, None, 0.8, True], # English
920
+ ["아름다운 벚꽃이 흩날리는 한옥 정원에서 한복을 입은 소녀", 768, 768, 3.5, 30, 0, False, None, 0.8, True], # Korean
921
+ ["夕暮れの富士山と桜の木の下で休んでいる可愛い柴犬", 768, 768, 3.5, 30, 0, False, None, 0.8, True], # Japanese
922
+ ["古老的中国庭园里,一只熊猫正在竹林中悠闲地吃着竹子", 768, 768, 3.5, 30, 0, False, None, 0.8, True] # Chinese
923
+ ],
924
+ inputs=[prompt, width, height, guidance, inference_steps, seed, do_img2img, init_image, image2image_strength, resize_img],
925
  outputs=[output_image, output_seed],
926
  fn=generate_image,
927
  cache_examples=True