ginipick commited on
Commit
4e85f51
·
verified ·
1 Parent(s): 844ec2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -32
app.py CHANGED
@@ -753,39 +753,78 @@ model_zero_init = False
753
  # result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
754
 
755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
756
  @spaces.GPU
757
  @torch.no_grad()
758
  def generate_image(
759
  prompt, width, height, guidance, inference_steps, seed,
760
  do_img2img, init_image, image2image_strength, resize_img,
 
761
  progress=gr.Progress(track_tqdm=True),
762
  ):
763
  translated_prompt = prompt
764
 
765
- # 한글 또는 일본어 문자 감지
766
- def contains_korean(text):
767
- return any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in text)
768
-
769
- def contains_japanese(text):
770
- return any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' or '\u4E00' <= c <= '\u9FFF' for c in text)
771
 
772
- # 한글이나 일본어가 있으면 번역
773
- if contains_korean(prompt):
774
- translated_prompt = ko_translator(prompt, max_length=512)[0]['translation_text']
775
- print(f"Translated Korean prompt: {translated_prompt}")
776
- prompt = translated_prompt
777
- elif contains_japanese(prompt):
778
- translated_prompt = ja_translator(prompt, max_length=512)[0]['translation_text']
779
- print(f"Translated Japanese prompt: {translated_prompt}")
780
- prompt = translated_prompt
781
-
782
  if seed == 0:
783
  seed = int(random.random() * 1000000)
784
 
785
  device = "cuda" if torch.cuda.is_available() else "cpu"
786
  torch_device = torch.device(device)
787
-
788
-
789
 
790
  global model, model_zero_init
791
  if not model_zero_init:
@@ -802,10 +841,11 @@ def generate_image(
802
  height = init_image.shape[-2]
803
  width = init_image.shape[-1]
804
  init_image = ae.encode(init_image.to(torch_device).to(torch.bfloat16)).latent_dist.sample()
805
- init_image = (init_image - ae.config.shift_factor) * ae.config.scaling_factor
806
 
807
  generator = torch.Generator(device=device).manual_seed(seed)
808
- x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator)
 
809
 
810
  num_steps = inference_steps
811
  timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
@@ -816,12 +856,9 @@ def generate_image(
816
  timesteps = timesteps[t_idx:]
817
  x = t * x + (1.0 - t) * init_image.to(x.dtype)
818
 
819
- inp = prepare(t5=t5, clip=clip, img=x, prompt=prompt)
820
  x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
821
 
822
- # with profile(activities=[ProfilerActivity.CPU],record_shapes=True,profile_memory=True) as prof:
823
- # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
824
-
825
  x = unpack(x.float(), height, width)
826
  with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
827
  x = x = (x / ae.config.scaling_factor) + ae.config.shift_factor
@@ -831,22 +868,33 @@ def generate_image(
831
  x = rearrange(x[0], "c h w -> h w c")
832
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
833
 
834
-
835
  return img, seed, translated_prompt
836
-
 
837
  css = """
838
  footer {
839
  visibility: hidden;
840
  }
841
  """
842
 
843
-
844
  def create_demo():
845
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
846
-
847
  with gr.Row():
848
  with gr.Column():
849
- prompt = gr.Textbox(label="Prompt(한글 가능)", value="A cute and fluffy golden retriever puppy sitting upright, holding a neatly designed white sign with bold, colorful lettering that reads 'Have a Happy Day!' in cheerful fonts. The puppy has expressive, sparkling eyes, a happy smile, and fluffy ears slightly flopped. The background is a vibrant and sunny meadow with soft-focus flowers, glowing sunlight filtering through the trees, and a warm golden glow that enhances the joyful atmosphere. The sign is framed with small decorative flowers, adding a charming and wholesome touch. Ensure the text on the sign is clear and legible.")
 
 
 
 
 
 
 
 
 
 
 
 
 
850
 
851
  width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=768)
852
  height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=768)
@@ -861,13 +909,21 @@ def create_demo():
861
  seed = gr.Number(label="Seed", precision=-1)
862
  do_img2img = gr.Checkbox(label="Image to Image", value=False)
863
  init_image = gr.Image(label="Input Image", visible=False)
864
- image2image_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Noising strength", value=0.8, visible=False)
 
 
 
 
 
 
 
865
  resize_img = gr.Checkbox(label="Resize image", value=True, visible=False)
866
  generate_button = gr.Button("Generate")
867
 
868
  with gr.Column():
869
  output_image = gr.Image(label="Generated Image")
