hpc-yekin
initial commit
92e0882
raw
history blame
10.7 kB
"""Use spatial relations extracted from the parses."""
from typing import Dict, Any, Callable, List, Tuple, NamedTuple
from numbers import Number
from collections import defaultdict
from overrides import overrides
import numpy as np
import spacy
from spacy.tokens.token import Token
from spacy.tokens.span import Span
from argparse import Namespace
from .ref_method import RefMethod
from lattice import Product as L
from heuristics import Heuristics
from entity_extraction import Entity, expand_chunks
def get_conjunct(ent, chunks, heuristics: Heuristics) -> Entity:
"""If an entity represents a conjunction of two entities, pull them apart."""
head = ent.head.root # Not ...root.head. Confusing names here.
if not any(child.text == "and" for child in head.children):
return None
for child in head.children:
if child.i in chunks and head.i is not child.i:
return Entity.extract(child, chunks, heuristics)
return None
class Parse(RefMethod):
"""An REF method that extracts and composes predicates, relations, and superlatives from a dependency parse.
The process is as follows:
1. Use spacy to parse the document.
2. Extract a semantic entity tree from the parse.
3. Execute the entity tree to yield a distribution over boxes."""
nlp = spacy.load('en_core_web_sm')
def __init__(self, args: Namespace = None):
self.args = args
self.box_area_threshold = args.box_area_threshold
self.baseline_threshold = args.baseline_threshold
self.temperature = args.temperature
self.superlative_head_only = args.superlative_head_only
self.expand_chunks = args.expand_chunks
self.branch = not args.parse_no_branch
self.possessive_expand = not args.possessive_no_expand
# Lists of keyword heuristics to use.
self.heuristics = Heuristics(args)
# Metrics for debugging relation extraction behavor.
self.counts = defaultdict(int)
@overrides
def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
"""Construct an `Entity` tree from the parse and execute it to yield a distribution over boxes."""
# Start by using the full caption, as in Baseline.
probs = env.filter(caption, area_threshold=self.box_area_threshold, softmax=True)
ori_probs = probs
# Extend the baseline using parse stuff.
doc = self.nlp(caption)
head = self.get_head(doc)
chunks = self.get_chunks(doc)
if self.expand_chunks:
chunks = expand_chunks(doc, chunks)
entity = Entity.extract(head, chunks, self.heuristics)
# If no head noun is found, take the first one.
if entity is None and len(list(doc.noun_chunks)) > 0:
head = list(doc.noun_chunks)[0]
entity = Entity.extract(head.root.head, chunks, self.heuristics)
self.counts["n_0th_noun"] += 1
# If we have found some head noun, filter based on it.
if entity is not None and (any(any(token.text in h.keywords for h in self.heuristics.relations+self.heuristics.superlatives) for token in doc) or not self.branch):
ent_probs, texts = self.execute_entity(entity, env, chunks)
probs = L.meet(probs, ent_probs)
else:
texts = [caption]
self.counts["n_full_expr"] += 1
if len(ori_probs) == 1:
probs = ori_probs
self.counts["n_total"] += 1
pred = np.argmax(probs)
return {
"probs": probs,
"pred": pred,
"box": env.boxes[pred],
"texts": texts
}
def execute_entity(self,
ent: Entity,
env: "Environment",
chunks: Dict[int, Span],
root: bool = True,
) -> np.ndarray:
"""Execute an `Entity` tree recursively, yielding a distribution over boxes."""
self.counts["n_rec"] += 1
probs = [1, 1]
head_probs = probs
# Only use relations if the head baseline isn't certain.
if len(probs) == 1 or len(env.boxes) == 1:
return probs, [ent.text]
m1, m2 = probs[:2] # probs[(-probs).argsort()[:2]]
text = ent.text
rel_probs = []
if self.baseline_threshold == float("inf") or m1 < self.baseline_threshold * m2:
self.counts["n_rec_rel"] += 1
for tokens, ent2 in ent.relations:
self.counts["n_rel"] += 1
rel = None
# Heuristically decide which spatial relation is represented.
for heuristic in self.heuristics.relations:
if any(tok.text in heuristic.keywords for tok in tokens):
rel = heuristic.callback(env)
self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1
break
# Filter and normalize by the spatial relation.
if rel is not None:
probs2 = self.execute_entity(ent2, env, chunks, root=False)
events = L.meet(np.expand_dims(probs2, axis=0), rel)
new_probs = L.join_reduce(events)
rel_probs.append((ent2.text, new_probs, probs2))
continue
# This case specifically handles "between", which takes two noun arguments.
rel = None
for heuristic in self.heuristics.ternary_relations:
if any(tok.text in heuristic.keywords for tok in tokens):
rel = heuristic.callback(env)
self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1
break
if rel is not None:
ent3 = get_conjunct(ent2, chunks, self.heuristics)
if ent3 is not None:
probs2 = self.execute_entity(ent2, env, chunks, root=False)
probs2 = np.expand_dims(probs2, axis=[0, 2])
probs3 = self.execute_entity(ent3, env, chunks, root=False)
probs3 = np.expand_dims(probs3, axis=[0, 1])
events = L.meet(L.meet(probs2, probs3), rel)
new_probs = L.join_reduce(L.join_reduce(events))
probs = L.meet(probs, new_probs)
continue
# Otherwise, treat the relation as a possessive relation.
if not self.args.no_possessive:
if self.possessive_expand:
text = ent.expand(ent2.head)
else:
text += f' {" ".join(tok.text for tok in tokens)} {ent2.text}'
#poss_probs = self._filter(text, env, root=root, expand=.3)
probs = self._filter(text, env, root=root)
texts = [text]
return_probs = [(probs.tolist(), probs.tolist())]
for (ent2_text, new_probs, ent2_only_probs) in rel_probs:
probs = L.meet(probs, new_probs)
probs /= probs.sum()
texts.append(ent2_text)
return_probs.append((probs.tolist(), ent2_only_probs.tolist()))
# Only use superlatives if thresholds work out.
m1, m2 = probs[(-probs).argsort()[:2]]
if m1 < self.baseline_threshold * m2:
self.counts["n_rec_sup"] += 1
for tokens in ent.superlatives:
self.counts["n_sup"] += 1
sup = None
for heuristic_index, heuristic in enumerate(self.heuristics.superlatives):
if any(tok.text in heuristic.keywords for tok in tokens):
texts.append('sup:'+' '.join([tok.text for tok in tokens if tok.text in heuristic.keywords]))
sup = heuristic.callback(env)
self.counts[f"n_sup_{heuristic.keywords[0]}"] += 1
break
if sup is not None:
# Could use `probs` or `head_probs` here?
precond = head_probs if self.superlative_head_only else probs
probs = L.meet(np.expand_dims(precond, axis=1)*np.expand_dims(precond, axis=0), sup).sum(axis=1)
probs = probs / probs.sum()
return_probs.append((probs.tolist(), None))
if root:
assert len(texts) == len(return_probs)
return probs, (texts, return_probs, tuple(str(chunk) for chunk in chunks.values()))
return probs
def get_head(self, doc) -> Token:
"""Return the token that is the head of the dependency parse. """
for token in doc:
if token.head.i == token.i:
return token
return None
def get_chunks(self, doc) -> Dict[int, Any]:
"""Return a dictionary mapping sentence indices to their noun chunk."""
chunks = {}
for chunk in doc.noun_chunks:
for idx in range(chunk.start, chunk.end):
chunks[idx] = chunk
return chunks
@overrides
def get_stats(self) -> Dict[str, Number]:
"""Summary statistics that have been tracked on this object."""
stats = dict(self.counts)
n_rel_caught = sum(v for k, v in stats.items() if k.startswith("n_rel_"))
n_sup_caught = sum(v for k, v in stats.items() if k.startswith("n_sup_"))
stats.update({
"p_rel_caught": n_rel_caught / (self.counts["n_rel"] + 1e-9),
"p_sup_caught": n_sup_caught / (self.counts["n_sup"] + 1e-9),
"p_rec_rel": self.counts["n_rec_rel"] / (self.counts["n_rec"] + 1e-9),
"p_rec_sup": self.counts["n_rec_sup"] / (self.counts["n_rec"] + 1e-9),
"p_0th_noun": self.counts["n_0th_noun"] / (self.counts["n_total"] + 1e-9),
"p_full_expr": self.counts["n_full_expr"] / (self.counts["n_total"] + 1e-9),
"avg_rec": self.counts["n_rec"] / self.counts["n_total"],
})
return stats
def _filter(self,
caption: str,
env: "Environment",
root: bool = False,
expand: float = None,
) -> np.ndarray:
"""Wrap a filter call in a consistent way for all recursions."""
kwargs = {
"softmax": not self.args.sigmoid,
"temperature": self.args.temperature,
}
if root:
return env.filter(caption, area_threshold=self.box_area_threshold, **kwargs)
else:
return env.filter(caption, **kwargs)