openfree commited on
Commit
1ab8a04
·
verified ·
1 Parent(s): 052f42d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -4
app.py CHANGED
@@ -25,8 +25,39 @@ huggingface_token = os.getenv("HF_TOKEN")
25
 
26
  translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu")
27
 
 
 
 
 
 
28
 
 
 
 
 
 
 
 
 
 
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  #Load prompts for randomization
31
  df = pd.read_csv('prompts.csv', header=None)
32
  prompt_values = df.values.flatten()
@@ -826,6 +857,38 @@ input:focus, textarea:focus {
826
  margin: 0 !important; /* auto에서 0으로 변경 */
827
  margin-left: 20px !important; /* 왼쪽 여백 추가 */
828
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829
  '''
830
 
831
  with gr.Blocks(theme=custom_theme, css=css, delete_cache=(60, 3600)) as app:
@@ -857,11 +920,30 @@ with gr.Blocks(theme=custom_theme, css=css, delete_cache=(60, 3600)) as app:
857
 
858
  with gr.Tab(label="Generate"):
859
  # Prompt and Generate Button
 
860
  with gr.Row():
861
  with gr.Column(scale=3):
862
- prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
 
 
 
 
863
  with gr.Column(scale=1):
864
- generate_button = gr.Button("Generate", variant="primary", elem_classes=["button_total"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
865
 
866
  # LoRA Selection Area
867
  with gr.Row(elem_id="loaded_loras"):
@@ -1009,13 +1091,21 @@ with gr.Blocks(theme=custom_theme, css=css, delete_cache=(60, 3600)) as app:
1009
  selected_indices, lora_scale_1, lora_scale_2, lora_scale_3,
1010
  lora_image_1, lora_image_2, lora_image_3]
1011
  )
 
 
 
 
 
 
 
1012
 
 
1013
  gr.on(
1014
  triggers=[generate_button.click, prompt.submit],
1015
  fn=run_lora,
1016
  inputs=[prompt, input_image, image_strength, cfg_scale, steps,
1017
- selected_indices, lora_scale_1, lora_scale_2, lora_scale_3,
1018
- randomize_seed, seed, width, height, loras_state],
1019
  outputs=[result, seed, progress_bar]
1020
  ).then(
1021
  fn=lambda x, history: update_history(x, history) if x is not None else history,
@@ -1023,6 +1113,7 @@ with gr.Blocks(theme=custom_theme, css=css, delete_cache=(60, 3600)) as app:
1023
  outputs=history_gallery
1024
  )
1025
 
 
1026
  if __name__ == "__main__":
1027
  app.queue(max_size=20)
1028
  app.launch(debug=True)
 
25
 
26
  translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu")
27
 
28
+ # Hugging Face 클라이언트 초기화
29
+ client = InferenceClient(
30
+ model="CohereForAI/c4ai-command-r-plus-08-2024",
31
+ token=huggingface_token
32
+ )
33
 
34
+ def augment_prompt(prompt):
35
+ try:
36
+ if not prompt:
37
+ return "", gr.Warning("Please enter a prompt first")
38
+
39
+ system_prompt = """You are an expert at writing detailed image generation prompts.
40
+ Enhance the given prompt by adding more descriptive details, artistic style, and technical aspects
41
+ that would help in generating better images. Keep the core meaning but make it more comprehensive."""
42
+
43
+ full_prompt = f"{system_prompt}\nOriginal prompt: {prompt}\nEnhanced prompt:"
44
 
45
+ # 클라이언트 인퍼런스 방식으로 호출
46
+ response = client.text_generation(
47
+ full_prompt,
48
+ max_new_tokens=300,
49
+ temperature=0.7,
50
+ top_p=0.95,
51
+ repetition_penalty=1.1,
52
+ do_sample=True
53
+ )
54
+
55
+ enhanced_prompt = response.strip()
56
+ return enhanced_prompt
57
+ except Exception as e:
58
+ print(f"Error in prompt augmentation: {str(e)}")
59
+ return prompt
60
+
61
  #Load prompts for randomization
62
  df = pd.read_csv('prompts.csv', header=None)
63
  prompt_values = df.values.flatten()
 
857
  margin: 0 !important; /* auto에서 0으로 변경 */
858
  margin-left: 20px !important; /* 왼쪽 여백 추가 */
859
  }
860
+
861
+ .enhance-button, .generate-button {
862
+ flex: 1 !important;
863
+ min-width: 120px !important;
864
+ height: 40px !important;
865
+ margin: 0 5px !important;
866
+ }
867
+
868
+ .enhance-button {
869
+ background-color: #6b7280 !important;
870
+ color: white !important;
871
+ }
872
+
873
+ .enhance-button:hover {
874
+ background-color: #4b5563 !important;
875
+ }
876
+ .generate-button {
877
+ background-color: #3b82f6 !important;
878
+ color: white !important;
879
+ }
880
+
881
+ .generate-button:hover {
882
+ background-color: #2563eb !important;
883
+ }
884
+
885
+ /* 버튼 컨테이너 스타일링 */
886
+ .button-container {
887
+ display: flex !important;
888
+ justify-content: space-between !important;
889
+ gap: 10px !important;
890
+ width: 100% !important;
891
+ }
892
  '''
893
 
894
  with gr.Blocks(theme=custom_theme, css=css, delete_cache=(60, 3600)) as app:
 
920
 
921
  with gr.Tab(label="Generate"):
922
  # Prompt and Generate Button
923
+ # Prompt and Generate Buttons 영역 수정
924
  with gr.Row():
925
  with gr.Column(scale=3):
926
+ prompt = gr.Textbox(
927
+ label="Prompt",
928
+ lines=1,
929
+ placeholder="Type a prompt after selecting a LoRA"
930
+ )
931
  with gr.Column(scale=1):
932
+ with gr.Row():
933
+ augment_button = gr.Button(
934
+ "✨프롬프트 증강",
935
+ variant="secondary",
936
+ size="sm",
937
+ scale=1,
938
+ elem_classes=["enhance-button"]
939
+ )
940
+ generate_button = gr.Button(
941
+ "🎨생성",
942
+ variant="primary",
943
+ size="sm",
944
+ scale=1,
945
+ elem_classes=["generate-button"]
946
+ )
947
 
948
  # LoRA Selection Area
949
  with gr.Row(elem_id="loaded_loras"):
 
1091
  selected_indices, lora_scale_1, lora_scale_2, lora_scale_3,
1092
  lora_image_1, lora_image_2, lora_image_3]
1093
  )
1094
+ # 기존 이벤트 핸들러들 위에 추가
1095
+ augment_button.click(
1096
+ fn=augment_prompt,
1097
+ inputs=[prompt],
1098
+ outputs=[prompt],
1099
+ api_name="enhance_prompt"
1100
+ )
1101
 
1102
+ # 기존 generate 이벤트 핸들러 유지
1103
  gr.on(
1104
  triggers=[generate_button.click, prompt.submit],
1105
  fn=run_lora,
1106
  inputs=[prompt, input_image, image_strength, cfg_scale, steps,
1107
+ selected_indices, lora_scale_1, lora_scale_2, lora_scale_3,
1108
+ randomize_seed, seed, width, height, loras_state],
1109
  outputs=[result, seed, progress_bar]
1110
  ).then(
1111
  fn=lambda x, history: update_history(x, history) if x is not None else history,
 
1113
  outputs=history_gallery
1114
  )
1115
 
1116
+
1117
  if __name__ == "__main__":
1118
  app.queue(max_size=20)
1119
  app.launch(debug=True)