Deadmon commited on
Commit
bea7193
·
verified ·
1 Parent(s): d2fcb60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -72
app.py CHANGED
@@ -10,19 +10,33 @@ import os
10
  os.environ["FAL_KEY"] = "b6fa8d06-4225-4ec3-9aaf-4d01e960d899:cc6a52d0fc818c6f892b2760fd341ee4"
11
  fal_client.api_key = os.environ["FAL_KEY"]
12
 
13
- # Model choices (base models)
14
  base_model_paths = {
15
- "Realistic Vision V4": "SG161222/Realistic_Vision_V4.0_noVAE",
16
- "Realistic Vision V6": "SG161222/Realistic_Vision_V6.0_B1_noVAE",
17
  "Deliberate": "Yntec/Deliberate",
18
- "Deliberate V2": "Yntec/Deliberate2",
19
- "Dreamshaper 8": "Lykon/dreamshaper-8",
20
- "Epic Realism": "emilianJR/epiCRealism"
21
  }
22
 
23
- async def generate_image(image_url: str, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int):
 
24
  """
25
  Submit the image generation process using the fal_client's submit method with the ip-adapter-face-id model.
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  """
27
  try:
28
  handler = fal_client.submit(
@@ -35,11 +49,13 @@ async def generate_image(image_url: str, prompt: str, negative_prompt: str, mode
35
  "seed": seed,
36
  "guidance_scale": guidance_scale,
37
  "num_inference_steps": num_inference_steps,
38
- "num_samples": num_samples,
39
  "width": width,
40
  "height": height,
41
- "base_1_5_model_repo": base_model_paths[base_model], # Base model selected by user
42
- "base_sdxl_model_repo": "SG161222/RealVisXL_V3.0", # SDXL model as default
 
 
43
  },
44
  )
45
  # Retrieve the result synchronously
@@ -53,45 +69,34 @@ async def generate_image(image_url: str, prompt: str, negative_prompt: str, mode
53
  print(f"Error generating image: {e}")
54
  return None
55
 
 
56
  def fetch_image_from_url(url: str) -> Image.Image:
57
- """
58
- Download the image from the given URL and return it as a PIL Image.
59
- """
60
  response = requests.get(url)
61
  return Image.open(io.BytesIO(response.content))
62
 
63
- async def process_inputs(image: Image.Image, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int):
64
- """
65
- Asynchronous function to handle image upload, prompt inputs and generate the final image.
66
- """
67
- # Upload the image and get a valid URL
68
  image_url = await upload_image_to_server(image)
69
 
70
  if not image_url:
71
  return None
72
 
73
- # Run the image generation
74
- image_info = await generate_image(image_url, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height)
75
 
76
  if image_info and "url" in image_info:
77
  return fetch_image_from_url(image_info["url"]), image_info # Return both the image and the metadata
78
 
79
  return None, None
80
 
 
81
  async def upload_image_to_server(image: Image.Image) -> str:
82
- """
83
- Upload an image to the fal_client and return the uploaded image URL.
84
- """
85
- # Convert PIL image to byte stream for upload
86
  byte_arr = io.BytesIO()
87
  image.save(byte_arr, format='PNG')
88
  byte_arr.seek(0)
89
 
90
- # Convert BytesIO to a file-like object that fal_client can handle
91
  with open("temp_image.png", "wb") as f:
92
  f.write(byte_arr.getvalue())
93
 
94
- # Upload the image using fal_client's asynchronous method
95
  try:
96
  upload_url = await fal_client.upload_file_async("temp_image.png")
97
  return upload_url
@@ -99,84 +104,56 @@ async def upload_image_to_server(image: Image.Image) -> str:
99
  print(f"Error uploading image: {e}")
100
  return ""
101
 
 
102
  def change_style(style):
103
- """
104
- Changes the style for 'Photorealistic' or 'Stylized' generation type.
105
- """
106
  if style == "Photorealistic":
107
  return gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0)
108
  else:
109
  return gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8)
110
 
111
- def gradio_interface(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height):
112
- """
113
- Wrapper function to run asynchronous code in a synchronous environment like Gradio.
114
- """
115
  loop = asyncio.new_event_loop()
116
  asyncio.set_event_loop(loop)
117
 
118
- # Execute the async process_inputs function
119
- result_image, image_info = loop.run_until_complete(process_inputs(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height))
 
 
120
  if result_image:
121
- # Display both the image and metadata
122
  metadata = f"File Name: {image_info['file_name']}\nFile Size: {image_info['file_size']} bytes\nDimensions: {image_info['width']}x{image_info['height']} px\nSeed: {image_info.get('seed', 'N/A')}"
123
  return result_image, metadata