870
  output_seed = gr.Text(label="Used Seed")
 
871
 
872
  do_img2img.change(
873
  fn=lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)],
@@ -877,8 +933,12 @@ def create_demo():
877
 
878
  generate_button.click(
879
  fn=generate_image,
880
- inputs=[prompt, width, height, guidance, inference_steps, seed, do_img2img, init_image, image2image_strength, resize_img],
881
- outputs=[output_image, output_seed]
 
 
 
 
882
  )
883
 
884
  examples = [
 
753
  # result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
754
 
755
 
756
+ # 기존 import 문들은 유지...
757
+
758
+ # 언어 모델 딕셔너리 추가
759
+ LANGUAGE_MODELS = {
760
+ "Korean": "Helsinki-NLP/opus-mt-ko-en",
761
+ "Japanese": "Helsinki-NLP/opus-mt-ja-en",
762
+ "Chinese": "Helsinki-NLP/opus-mt-zh-en",
763
+ "Russian": "Helsinki-NLP/opus-mt-ru-en",
764
+ "Spanish": "Helsinki-NLP/opus-mt-es-en",
765
+ "French": "Helsinki-NLP/opus-mt-fr-en",
766
+ "Arabic": "Helsinki-NLP/opus-mt-ar-en",
767
+ "Bengali": "Helsinki-NLP/opus-mt-bn-en",
768
+ "Estonian": "Helsinki-NLP/opus-mt-et-en",
769
+ "Polish": "Helsinki-NLP/opus-mt-pl-en",
770
+ "Swedish": "Helsinki-NLP/opus-mt-sv-en",
771
+ "Thai": "Helsinki-NLP/opus-mt-th-en",
772
+ "Urdu": "Helsinki-NLP/opus-mt-ur-en",
773
+ "Bulgarian": "Helsinki-NLP/opus-mt-bg-en",
774
+ "Catalan": "Helsinki-NLP/opus-mt-ca-en",
775
+ "Czech": "Helsinki-NLP/opus-mt-cs-en",
776
+ "Azerbaijani": "Helsinki-NLP/opus-mt-az-en",
777
+ "Basque": "Helsinki-NLP/opus-mt-bat-en",
778
+ "Bicolano": "Helsinki-NLP/opus-mt-bcl-en",
779
+ "Bemba": "Helsinki-NLP/opus-mt-bem-en",
780
+ "Berber": "Helsinki-NLP/opus-mt-ber-en",
781
+ "Bislama": "Helsinki-NLP/opus-mt-bi-en",
782
+ "Bantu": "Helsinki-NLP/opus-mt-bnt-en",
783
+ "Brazilian Sign Language": "Helsinki-NLP/opus-mt-bzs-en",
784
+ "Caucasian": "Helsinki-NLP/opus-mt-cau-en",
785
+ "Cebuano": "Helsinki-NLP/opus-mt-ceb-en",
786
+ "Celtic": "Helsinki-NLP/opus-mt-cel-en",
787
+ "Chuukese": "Helsinki-NLP/opus-mt-chk-en",
788
+ "Creoles and pidgins (French)": "Helsinki-NLP/opus-mt-cpf-en",
789
+ "Seychelles Creole": "Helsinki-NLP/opus-mt-crs-en",
790
+ "American Sign Language": "Helsinki-NLP/opus-mt-ase-en",
791
+ "Artificial Language": "Helsinki-NLP/opus-mt-art-en",
792
+ "Atlantic-Congo": "Helsinki-NLP/opus-mt-alv-en",
793
+ "Afroasiatic": "Helsinki-NLP/opus-mt-afa-en",
794
+ "Afrikaans": "Helsinki-NLP/opus-mt-af-en",
795
+ "Austroasiatic": "Helsinki-NLP/opus-mt-aav-en"
796
+ }
797
+
798
+ # 번역기 딕셔너리를 저장할 전역 변수
799
+ translators = {}
800
+
801
+ def get_translator(language):
802
+ """필요할 때만 번역기를 로드하는 지연 초기화 함수"""
803
+ if language not in translators and language in LANGUAGE_MODELS:
804
+ translators[language] = pipeline("translation", model=LANGUAGE_MODELS[language])
805
+ return translators.get(language)
806
+
807
  @spaces.GPU
808
  @torch.no_grad()
809
  def generate_image(
810
  prompt, width, height, guidance, inference_steps, seed,
811
  do_img2img, init_image, image2image_strength, resize_img,
812
+ selected_language="Auto",
813
  progress=gr.Progress(track_tqdm=True),
814
  ):
815
  translated_prompt = prompt
816
 
817
+ if selected_language != "Auto":
818
+ translator = get_translator(selected_language)
819
+ if translator:
820
+ translated_prompt = translator(prompt, max_length=512)[0]['translation_text']
821
+ print(f"Translated from {selected_language}: {translated_prompt}")
 
822
 
 
 
 
 
 
 
 
 
 
 
823
  if seed == 0:
824
  seed = int(random.random() * 1000000)
825
 
826
  device = "cuda" if torch.cuda.is_available() else "cpu"
827
  torch_device = torch.device(device)
 
 
828
 
829
  global model, model_zero_init
830
  if not model_zero_init:
 
841
  height = init_image.shape[-2]
842
  width = init_image.shape[-1]
843
  init_image = ae.encode(init_image.to(torch_device).to(torch.bfloat16)).latent_dist.sample()
844
+ init_image = (init_image - ae.config.shift_factor) * ae.config.scaling_factor
845
 
846
  generator = torch.Generator(device=device).manual_seed(seed)
847
+ x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16),
848
+ device=device, dtype=torch.bfloat16, generator=generator)
849
 
