Ryukijano commited on
Commit
8c5c8af
·
1 Parent(s): e41b6f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -27
app.py CHANGED
@@ -1,35 +1,34 @@
1
  import gradio as gr
2
- from transformers import FlaxAutoModel, AutoTokenizer
3
 
4
- def load_model_and_tokenizer(model_name):
5
- model = FlaxAutoModel.from_pretrained(model_name)
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- return model, tokenizer
 
 
8
 
9
- def generate_image(model, tokenizer, prompt):
10
- # Process the input prompt and generate the output image using the model and tokenizer
11
- # ...
12
- # output_image = ... (your implementation here)
13
- return output_image
14
 
15
- def infer(prompt):
16
- model_name = "Ryukijano/controlnet-fill-circle"
17
- model, tokenizer = load_model_and_tokenizer(model_name)
18
- output_image = generate_image(model, tokenizer, prompt)
19
- return output_image
 
20
 
21
- iface = gr.Interface(
22
- fn=infer,
23
- inputs=["text"],
24
- outputs="image",
25
- title="ControlNet Fill Circle",
26
- description="This is a demo of ControlNet Fill Circle.",
27
- examples=[
28
- ["red circle with blue background"],
29
- ["cyan circle with brown floral background"]
30
- ],
31
- theme="gradio/soft",
32
- )
33
 
34
- iface.launch()
 
 
 
 
 
 
 
 
 
35
 
 
1
  import gradio as gr
2
+ from transformers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline, AutoTokenizer
3
 
4
+ def load_model(model_name):
 
5
  tokenizer = AutoTokenizer.from_pretrained(model_name)
6
+ controlnet = FlaxControlNetModel.from_pretrained(model_name)
7
+ pipeline = FlaxStableDiffusionControlNetPipeline.from_pretrained(model_name)
8
+ return tokenizer, controlnet, pipeline
9
 
10
+ model_name = "Ryukijano/controlnet-fill-circle"
11
+ tokenizer, controlnet, pipeline = load_model(model_name)
 
 
 
12
 
13
+ def infer_fill_circle(prompt, negative_prompt, image):
14
+ # Your inference function for fill circle control
15
+ inputs = tokenizer(prompt, return_tensors="jax")
16
+ # Implement your image preprocessing here
17
+ outputs = pipeline.generate(inputs, image)
18
+ return outputs
19
 
20
+ with gr.Blocks(theme='gradio/soft') as demo:
21
+ gr.Markdown("## Stable Diffusion with Fill Circle Control")
22
+ gr.Markdown("In this app, you can find the ControlNet with Fill Circle control.")
 
 
 
 
 
 
 
 
 
23
 
24
+ with gr.Tab("ControlNet Fill Circle"):
25
+ prompt_input_fill_circle = gr.Textbox(label="Prompt")
26
+ negative_prompt_fill_circle = gr.Textbox(label="Negative Prompt")
27
+ fill_circle_input = gr.Image(label="Input Image")
28
+ fill_circle_output = gr.Image(label="Output Image")
29
+ submit_btn = gr.Button(value="Submit")
30
+ fill_circle_inputs = [prompt_input_fill_circle, negative_prompt_fill_circle, fill_circle_input]
31
+ submit_btn.click(fn=infer_fill_circle, inputs=fill_circle_inputs, outputs=[fill_circle_output])
32
+
33
+ demo.launch()
34