124
  return None, "Error generating image"
125
 
126
- # Gradio Interface
127
  with gr.Blocks() as demo:
128
  gr.Markdown("## Image Generation with Fal API and Gradio")
129
 
130
  with gr.Row():
131
  with gr.Column():
132
- # Image input
133
  image_input = gr.Image(label="Upload Image", type="pil")
134
-
135
- # Textbox for prompt
136
  prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate", lines=2)
137
-
138
- # Textbox for negative prompt
139
  negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="Describe elements to avoid", lines=2)
140
-
141
- # Radio buttons for model type (Photorealistic or Stylized)
142
  style = gr.Radio(label="Generation type", choices=["Photorealistic", "Stylized"], value="Photorealistic")
143
-
144
- # Dropdown for selecting the base model
145
- base_model = gr.Dropdown(label="Base Model", choices=list(base_model_paths.keys()), value="Realistic Vision V4")
146
-
147
- # Seed input
148
- seed_input = gr.Number(label="Seed", value=42, precision=0)
149
-
150
- # Guidance scale slider
151
- guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.1, minimum=1, maximum=20)
152
-
153
- # Inference steps slider
154
- num_inference_steps = gr.Slider(label="Number of Inference Steps", value=50, step=1, minimum=10, maximum=100)
155
-
156
- # Samples slider
157
- num_samples = gr.Slider(label="Number of Samples", value=4, step=1, minimum=1, maximum=10)
158
-
159
- # Image dimensions sliders
160
- width = gr.Slider(label="Width", value=1024, step=64, minimum=256, maximum=1024)
161
- height = gr.Slider(label="Height", value=1024, step=64, minimum=256, maximum=1024)
162
-
163
- # Button to trigger image generation
164
  generate_button = gr.Button("Generate Image")
165
 
166
  with gr.Column():
167
- # Display generated image and metadata
168
  generated_image = gr.Image(label="Generated Image")
169
  metadata_output = gr.Textbox(label="Image Metadata", interactive=False, lines=6)
170
 
171
- # Style change functionality
172
- style.change(fn=change_style, inputs=style, outputs=[guidance_scale, num_samples, width])
173
-
174
- # Define the interaction between inputs and output
175
  generate_button.click(
176
  fn=gradio_interface,
177
- inputs=[image_input, prompt_input, negative_prompt_input, style, base_model, seed_input, guidance_scale, num_inference_steps, num_samples, width, height],
178
  outputs=[generated_image, metadata_output]
179
  )
180
 
181
- # Launch the Gradio interface
 
182
  demo.launch()
 
10
  os.environ["FAL_KEY"] = "b6fa8d06-4225-4ec3-9aaf-4d01e960d899:cc6a52d0fc818c6f892b2760fd341ee4"
11
  fal_client.api_key = os.environ["FAL_KEY"]
12
 
13
+ # Base model paths for model switching
14
  base_model_paths = {
15
+ "RealisticVisionV4": "SG161222/Realistic_Vision_V4.0_noVAE",
16
+ "RealisticVisionV6": "SG161222/Realistic_Vision_V6.0_B1_noVAE",
17
  "Deliberate": "Yntec/Deliberate",
18
+ "DeliberateV2": "Yntec/Deliberate2",
19
+ "Dreamshaper8": "Lykon/dreamshaper-8",
20
+ "EpicRealism": "emilianJR/epiCRealism"
21
  }
22
 
23
+ # Updated function to include the API call to the Fal model
24
+ async def generate_image(image_url: str, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, width: int, height: int):
25
  """
26
  Submit the image generation process using the fal_client's submit method with the ip-adapter-face-id model.
27
+ Arguments:
28
+ image_url: URL of the input image to use.
29
+ prompt: Text prompt for generating the image.
30
+ negative_prompt: Text for negative prompt to avoid unwanted characteristics in the output.
31
+ model_type: Model type to use.
32
+ base_model: Base model to use for image generation.
33
+ seed: Seed for random generation.
34
+ guidance_scale: CFG scale for how closely the model sticks to the prompt.
35
+ num_inference_steps: Number of inference steps.
36
+ width: Width of the generated image.
37
+ height: Height of the generated image.
38
+ Returns:
39
+ The URL of the generated image along with other attributes like file size, dimensions, etc., or None if failed.
40
  """
41
  try:
42
  handler = fal_client.submit(
 
49
  "seed": seed,
50
  "guidance_scale": guidance_scale,
51
  "num_inference_steps": num_inference_steps,
52
+ "num_samples": 1, # Adjusted to 1 sample
53
  "width": width,
54
  "height": height,
55
+ "face_id_det_size": 640,
56
+ "base_1_5_model_repo": base_model_paths[base_model], # Base model
57
+ "base_sdxl_model_repo": "SG161222/RealVisXL_V3.0",
58
+ "face_images_data_url": None
59
  },