850
  num_steps = inference_steps
851
  timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
 
856
  timesteps = timesteps[t_idx:]
857
  x = t * x + (1.0 - t) * init_image.to(x.dtype)
858
 
859
+ inp = prepare(t5=t5, clip=clip, img=x, prompt=translated_prompt)
860
  x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
861
 
 
 
 
862
  x = unpack(x.float(), height, width)
863
  with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
864
  x = x = (x / ae.config.scaling_factor) + ae.config.shift_factor
 
868
  x = rearrange(x[0], "c h w -> h w c")
869
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
870
 
 
871
  return img, seed, translated_prompt
872
+
873
+
874
  css = """
875
  footer {
876
  visibility: hidden;
877
  }
878
  """
879
 
 
880
  def create_demo():
881
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
 
882
  with gr.Row():
883
  with gr.Column():
884
+ # 언어 선택 드롭다운 추가
885
+ language_selector = gr.Dropdown(
886
+ choices=["Auto"] + list(LANGUAGE_MODELS.keys()),
887
+ value="Auto",
888
+ label="Input language_selector = gr.Dropdown(
889
+ choices=["Auto"] + list(LANGUAGE_MODELS.keys()),
890
+ value="Auto",
891
+ label="Input Language"
892
+ )
893
+
894
+ prompt = gr.Textbox(
895
+ label="Prompt (Multi-language Support)",
896
+ value="A cute and fluffy golden retriever puppy sitting upright, holding a neatly designed white sign with bold, colorful lettering that reads 'Have a Happy Day!' in cheerful fonts. The puppy has expressive, sparkling eyes, a happy smile, and fluffy ears slightly flopped. The background is a vibrant and sunny meadow with soft-focus flowers, glowing sunlight filtering through the trees, and a warm golden glow that enhances the joyful atmosphere. The sign is framed with small decorative flowers, adding a charming and wholesome touch. Ensure the text on the sign is clear and legible."
897
+ )
898
 
899
  width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=768)
900
  height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=768)
 
909
  seed = gr.Number(label="Seed", precision=-1)
910
  do_img2img = gr.Checkbox(label="Image to Image", value=False)
911
  init_image = gr.Image(label="Input Image", visible=False)
912
+ image2image_strength = gr.Slider(
913
+ minimum=0.0,
914
+ maximum=1.0,
915
+ step=0.01,
916
+ label="Noising strength",
917
+ value=0.8,
918
+ visible=False
919
+ )
920
  resize_img = gr.Checkbox(label="Resize image", value=True, visible=False)
921
  generate_button = gr.Button("Generate")
922
 
923
  with gr.Column():
924
  output_image = gr.Image(label="Generated Image")
925
  output_seed = gr.Text(label="Used Seed")
926
+ translated_prompt = gr.Text(label="Translated Prompt")
927
 
928
  do_img2img.change(
929
  fn=lambda x: [gr.update(visible=x), gr.update(visible=x), gr.update(visible=x)],
 
933
 
934
  generate_button.click(
935
  fn=generate_image,
936
+ inputs=[
937
+ prompt, width, height, guidance, inference_steps, seed,
938
+ do_img2img, init_image, image2image_strength, resize_img,
939
+ language_selector
940
+ ],
941
+ outputs=[output_image, output_seed, translated_prompt]
942
  )
943
 
944
  examples = [