import gradio as gr import torch from carvekit.api.interface import Interface from carvekit.ml.wrap.fba_matting import FBAMatting from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 from carvekit.pipelines.postprocessing import MattingMethod from carvekit.pipelines.preprocessing import PreprocessingStub from carvekit.trimap.generator import TrimapGenerator device = 'cuda' if torch.cuda.is_available() else 'cpu' # Check doc strings for more information seg_net = TracerUniversalB7(device=device, batch_size=1) fba = FBAMatting(device=device, input_tensor_size=2048, batch_size=1) trimap = TrimapGenerator() preprocessing = PreprocessingStub() postprocessing = MattingMethod(matting_module=fba, trimap_generator=trimap, device=device) interface = Interface(pre_pipe=preprocessing, post_pipe=postprocessing, seg_pipe=seg_net) def generate_trimap(original): mask = seg_net([original]) return trimap(original_image=original, mask=mask[0]) def predict(image): return interface([image])[0] footer = r"""
CarveKit
Demo based on CarveKit
""" with gr.Blocks(title="CarveKit") as app: gr.Markdown("

CarveKit

") gr.HTML("

High-quality image background removal

") with gr.Tabs() as tabs: with gr.TabItem("Remove background", id=0): with gr.Row().style(equal_height=False): with gr.Column(): input_img = gr.Image(type="pil", label="Input image") run_btn = gr.Button(variant="primary") with gr.Column(): output_img = gr.Image(type="pil", label="result") run_btn.click(predict, [input_img], [output_img]) with gr.TabItem("Generate trimap", id=1): with gr.Row().style(equal_height=False): with gr.Column(): trimap_input = gr.Image(type="pil", label="Input image") trimap_btn = gr.Button(variant="primary") with gr.Column(): trimap_output = gr.Image(type="pil", label="result") trimap_btn.click(generate_trimap, [trimap_input], [trimap_output]) # with gr.Row(): # examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)] # examples = gr.Dataset(components=[input_img], samples=examples_data) # examples.click(lambda x: x[0], [examples], [input_img]) with gr.Row(): gr.HTML(footer) app.launch(share=False, debug=True, enable_queue=True, show_error=True)