Source code for udapi.block.corefud.markpairs

from udapi.core.block import Block
import udapi.core.coref
import itertools
from collections import Counter
import logging

[docs] class MarkPairs(Block): """Find pairs of coreference mentions within the same sentence with given properties. Mark these pairs of mentions (using `misc["Mark"]`), so they can be further processed or printed. Usage: # Find pairs of mentions of the same entity within the same sentence: cat my.conllu | udapy -TM corefud.MarkPairs same_entity=1 | less -R Properties: same_entity - both mentions belong to the same entity (cluster) both_continuous - both mentions have continuous spans both_discontinuous - both mentions have discontinuous spans nested - span of one mention is nested (a subset of) in the span of the other mention crossing - spans are crossing (i.e. intersecting, but neither is subset of the other) interleaved - spans are interleaved (i.e. not intersecting, but neither span precedes the other) same_head - the same node is a head of both mentions same_span - both mentions have the same span (which is invalid according to UD's validate.py) same_subspan - at least one of the mentions is discontinuous and one of its subspans is also a subspan (or span) of the other mention You can combine any number of properties. Each property can have one of the three values: include - this is the default value: include pairs with this property, i.e. ignore the property exclude - exclude (from the marking) pairs of mentions with this property only - pairs of mentions without this property will be excluded As a shortcut, you can use -1 and 1 instead of exclude and only, so e.g. nested=only same_head=exclude can be written as nested=1 same_head=-1 """ def __init__(self, same_entity=0, both_continuous=0, both_discontinuous=0, nested=0, crossing=0, interleaved=0, same_head=0, same_span=0, same_subspan=0, print_form=False, print_total=True, log=True, mark=True, **kwargs): super().__init__(**kwargs) self.same_entity = self._convert(same_entity) self.both_continuous = self._convert(both_continuous) self.both_discontinuous = self._convert(both_discontinuous) self.nested = self._convert(nested) self.crossing = self._convert(crossing) self.interleaved = self._convert(interleaved) self.same_head = self._convert(same_head) self.same_span = self._convert(same_span) self.same_subspan = self._convert(same_subspan) self.print_form = print_form self.print_total = print_total self.log = log self.mark = mark self.counter = Counter() def _convert(self, value): if value in {-1, 0, 1}: return value if value == 'include': return 0 if value == 'only': return 1 if value == 'exclude': return -1 raise ValueError('unknown value ' + value) def _ok(self, condition, value): if value == 0: return True return (condition and value == 1) or (not condition and value==-1) def _print(self, mention): if self.print_form: return mention.entity.eid + ':' + ' '.join([w.form for w in mention.words]) else: return mention.entity.eid + ':' + mention.span
[docs] def process_tree(self, tree): mentions = set() for node in tree.descendants_and_empty: for m in node.coref_mentions: mentions.add(m) self.counter['mentions'] += len(mentions) for mA, mB in itertools.combinations(mentions, 2): self.counter['pairs'] += 1 if not self._ok(mA.entity == mB.entity, self.same_entity): continue if not self._ok(mA.head == mB.head, self.same_head): continue if self.both_continuous or self.both_discontinuous or self.same_span or self.same_subspan: sA, sB = mA.span, mB.span cA, cB = ',' not in sA, ',' not in sB if not self._ok(cA and cB, self.both_continuous): continue if not self._ok(not cA and not cB, self.both_discontinuous): continue if not self._ok(sA == sB, self.same_span): continue if not self._ok(set(sA.split(',')).intersection(set(sB.split(','))), self.same_subspan): continue if self.nested or self.crossing or self.interleaved: wA, wB = set(mA.words), set(mB.words) if not self._ok(wA <= wB or wB <= wA, self.nested): continue if not self._ok(wA.intersection(wB) and not wA <= wB and not wB <= wA, self.crossing): continue if self.interleaved: a_precedes_b = mA.words[0] < mB.words[0] and mA.words[-1] < mB.words[0] b_precedes_a = mB.words[0] < mA.words[0] and mB.words[-1] < mA.words[0] if not self._ok(not wA.intersection(wB) and not a_precedes_b and not b_precedes_a, self.interleaved): continue self.counter['matching'] += 1 if self.mark: for w in mA.words + mB.words: w.misc['Mark'] = 1 mA.words[0].misc['Mark'] = f"{self._print(mA)}+{self._print(mB)}" if self.log: logging.info(f"Found mentions at {tree.sent_id}: {self._print(mA)} + {self._print(mB)}")
[docs] def after_process_document(self, doc): if self.print_total: #if self.max_trees and seen_trees > self.max_trees: # print(f'######## Only first {self.max_trees} matching mentions printed. Use max_trees=0 to see all.') msg = f'######## Mentions = {self.counter["mentions"]}, matching/all pairs = {self.counter["matching"]} / {self.counter["pairs"]}' logging.info(msg) doc.meta["corefud.MarkPairs"] = msg