ginipick commited on
Commit
429cbeb
·
verified ·
1 Parent(s): 67016e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -86
app.py CHANGED
@@ -10,7 +10,6 @@ from dataclasses import dataclass
10
  import math
11
  from typing import Callable
12
 
13
-
14
  from tqdm import tqdm
15
  import bitsandbytes as bnb
16
  from bitsandbytes.nn.modules import Params4bit, QuantState
@@ -25,6 +24,9 @@ from transformers import T5EncoderModel, T5Tokenizer
25
  # from optimum.quanto import freeze, qfloat8, quantize
26
  from transformers import pipeline
27
 
 
 
 
28
  class HFEmbedder(nn.Module):
29
  def __init__(self, version: str, max_length: int, **hf_kwargs):
30
  super().__init__()
@@ -747,48 +749,10 @@ model = Flux().to(dtype=torch.bfloat16, device="cuda")
747
  result = model.load_state_dict(sd)
748
  model_zero_init = False
749
 
 
 
750
 
751
 
752
- ko_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
753
- ja_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en")
754
- zh_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en")
755
-
756
-
757
- from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
758
-
759
- def translate_text(text):
760
- try:
761
- # M2M100은 다국어 번역을 한 모델로 처리할 수 있습니다
762
- model_name = "facebook/m2m100_418M"
763
- tokenizer = M2M100Tokenizer.from_pretrained(model_name)
764
- model = M2M100ForConditionalGeneration.from_pretrained(model_name).to(device)
765
-
766
- # 언어 감지
767
- def detect_language(text):
768
- if any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in text):
769
- return 'ko'
770
- elif any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' for c in text):
771
- return 'ja'
772
- elif any('\u4e00' <= c <= '\u9fff' for c in text):
773
- return 'zh'
774
- return None
775
-
776
- src_lang = detect_language(text)
777
- if src_lang is None:
778
- return text
779
-
780
- tokenizer.src_lang = src_lang
781
- encoded = tokenizer(text, return_tensors="pt").to(device)
782
- generated_tokens = model.generate(
783
- **encoded,
784
- forced_bos_token_id=tokenizer.get_lang_id("en"),
785
- max_length=128
786
- )
787
- return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
788
- except Exception as e:
789
- print(f"Translation error: {e}")
790
- return text
791
-
792
  @spaces.GPU
793
  @torch.no_grad()
794
  def generate_image(
@@ -796,36 +760,24 @@ def generate_image(
796
  do_img2img, init_image, image2image_strength, resize_img,
797
  progress=gr.Progress(track_tqdm=True),
798
  ):
799
- translated_prompt = translate_text(prompt)
800
- if translated_prompt != prompt:
801
- print(f"Translated prompt: {translated_prompt}")
802
- prompt = translated_prompt
803
-
804
- if seed == 0:
805
- seed = int(random.random() * 1000000)
806
- def translate_text(text, src_lang, model_name):
807
- try:
808
- tokenizer = MarianTokenizer.from_pretrained(model_name)
809
- model = MarianMTModel.from_pretrained(model_name)
810
- model = model.to(device)
811
 
812
- inputs = tokenizer(text, return_tensors="pt", padding=True).to(device)
813
- translated = model.generate(**inputs)
814
- translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
815
- return translated_text
816
- except:
817
- return text # 번역 실패시 원본 텍스트 반환
818
-
819
- # 기존의 translator 정의 부분을 삭제하고 아래 코드로 대체
820
- def translate_if_needed(prompt):
821
- if any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in prompt): # Korean
822
- return translate_text(prompt, 'ko', 'Helsinki-NLP/opus-mt-ko-en')
823
- elif any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' for c in prompt): # Japanese
824
- return translate_text(prompt, 'ja', 'Helsinki-NLP/opus-mt-ja-en')
825
- elif any('\u4e00' <= c <= '\u9fff' for c in prompt): # Chinese
826
- return translate_text(prompt, 'zh', 'Helsinki-NLP/opus-mt-zh-en')
827
- return prompt
828
-
829
 
830
  if seed == 0:
831
  seed = int(random.random() * 1000000)
@@ -888,12 +840,13 @@ footer {
888
  }
