hpc-yekin
initial commit
92e0882
raw
history blame
1.93 kB
"""A naive baseline method: just pass the full expression to CLIP."""
from overrides import overrides
from typing import Dict, Any, List
import numpy as np
import torch
import spacy
from argparse import Namespace
from .ref_method import RefMethod
from lattice import Product as L
class Baseline(RefMethod):
"""CLIP-only baseline where each box is evaluated with the full expression."""
nlp = spacy.load('en_core_web_sm')
def __init__(self, args: Namespace):
self.args = args
self.box_area_threshold = args.box_area_threshold
self.batch_size = args.batch_size
self.batch = []
@overrides
def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
chunk_texts = self.get_chunk_texts(caption)
probs = env.filter(caption, area_threshold = self.box_area_threshold, softmax=True)
if self.args.baseline_head:
probs2 = env.filter(chunk_texts[0], area_threshold = self.box_area_threshold, softmax=True)
probs = L.meet(probs, probs2)
pred = np.argmax(probs)
return {
"probs": probs,
"pred": pred,
"box": env.boxes[pred],
}
def get_chunk_texts(self, expression: str) -> List:
doc = self.nlp(expression)
head = None
for token in doc:
if token.head.i == token.i:
head = token
break
head_chunk = None
chunk_texts = []
for chunk in doc.noun_chunks:
if head.i >= chunk.start and head.i < chunk.end:
head_chunk = chunk.text
chunk_texts.append(chunk.text)
if head_chunk is None:
if len(list(doc.noun_chunks)) > 0:
head_chunk = list(doc.noun_chunks)[0].text
else:
head_chunk = expression
return [head_chunk] + [txt for txt in chunk_texts if txt != head_chunk]