Spaces:
Running
on
Zero
Running
on
Zero
"""Use spatial relations extracted from the parses.""" | |
from typing import Dict, Any, Callable, List, Tuple, NamedTuple | |
from numbers import Number | |
from collections import defaultdict | |
from overrides import overrides | |
import numpy as np | |
import spacy | |
from spacy.tokens.token import Token | |
from spacy.tokens.span import Span | |
from argparse import Namespace | |
from .ref_method import RefMethod | |
from lattice import Product as L | |
from heuristics import Heuristics | |
from entity_extraction import Entity, expand_chunks | |
def get_conjunct(ent, chunks, heuristics: Heuristics) -> Entity: | |
"""If an entity represents a conjunction of two entities, pull them apart.""" | |
head = ent.head.root # Not ...root.head. Confusing names here. | |
if not any(child.text == "and" for child in head.children): | |
return None | |
for child in head.children: | |
if child.i in chunks and head.i is not child.i: | |
return Entity.extract(child, chunks, heuristics) | |
return None | |
class Parse(RefMethod): | |
"""An REF method that extracts and composes predicates, relations, and superlatives from a dependency parse. | |
The process is as follows: | |
1. Use spacy to parse the document. | |
2. Extract a semantic entity tree from the parse. | |
3. Execute the entity tree to yield a distribution over boxes.""" | |
nlp = spacy.load('en_core_web_sm') | |
def __init__(self, args: Namespace = None): | |
self.args = args | |
self.box_area_threshold = args.box_area_threshold | |
self.baseline_threshold = args.baseline_threshold | |
self.temperature = args.temperature | |
self.superlative_head_only = args.superlative_head_only | |
self.expand_chunks = args.expand_chunks | |
self.branch = not args.parse_no_branch | |
self.possessive_expand = not args.possessive_no_expand | |
# Lists of keyword heuristics to use. | |
self.heuristics = Heuristics(args) | |
# Metrics for debugging relation extraction behavor. | |
self.counts = defaultdict(int) | |
def execute(self, caption: str, env: "Environment") -> Dict[str, Any]: | |
"""Construct an `Entity` tree from the parse and execute it to yield a distribution over boxes.""" | |
# Start by using the full caption, as in Baseline. | |
probs = env.filter(caption, area_threshold=self.box_area_threshold, softmax=True) | |
ori_probs = probs | |
# Extend the baseline using parse stuff. | |
doc = self.nlp(caption) | |
head = self.get_head(doc) | |
chunks = self.get_chunks(doc) | |
if self.expand_chunks: | |
chunks = expand_chunks(doc, chunks) | |
entity = Entity.extract(head, chunks, self.heuristics) | |
# If no head noun is found, take the first one. | |
if entity is None and len(list(doc.noun_chunks)) > 0: | |
head = list(doc.noun_chunks)[0] | |
entity = Entity.extract(head.root.head, chunks, self.heuristics) | |
self.counts["n_0th_noun"] += 1 | |
# If we have found some head noun, filter based on it. | |
if entity is not None and (any(any(token.text in h.keywords for h in self.heuristics.relations+self.heuristics.superlatives) for token in doc) or not self.branch): | |
ent_probs, texts = self.execute_entity(entity, env, chunks) | |
probs = L.meet(probs, ent_probs) | |
else: | |
texts = [caption] | |
self.counts["n_full_expr"] += 1 | |
if len(ori_probs) == 1: | |
probs = ori_probs | |
self.counts["n_total"] += 1 | |
pred = np.argmax(probs) | |
return { | |
"probs": probs, | |
"pred": pred, | |
"box": env.boxes[pred], | |
"texts": texts | |
} | |
def execute_entity(self, | |
ent: Entity, | |
env: "Environment", | |
chunks: Dict[int, Span], | |
root: bool = True, | |
) -> np.ndarray: | |
"""Execute an `Entity` tree recursively, yielding a distribution over boxes.""" | |
self.counts["n_rec"] += 1 | |
probs = [1, 1] | |
head_probs = probs | |
# Only use relations if the head baseline isn't certain. | |
if len(probs) == 1 or len(env.boxes) == 1: | |
return probs, [ent.text] | |
m1, m2 = probs[:2] # probs[(-probs).argsort()[:2]] | |
text = ent.text | |
rel_probs = [] | |
if self.baseline_threshold == float("inf") or m1 < self.baseline_threshold * m2: | |
self.counts["n_rec_rel"] += 1 | |
for tokens, ent2 in ent.relations: | |
self.counts["n_rel"] += 1 | |
rel = None | |
# Heuristically decide which spatial relation is represented. | |
for heuristic in self.heuristics.relations: | |
if any(tok.text in heuristic.keywords for tok in tokens): | |
rel = heuristic.callback(env) | |
self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1 | |
break | |
# Filter and normalize by the spatial relation. | |
if rel is not None: | |
probs2 = self.execute_entity(ent2, env, chunks, root=False) | |
events = L.meet(np.expand_dims(probs2, axis=0), rel) | |
new_probs = L.join_reduce(events) | |
rel_probs.append((ent2.text, new_probs, probs2)) | |
continue | |
# This case specifically handles "between", which takes two noun arguments. | |
rel = None | |
for heuristic in self.heuristics.ternary_relations: | |
if any(tok.text in heuristic.keywords for tok in tokens): | |
rel = heuristic.callback(env) | |
self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1 | |
break | |
if rel is not None: | |
ent3 = get_conjunct(ent2, chunks, self.heuristics) | |
if ent3 is not None: | |
probs2 = self.execute_entity(ent2, env, chunks, root=False) | |
probs2 = np.expand_dims(probs2, axis=[0, 2]) | |
probs3 = self.execute_entity(ent3, env, chunks, root=False) | |
probs3 = np.expand_dims(probs3, axis=[0, 1]) | |
events = L.meet(L.meet(probs2, probs3), rel) | |
new_probs = L.join_reduce(L.join_reduce(events)) | |
probs = L.meet(probs, new_probs) | |
continue | |
# Otherwise, treat the relation as a possessive relation. | |
if not self.args.no_possessive: | |
if self.possessive_expand: | |
text = ent.expand(ent2.head) | |
else: | |
text += f' {" ".join(tok.text for tok in tokens)} {ent2.text}' | |
#poss_probs = self._filter(text, env, root=root, expand=.3) | |
probs = self._filter(text, env, root=root) | |
texts = [text] | |
return_probs = [(probs.tolist(), probs.tolist())] | |
for (ent2_text, new_probs, ent2_only_probs) in rel_probs: | |
probs = L.meet(probs, new_probs) | |
probs /= probs.sum() | |
texts.append(ent2_text) | |
return_probs.append((probs.tolist(), ent2_only_probs.tolist())) | |
# Only use superlatives if thresholds work out. | |
m1, m2 = probs[(-probs).argsort()[:2]] | |
if m1 < self.baseline_threshold * m2: | |
self.counts["n_rec_sup"] += 1 | |
for tokens in ent.superlatives: | |
self.counts["n_sup"] += 1 | |
sup = None | |
for heuristic_index, heuristic in enumerate(self.heuristics.superlatives): | |
if any(tok.text in heuristic.keywords for tok in tokens): | |
texts.append('sup:'+' '.join([tok.text for tok in tokens if tok.text in heuristic.keywords])) | |
sup = heuristic.callback(env) | |
self.counts[f"n_sup_{heuristic.keywords[0]}"] += 1 | |
break | |
if sup is not None: | |
# Could use `probs` or `head_probs` here? | |
precond = head_probs if self.superlative_head_only else probs | |
probs = L.meet(np.expand_dims(precond, axis=1)*np.expand_dims(precond, axis=0), sup).sum(axis=1) | |
probs = probs / probs.sum() | |
return_probs.append((probs.tolist(), None)) | |
if root: | |
assert len(texts) == len(return_probs) | |
return probs, (texts, return_probs, tuple(str(chunk) for chunk in chunks.values())) | |
return probs | |
def get_head(self, doc) -> Token: | |
"""Return the token that is the head of the dependency parse. """ | |
for token in doc: | |
if token.head.i == token.i: | |
return token | |
return None | |
def get_chunks(self, doc) -> Dict[int, Any]: | |
"""Return a dictionary mapping sentence indices to their noun chunk.""" | |
chunks = {} | |
for chunk in doc.noun_chunks: | |
for idx in range(chunk.start, chunk.end): | |
chunks[idx] = chunk | |
return chunks | |
def get_stats(self) -> Dict[str, Number]: | |
"""Summary statistics that have been tracked on this object.""" | |
stats = dict(self.counts) | |
n_rel_caught = sum(v for k, v in stats.items() if k.startswith("n_rel_")) | |
n_sup_caught = sum(v for k, v in stats.items() if k.startswith("n_sup_")) | |
stats.update({ | |
"p_rel_caught": n_rel_caught / (self.counts["n_rel"] + 1e-9), | |
"p_sup_caught": n_sup_caught / (self.counts["n_sup"] + 1e-9), | |
"p_rec_rel": self.counts["n_rec_rel"] / (self.counts["n_rec"] + 1e-9), | |
"p_rec_sup": self.counts["n_rec_sup"] / (self.counts["n_rec"] + 1e-9), | |
"p_0th_noun": self.counts["n_0th_noun"] / (self.counts["n_total"] + 1e-9), | |
"p_full_expr": self.counts["n_full_expr"] / (self.counts["n_total"] + 1e-9), | |
"avg_rec": self.counts["n_rec"] / self.counts["n_total"], | |
}) | |
return stats | |
def _filter(self, | |
caption: str, | |
env: "Environment", | |
root: bool = False, | |
expand: float = None, | |
) -> np.ndarray: | |
"""Wrap a filter call in a consistent way for all recursions.""" | |
kwargs = { | |
"softmax": not self.args.sigmoid, | |
"temperature": self.args.temperature, | |
} | |
if root: | |
return env.filter(caption, area_threshold=self.box_area_threshold, **kwargs) | |
else: | |
return env.filter(caption, **kwargs) | |