60
  )
61
  # Retrieve the result synchronously
 
69
  print(f"Error generating image: {e}")
70
  return None
71
 
72
+ # Fetch the image from the given URL
73
  def fetch_image_from_url(url: str) -> Image.Image:
 
 
 
74
  response = requests.get(url)
75
  return Image.open(io.BytesIO(response.content))
76
 
77
+ # Process input images and handle the image generation
78
+ async def process_inputs(image: Image.Image, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, width: int, height: int):
 
 
 
79
  image_url = await upload_image_to_server(image)
80
 
81
  if not image_url:
82
  return None
83
 
84
+ image_info = await generate_image(image_url, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, width, height)
 
85
 
86
  if image_info and "url" in image_info:
87
  return fetch_image_from_url(image_info["url"]), image_info # Return both the image and the metadata
88
 
89
  return None, None
90
 
91
+ # Upload image to server
92
  async def upload_image_to_server(image: Image.Image) -> str:
 
 
 
 
93
  byte_arr = io.BytesIO()
94
  image.save(byte_arr, format='PNG')
95
  byte_arr.seek(0)
96
 
 
97
  with open("temp_image.png", "wb") as f:
98
  f.write(byte_arr.getvalue())
99
 
 
100
  try:
101
  upload_url = await fal_client.upload_file_async("temp_image.png")
102
  return upload_url
 
104
  print(f"Error uploading image: {e}")
105
  return ""
106
 
107
+ # Change style between Photorealistic and Stylized
108
  def change_style(style):
 
 
 
109
  if style == "Photorealistic":
110
  return gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0)
111
  else:
112
  return gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8)
113
 
114
+ # Gradio Interface
115
+ def gradio_interface(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, width, height):
 
 
116
  loop = asyncio.new_event_loop()
117
  asyncio.set_event_loop(loop)
118
 
119
+ result_image, image_info = loop.run_until_complete(
120
+ process_inputs(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, width, height)
121
+ )
122
+
123
  if result_image:
 
124
  metadata = f"File Name: {image_info['file_name']}\nFile Size: {image_info['file_size']} bytes\nDimensions: {image_info['width']}x{image_info['height']} px\nSeed: {image_info.get('seed', 'N/A')}"
125
  return result_image, metadata
126
  return None, "Error generating image"
127
 
128
+ # Main Gradio App
129
  with gr.Blocks() as demo:
130
  gr.Markdown("## Image Generation with Fal API and Gradio")
131
 
132
  with gr.Row():
133
  with gr.Column():
 
134
  image_input = gr.Image(label="Upload Image", type="pil")
 
 
135
  prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate", lines=2)
 
 
136
  negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="Describe elements to avoid", lines=2)
 
 
137
  style = gr.Radio(label="Generation type", choices=["Photorealistic", "Stylized"], value="Photorealistic")
138
+ model_type = gr.Dropdown(label="Model Type", choices=["1_5-v1", "SDXL-v2-plus", "1_5-auraface-v1"], value="SDXL-v2-plus")
139
+ base_model = gr.Dropdown(label="Base Model", choices=list(base_model_paths.keys()), value="RealisticVisionV4")
140
+ seed_input = gr.Slider(label="Seed", value=42, minimum=0, maximum=1000, step=1)
141
+ guidance_scale_input = gr.Slider(label="Guidance Scale", value=7.5, minimum=1.0, maximum=20.0, step=0.1)
142
+ num_inference_steps_input = gr.Slider(label="Inference Steps", value=50, minimum=10, maximum=100, step=1)
143
+ width_input = gr.Slider(label="Width", value=1024, minimum=512, maximum=1024, step=64)
144
+ height_input = gr.Slider(label="Height", value=1024, minimum=512, maximum=1024, step=64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  generate_button = gr.Button("Generate Image")
146
 
147
  with gr.Column():
 
148
  generated_image = gr.Image(label="Generated Image")
149
  metadata_output = gr.Textbox(label="Image Metadata", interactive=False, lines=6)
150
 
 
 
 
 
151
  generate_button.click(
152
  fn=gradio_interface,
153
+ inputs=[image_input, prompt_input, negative_prompt_input, model_type, base_model, seed_input, guidance_scale_input, num_inference_steps_input, width_input, height_input],
154
  outputs=[generated_image, metadata_output]
155
  )
156
 
157
+ style.change(fn=change_style, inputs=style, outputs=[model_type, guidance_scale_input, num_inference_steps_input])
158
+
159
  demo.launch()