# CountGD - Multimodela open-world object counting



## Setup

The following cells will setup the runtime environment with the following

- Mount Google Drive
- Install dependencies for running the model
- Load the model into memory

### Mount Google Drive (if running on colab)

The following bit of code will mount your Google Drive folder at `/content/drive`, allowing you to process files directly from it as well as store the results alongside it.

Once you execute the next cell, you will be requested to share access with the notebook. Please follow the instructions on screen to do so.
If you are not running this on colab, you will still be able to use the files available on your environment.

In [1]:
# Check if running colab
import logging

logging.basicConfig(
 level=logging.INFO,
 format='%(asctime)s %(levelname)-8s %(name)s %(message)s'
)
try:
 import google.colab
 RUNNING_IN_COLAB = True
except:
 RUNNING_IN_COLAB = False

if RUNNING_IN_COLAB:
 from google.colab import drive
 drive.mount('/content/drive')

from IPython.core.magic import register_cell_magic
from IPython import get_ipython
@register_cell_magic
def skip_if(line, cell):
 if eval(line):
 return
 get_ipython().run_cell(cell)


%env RUNNING_IN_COLAB {RUNNING_IN_COLAB}


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
env: RUNNING_IN_COLAB=True


### Install Dependencies

The environment will be setup with the code, models and required dependencies.

In [3]:
%%bash

set -euxo pipefail

if [ "${RUNNING_IN_COLAB}" == "True" ]; then
 echo "Downloading the repository..."
 if [ ! -d /content/countgd ]; then
 git clone "https://huggingface.co/spaces/nikigoli/countgd" /content/countgd
 fi
 cd /content/countgd
 git fetch origin refs/pr/5:refs/remotes/origin/pr/5
 git checkout pr/5
else
 # TODO check if cwd is the correct git repo
 # If users use vscode, then we set the default start directory to root of the repo
 echo "Running in $(pwd)"
fi

# TODO check for gcc-11 or above

# Install pip packages
pip install --upgrade pip setuptools wheel
pip install -r requirements.txt

# Compile modules
export CUDA_HOME=/usr/local/cuda/
cd models/GroundingDINO/ops
python3 setup.py build
pip install .
python3 test.py

Downloading the repository...
Branch 'pr/5' set up to track remote branch 'pr/5' from 'origin'.
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu121
Collecting filetype (from -r requirements.txt (line 15))
 Downloading filetype-1.2.0-py2.py3-none-any.whl.metadata (6.5 kB)
Downloading filetype-1.2.0-py2.py3-none-any.whl (19 kB)
Installing collected packages: filetype
Successfully installed filetype-1.2.0
inside get_extensions
/usr/local/cuda/
running build
running build_py
copying modules/ms_deform_attn.py -> build/lib.linux-x86_64-cpython-311/modules
copying modules/__init__.py -> build/lib.linux-x86_64-cpython-311/modules
copying functions/__init__.py -> build/lib.linux-x86_64-cpython-311/functions
copying functions/ms_deform_attn_func.py -> build/lib.linux-x86_64-cpython-311/functions
running build_ext
Processing /content/countgd/models/GroundingDINO/ops
 Preparing metadata (setup.py): started
 Preparing metadata (setup.py): finished with status 'done'


+ '[' True == True ']'
+ echo 'Downloading the repository...'
+ '[' '!' -d /content/countgd ']'
+ cd /content/countgd
+ git fetch origin refs/pr/5:refs/remotes/origin/pr/5
From https://huggingface.co/spaces/nikigoli/countgd
 * [new ref] refs/pr/5 -> origin/pr/5
+ git checkout pr/5
Switched to a new branch 'pr/5'
+ pip install --upgrade pip setuptools wheel
+ pip install -r requirements.txt
+ export CUDA_HOME=/usr/local/cuda/
+ CUDA_HOME=/usr/local/cuda/
+ cd models/GroundingDINO/ops
+ python3 setup.py build
+ pip install .
+ python3 test.py


In [4]:
%cd {"/content/countgd" if RUNNING_IN_COLAB else '.'}

/content/countgd


## Inference

### Loading the model

In [11]:
import app
import importlib
importlib.reload(app)
from app import (
 build_model_and_transforms,
 get_device,
 get_args_parser,
 generate_heatmap,
 predict,
)
args = get_args_parser().parse_args([])
device = get_device()
model, transform = build_model_and_transforms(args)
model = model.to(device)