889
  """
890
 
 
891
  def create_demo():
892
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
893
 
894
  with gr.Row():
895
  with gr.Column():
896
- prompt = gr.Textbox(label="Prompt(한글/일본어/중국어 가능)", value="A cute and fluffy golden retriever puppy sitting upright, holding a neatly designed white sign with bold, colorful lettering that reads 'Have a Happy Day!' in cheerful fonts. The puppy has expressive, sparkling eyes, a happy smile, and fluffy ears slightly flopped. The background is a vibrant and sunny meadow with soft-focus flowers, glowing sunlight filtering through the trees, and a warm golden glow that enhances the joyful atmosphere. The sign is framed with small decorative flowers, adding a charming and wholesome touch. Ensure the text on the sign is clear and legible.")
897
 
898
  width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=768)
899
  height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=768)
@@ -922,27 +875,19 @@ def create_demo():
922
  outputs=[init_image, image2image_strength, resize_img]
923
  )
924
 
925
- gr.Examples(
926
- examples=[
927
- ["A magical fairy garden with glowing mushrooms and floating lanterns", 768, 768, 3.5, 30, 0, False, None, 0.8, True], # English
928
- ["아름다운 벚꽃이 흩날리는 한옥 정원에서 한복을 입은 소녀", 768, 768, 3.5, 30, 0, False, None, 0.8, True], # Korean
929
- ["夕暮れの富士山と桜の木の下で休んでいる可愛い柴犬", 768, 768, 3.5, 30, 0, False, None, 0.8, True], # Japanese
930
- ["古老的中国庭园里,一只熊猫正在竹林中悠闲地吃着竹子", 768, 768, 3.5, 30, 0, False, None, 0.8, True] # Chinese
931
- ],
932
- inputs=[prompt, width, height, guidance, inference_steps, seed, do_img2img, init_image, image2image_strength, resize_img],
933
- outputs=[output_image, output_seed],
934
- fn=generate_image,
935
- cache_examples=True
936
- )
937
-
938
  generate_button.click(
939
  fn=generate_image,
940
  inputs=[prompt, width, height, guidance, inference_steps, seed, do_img2img, init_image, image2image_strength, resize_img],
941
  outputs=[output_image, output_seed]
942
  )
 
 
 
 
 
 
943
 
944
  return demo
945
 
946
  if __name__ == "__main__":
947
- demo = create_demo()
948
- demo.launch()
 
10
  import math
11
  from typing import Callable
12
 
 
13
  from tqdm import tqdm
14
  import bitsandbytes as bnb
15
  from bitsandbytes.nn.modules import Params4bit, QuantState
 
24
  # from optimum.quanto import freeze, qfloat8, quantize
25
  from transformers import pipeline
26
 
27
+ ko_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
28
+ ja_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en")
29
+
30
  class HFEmbedder(nn.Module):
31
  def __init__(self, version: str, max_length: int, **hf_kwargs):
32
  super().__init__()
 
749
  result = model.load_state_dict(sd)
750
  model_zero_init = False
751
 
752
+ # model = Flux().to(dtype=torch.bfloat16, device="cuda")
753
+ # result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
754
 
755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
756
  @spaces.GPU
757
  @torch.no_grad()
758
  def generate_image(
 
760
  do_img2img, init_image, image2image_strength, resize_img,
761
  progress=gr.Progress(track_tqdm=True),
762
  ):
763
+ translated_prompt = prompt
764
+
765
+ # 한글 또는 일본어 문자 감지
766
+ def contains_korean(text):
767
+ return any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in text)
 
 
 
 
 
 
 
768
 
769
+ def contains_japanese(text):
770
+ return any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' or '\u4E00' <= c <= '\u9FFF' for c in text)
771
+
772
+ # 한글이나 일본어가 있으면 번역
773
+ if contains_korean(prompt):
774
+ translated_prompt = ko_translator(prompt, max_length=512)[0]['translation_text']
775
+ print(f"Translated Korean prompt: {translated_prompt}")
776
+ prompt = translated_prompt
777
+ elif contains_japanese(prompt):
778
+ translated_prompt = ja_translator(prompt, max_length=512)[0]['translation_text']
779
+ print(f"Translated Japanese prompt: {translated_prompt}")
780
+ prompt = translated_prompt
 
 
 
 
 
781
 
782
  if seed == 0:
783
  seed = int(random.random() * 1000000)
 
840
  }
841
  """
842
 
843
+
844
  def create_demo():
845
  with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
846
 
847
  with gr.Row():
848
  with gr.Column():
849
+ prompt = gr.Textbox(label="Prompt(한글 가능)", value="A cute and fluffy golden retriever puppy sitting upright, holding a neatly designed white sign with bold, colorful lettering that reads 'Have a Happy Day!' in cheerful fonts. The puppy has expressive, sparkling eyes, a happy smile, and fluffy ears slightly flopped. The background is a vibrant and sunny meadow with soft-focus flowers, glowing sunlight filtering through the trees, and a warm golden glow that enhances the joyful atmosphere. The sign is framed with small decorative flowers, adding a charming and wholesome touch. Ensure the text on the sign is clear and legible.")
850
 
851
  width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=768)
852
  height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=768)
 
875
  outputs=[init_image, image2image_strength, resize_img]
876
  )
877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
878
  generate_button.click(
879
  fn=generate_image,
880
  inputs=[prompt, width, height, guidance, inference_steps, seed, do_img2img, init_image, image2image_strength, resize_img],
881
  outputs=[output_image, output_seed]
882
  )
883
+
884
+ examples = [
885
+ "a tiny astronaut hatching from an egg on the moon",
886
+ "a cat holding a sign that says hello world",
887
+ "an anime illustration of a wiener schnitzel",
888
+ ]
889
 
890
  return demo
891
 
892
  if __name__ == "__main__":
893
+ demo = create_demo()