Spaces:
Running
on
Zero
Running
on
Zero
"""A naive baseline method: just pass the full expression to CLIP.""" | |
from overrides import overrides | |
from typing import Dict, Any, List | |
import numpy as np | |
import torch | |
import spacy | |
from argparse import Namespace | |
from .ref_method import RefMethod | |
from lattice import Product as L | |
class Baseline(RefMethod): | |
"""CLIP-only baseline where each box is evaluated with the full expression.""" | |
nlp = spacy.load('en_core_web_sm') | |
def __init__(self, args: Namespace): | |
self.args = args | |
self.box_area_threshold = args.box_area_threshold | |
self.batch_size = args.batch_size | |
self.batch = [] | |
def execute(self, caption: str, env: "Environment") -> Dict[str, Any]: | |
chunk_texts = self.get_chunk_texts(caption) | |
probs = env.filter(caption, area_threshold = self.box_area_threshold, softmax=True) | |
if self.args.baseline_head: | |
probs2 = env.filter(chunk_texts[0], area_threshold = self.box_area_threshold, softmax=True) | |
probs = L.meet(probs, probs2) | |
pred = np.argmax(probs) | |
return { | |
"probs": probs, | |
"pred": pred, | |
"box": env.boxes[pred], | |
} | |
def get_chunk_texts(self, expression: str) -> List: | |
doc = self.nlp(expression) | |
head = None | |
for token in doc: | |
if token.head.i == token.i: | |
head = token | |
break | |
head_chunk = None | |
chunk_texts = [] | |
for chunk in doc.noun_chunks: | |
if head.i >= chunk.start and head.i < chunk.end: | |
head_chunk = chunk.text | |
chunk_texts.append(chunk.text) | |
if head_chunk is None: | |
if len(list(doc.noun_chunks)) > 0: | |
head_chunk = list(doc.noun_chunks)[0].text | |
else: | |
head_chunk = expression | |
return [head_chunk] + [txt for txt in chunk_texts if txt != head_chunk] | |