run = lambda image, text: predict(model, transform, image, text, None, device)
get_output = lambda image, boxes: (len(boxes), generate_heatmap(image, boxes))


Some weights of BertModel were not initialized from the model checkpoint at checkpoints/bert-base-uncased and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


final text_encoder_type: checkpoints/bert-base-uncased
load tokenizer done.
final text_encoder_type: checkpoints/bert-base-uncased
load tokenizer done.


### Input / Output Utils

Helper functions for reading / writing to zipfiles and csv

In [17]:
import io
import csv
from pathlib import Path
from contextlib import contextmanager
import zipfile
import filetype
from PIL import Image
logger = logging.getLogger()

def images_from_zipfile(p: Path):
 if not zipfile.is_zipfile(p):
 raise ValueError(f'{p} is not a zipfile!')

 with zipfile.ZipFile(p, 'r') as zipf:
 def process_entry(info: zipfile.ZipInfo):
 with zipf.open(info) as f:
 if not filetype.is_image(f):
 logger.debug(f'Skipping file - {info.filename} as it is not an image')
 return
 # Try loading the file
 try:
 with Image.open(f) as im:
 im.load()
 return (info.filename, im)
 except:
 logger.exception(f'Error reading file {info.filename}')

 num_files = sum(1 for info in zipf.infolist() if info.is_dir() == False)
 logger.info(f'Found {num_files} file(s) in the zip')
 yield from (process_entry(info) for info in zipf.infolist() if info.is_dir() == False)

@contextmanager
def zipfile_writer(p: Path):
 with zipfile.ZipFile(p, 'w') as zipf:
 def write_output(image, image_filename):
 buf = io.BytesIO()
 image.save(buf, 'PNG')
 zipf.writestr(image_filename, buf.getvalue())
 yield write_output

@contextmanager
def csvfile_writer(p: Path):
 with p.open('w', newline='') as csvfile:
 fieldnames = ['filename', 'count']
 csv_writer = csv.DictWriter(csvfile, fieldnames = fieldnames)
 csv_writer.writeheader()

 yield csv_writer.writerow

In [15]:
from tqdm import tqdm
import os
def process_zipfile(input_zipfile: Path, text: str):
 if not input_zipfile.exists() or not input_zipfile.is_file() or not os.access(input_zipfile, os.R_OK):
 logger.error(f'Cannot open / read zipfile: {input_zipfile}. Please check if it exists')
 return

 if text == "":
 logger.error('Please provide the object you would like to count')
 return

 output_zipfile = input_zipfile.parent / f'{input_zipfile.stem}_countgd.zip'
 output_csvfile = input_zipfile.parent / f'{input_zipfile.stem}.csv'

 logger.info(f'Writing outputs to {output_zipfile.name} and {output_csvfile.name} in {input_zipfile.parent} folder')
 with zipfile_writer(output_zipfile) as add_to_zip, csvfile_writer(output_csvfile) as write_row:
 for filename, im in tqdm(images_from_zipfile(input_zipfile)):
 boxes, _ = run(im, text)
 count, heatmap = get_output(im, boxes)
 write_row({'filename': filename, 'count': count})
 add_to_zip(heatmap, filename)

### Run

Use the form on colab to set the parameters, providing the zipfile with input images and a promt text representing the object you want to count.

If you are not running on colab, change the values in the next cell

Make sure to run the cell once you change the value.

In [8]:
# @title ## Parameters { display-mode: "form", run: "auto" }
# @markdown Set the following options to pass to the CountGD Model

# @markdown ---
# @markdown ### Enter a file path to a zip:
zipfile_path = "test_images.zip" # @param {type:"string"}
# @markdown
# @markdown ### Which object would you like to count?
prompt = "strawberry" # @param {type:"string"}
# @markdown ---

In [18]:
import ipywidgets as widgets
from IPython.display import display
button = widgets.Button(description="Run")

def on_button_clicked(b):
 # Display the message within the output widget.
 process_zipfile(Path(zipfile_path), prompt)

button.on_click(on_button_clicked)
display(button)

Button(description='Run', style=ButtonStyle())

11it [00:12, 1.14s/it]
