Spaces:
Runtime error
Runtime error
| """Use spatial relations extracted from the parses.""" | |
| from typing import Dict, Any, Callable, List, Tuple, NamedTuple | |
| from numbers import Number | |
| from collections import defaultdict | |
| from overrides import overrides | |
| import numpy as np | |
| import spacy | |
| from spacy.tokens.token import Token | |
| from spacy.tokens.span import Span | |
| from argparse import Namespace | |
| from .ref_method import RefMethod | |
| from lattice import Product as L | |
| from heuristics import Heuristics | |
| from entity_extraction import Entity, expand_chunks | |
| def get_conjunct(ent, chunks, heuristics: Heuristics) -> Entity: | |
| """If an entity represents a conjunction of two entities, pull them apart.""" | |
| head = ent.head.root # Not ...root.head. Confusing names here. | |
| if not any(child.text == "and" for child in head.children): | |
| return None | |
| for child in head.children: | |
| if child.i in chunks and head.i is not child.i: | |
| return Entity.extract(child, chunks, heuristics) | |
| return None | |
| class Parse(RefMethod): | |
| """An REF method that extracts and composes predicates, relations, and superlatives from a dependency parse. | |
| The process is as follows: | |
| 1. Use spacy to parse the document. | |
| 2. Extract a semantic entity tree from the parse. | |
| 3. Execute the entity tree to yield a distribution over boxes.""" | |
| nlp = spacy.load('en_core_web_sm') | |
| def __init__(self, args: Namespace = None): | |
| self.args = args | |
| self.box_area_threshold = args.box_area_threshold | |
| self.baseline_threshold = args.baseline_threshold | |
| self.temperature = args.temperature | |
| self.superlative_head_only = args.superlative_head_only | |
| self.expand_chunks = args.expand_chunks | |
| self.branch = not args.parse_no_branch | |
| self.possessive_expand = not args.possessive_no_expand | |
| # Lists of keyword heuristics to use. | |
| self.heuristics = Heuristics(args) | |
| # Metrics for debugging relation extraction behavor. | |
| self.counts = defaultdict(int) | |
| def execute(self, caption: str, env: "Environment") -> Dict[str, Any]: | |
| """Construct an `Entity` tree from the parse and execute it to yield a distribution over boxes.""" | |
| # Start by using the full caption, as in Baseline. | |
| probs = env.filter(caption, area_threshold=self.box_area_threshold, softmax=True) | |
| ori_probs = probs | |
| # Extend the baseline using parse stuff. | |
| doc = self.nlp(caption) | |
| head = self.get_head(doc) | |
| chunks = self.get_chunks(doc) | |
| if self.expand_chunks: | |
| chunks = expand_chunks(doc, chunks) | |
| entity = Entity.extract(head, chunks, self.heuristics) | |
| # If no head noun is found, take the first one. | |
| if entity is None and len(list(doc.noun_chunks)) > 0: | |
| head = list(doc.noun_chunks)[0] | |
| entity = Entity.extract(head.root.head, chunks, self.heuristics) | |
| self.counts["n_0th_noun"] += 1 | |
| # If we have found some head noun, filter based on it. | |
| if entity is not None and (any(any(token.text in h.keywords for h in self.heuristics.relations+self.heuristics.superlatives) for token in doc) or not self.branch): | |
| ent_probs, texts = self.execute_entity(entity, env, chunks) | |
| probs = L.meet(probs, ent_probs) | |
| else: | |
| texts = [caption] | |
| self.counts["n_full_expr"] += 1 | |
| if len(ori_probs) == 1: | |
| probs = ori_probs | |
| self.counts["n_total"] += 1 | |
| pred = np.argmax(probs) | |
| return { | |
| "probs": probs, | |
| "pred": pred, | |
| "box": env.boxes[pred], | |
| "texts": texts | |
| } | |
| def execute_entity(self, | |
| ent: Entity, | |
| env: "Environment", | |
| chunks: Dict[int, Span], | |
| root: bool = True, | |
| ) -> np.ndarray: | |
| """Execute an `Entity` tree recursively, yielding a distribution over boxes.""" | |
| self.counts["n_rec"] += 1 | |
| probs = [1, 1] | |
| head_probs = probs | |
| # Only use relations if the head baseline isn't certain. | |
| if len(probs) == 1 or len(env.boxes) == 1: | |
| return probs, [ent.text] | |
| m1, m2 = probs[:2] # probs[(-probs).argsort()[:2]] | |
| text = ent.text | |
| rel_probs = [] | |
| if self.baseline_threshold == float("inf") or m1 < self.baseline_threshold * m2: | |
| self.counts["n_rec_rel"] += 1 | |
| for tokens, ent2 in ent.relations: | |
| self.counts["n_rel"] += 1 | |
| rel = None | |
| # Heuristically decide which spatial relation is represented. | |
| for heuristic in self.heuristics.relations: | |
| if any(tok.text in heuristic.keywords for tok in tokens): | |
| rel = heuristic.callback(env) | |
| self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1 | |
| break | |
| # Filter and normalize by the spatial relation. | |
| if rel is not None: | |
| probs2 = self.execute_entity(ent2, env, chunks, root=False) | |
| events = L.meet(np.expand_dims(probs2, axis=0), rel) | |
| new_probs = L.join_reduce(events) | |
| rel_probs.append((ent2.text, new_probs, probs2)) | |
| continue | |
| # This case specifically handles "between", which takes two noun arguments. | |
| rel = None | |
| for heuristic in self.heuristics.ternary_relations: | |
| if any(tok.text in heuristic.keywords for tok in tokens): | |
| rel = heuristic.callback(env) | |
| self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1 | |
| break | |
| if rel is not None: | |
| ent3 = get_conjunct(ent2, chunks, self.heuristics) | |
| if ent3 is not None: | |
| probs2 = self.execute_entity(ent2, env, chunks, root=False) | |
| probs2 = np.expand_dims(probs2, axis=[0, 2]) | |
| probs3 = self.execute_entity(ent3, env, chunks, root=False) | |
| probs3 = np.expand_dims(probs3, axis=[0, 1]) | |
| events = L.meet(L.meet(probs2, probs3), rel) | |
| new_probs = L.join_reduce(L.join_reduce(events)) | |
| probs = L.meet(probs, new_probs) | |
| continue | |
| # Otherwise, treat the relation as a possessive relation. | |
| if not self.args.no_possessive: | |
| if self.possessive_expand: | |
| text = ent.expand(ent2.head) | |
| else: | |
| text += f' {" ".join(tok.text for tok in tokens)} {ent2.text}' | |
| #poss_probs = self._filter(text, env, root=root, expand=.3) | |
| probs = self._filter(text, env, root=root) | |
| texts = [text] | |
| return_probs = [(probs.tolist(), probs.tolist())] | |
| for (ent2_text, new_probs, ent2_only_probs) in rel_probs: | |
| probs = L.meet(probs, new_probs) | |
| probs /= probs.sum() | |
| texts.append(ent2_text) | |
| return_probs.append((probs.tolist(), ent2_only_probs.tolist())) | |
| # Only use superlatives if thresholds work out. | |
| m1, m2 = probs[(-probs).argsort()[:2]] | |
| if m1 < self.baseline_threshold * m2: | |
| self.counts["n_rec_sup"] += 1 | |
| for tokens in ent.superlatives: | |
| self.counts["n_sup"] += 1 | |
| sup = None | |
| for heuristic_index, heuristic in enumerate(self.heuristics.superlatives): | |
| if any(tok.text in heuristic.keywords for tok in tokens): | |
| texts.append('sup:'+' '.join([tok.text for tok in tokens if tok.text in heuristic.keywords])) | |
| sup = heuristic.callback(env) | |
| self.counts[f"n_sup_{heuristic.keywords[0]}"] += 1 | |
| break | |
| if sup is not None: | |
| # Could use `probs` or `head_probs` here? | |
| precond = head_probs if self.superlative_head_only else probs | |
| probs = L.meet(np.expand_dims(precond, axis=1)*np.expand_dims(precond, axis=0), sup).sum(axis=1) | |
| probs = probs / probs.sum() | |
| return_probs.append((probs.tolist(), None)) | |
| if root: | |
| assert len(texts) == len(return_probs) | |
| return probs, (texts, return_probs, tuple(str(chunk) for chunk in chunks.values())) | |
| return probs | |
| def get_head(self, doc) -> Token: | |
| """Return the token that is the head of the dependency parse. """ | |
| for token in doc: | |
| if token.head.i == token.i: | |
| return token | |
| return None | |
| def get_chunks(self, doc) -> Dict[int, Any]: | |
| """Return a dictionary mapping sentence indices to their noun chunk.""" | |
| chunks = {} | |
| for chunk in doc.noun_chunks: | |
| for idx in range(chunk.start, chunk.end): | |
| chunks[idx] = chunk | |
| return chunks | |
| def get_stats(self) -> Dict[str, Number]: | |
| """Summary statistics that have been tracked on this object.""" | |
| stats = dict(self.counts) | |
| n_rel_caught = sum(v for k, v in stats.items() if k.startswith("n_rel_")) | |
| n_sup_caught = sum(v for k, v in stats.items() if k.startswith("n_sup_")) | |
| stats.update({ | |
| "p_rel_caught": n_rel_caught / (self.counts["n_rel"] + 1e-9), | |
| "p_sup_caught": n_sup_caught / (self.counts["n_sup"] + 1e-9), | |
| "p_rec_rel": self.counts["n_rec_rel"] / (self.counts["n_rec"] + 1e-9), | |
| "p_rec_sup": self.counts["n_rec_sup"] / (self.counts["n_rec"] + 1e-9), | |
| "p_0th_noun": self.counts["n_0th_noun"] / (self.counts["n_total"] + 1e-9), | |
| "p_full_expr": self.counts["n_full_expr"] / (self.counts["n_total"] + 1e-9), | |
| "avg_rec": self.counts["n_rec"] / self.counts["n_total"], | |
| }) | |
| return stats | |
| def _filter(self, | |
| caption: str, | |
| env: "Environment", | |
| root: bool = False, | |
| expand: float = None, | |
| ) -> np.ndarray: | |
| """Wrap a filter call in a consistent way for all recursions.""" | |
| kwargs = { | |
| "softmax": not self.args.sigmoid, | |
| "temperature": self.args.temperature, | |
| } | |
| if root: | |
| return env.filter(caption, area_threshold=self.box_area_threshold, **kwargs) | |
| else: | |
| return env.filter(caption, **kwargs) | |