LKCell / app.py
qingke1's picture
remove load ckpt and img from model
96ec051
raw
history blame
4.69 kB
import gradio as gr
import os, requests
import numpy as np
import torch
import cv2
from cell_segmentation.inference.inference_cellvit_experiment_pannuke import InferenceCellViTParser,InferenceCellViT
from cell_segmentation.inference.inference_cellvit_experiment_monuseg import InferenceCellViTMoNuSegParser,MoNuSegInference
## local | remote
RUN_MODE = "remote"
# if RUN_MODE != "local":
# os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/model_best.pth")
# ## examples
# os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/1.png")
# os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/2.png")
# os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/3.png")
# os.system("wget https://huggingface.co/xiazhi/LKCell-demo/resolve/main/4.png")
## step 1: set up model
device = "cpu"
## pannuke set
pannuke_parser = InferenceCellViTParser()
pannuke_configurations = pannuke_parser.parse_arguments()
pannuke_inf = InferenceCellViT(
run_dir=pannuke_configurations["run_dir"],
checkpoint_name=pannuke_configurations["checkpoint_name"],
gpu=pannuke_configurations["gpu"],
magnification=pannuke_configurations["magnification"],
)
pannuke_checkpoint = torch.load(
pannuke_inf.run_dir / pannuke_inf.checkpoint_name, map_location="cpu"
)
pannuke_model = pannuke_inf.get_model(model_type=pannuke_checkpoint["arch"])
pannuke_model.load_state_dict(pannuke_checkpoint["model_state_dict"])
# # put model in eval mode
pannuke_model.to(device)
pannuke_model.eval()
## monuseg set
monuseg_parser = InferenceCellViTMoNuSegParser()
monuseg_configurations = monuseg_parser.parse_arguments()
monuseg_inf = MoNuSegInference(
model_path=monuseg_configurations["model"],
dataset_path=monuseg_configurations["dataset"],
outdir=monuseg_configurations["outdir"],
gpu=monuseg_configurations["gpu"],
patching=monuseg_configurations["patching"],
magnification=monuseg_configurations["magnification"],
overlap=monuseg_configurations["overlap"],
)
def click_process(image_input , type_dataset):
if type_dataset == "pannuke":
pannuke_inf.run_single_image_inference(pannuke_model,image_input)
else:
monuseg_inf.run_single_image_inference(monuseg_inf.model, image_input)
image_output = cv2.imread("pred_img.png")
image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2RGB)
return image_output
demo = gr.Blocks(title="LkCell")
with demo:
gr.Markdown(value="""
**Gradio demo for LKCell: Efficient Cell Nuclei Instance Segmentation with Large Convolution Kernels**. Check our [Github Repo](https://github.com/ziwei-cui/LKCellv1) πŸ˜›.
""")
with gr.Row():
with gr.Column():
with gr.Row():
Image_input = gr.Image(type="numpy", label="Input", interactive=True,height=480)
with gr.Row():
Type_dataset = gr.Radio(choices=["pannuke", "monuseg"], label=" input image's dataset type",value="pannuke")
with gr.Column():
with gr.Row():
image_output = gr.Image(type="numpy", label="Output",height=480)
with gr.Row():
Button_run = gr.Button("πŸš€ Submit (发送) ")
clear_button = gr.ClearButton(components=[Image_input,Type_dataset,image_output],value="🧹 Clear (清陀)")
Button_run.click(fn=click_process, inputs=[Image_input, Type_dataset ], outputs=[image_output])
## guiline
gr.Markdown(value="""
πŸ””**Guideline**
1. Upload your image or select one from the examples.
2. Set up the arguments: "Type_dataset".
3. Run the Submit button to get the output.
""")
# if RUN_MODE != "local":
gr.Examples(examples=[
['1.png', "pannuke"],
['2.png', "pannuke"],
['3.png', "monuseg"],
['4.png', "monuseg"],
],
inputs=[Image_input, Type_dataset], outputs=[image_output], label="Examples")
gr.HTML(value="""
<p style="text-align:center; color:orange"> <a href='https://github.com/ziwei-cui/LKCellv1' target='_blank'>Github Repo</a></p>
""")
gr.Markdown(value="""
Template is adapted from [Here](https://huggingface.co/spaces/menghanxia/disco)
""")
if RUN_MODE == "local":
demo.launch(server_name='127.0.0.1',server_port=8003)
else:
demo.launch()