Source code for udapi.block.corefud.fixinterleaved

from udapi.core.block import Block
import udapi.core.coref
import itertools

[docs] class FixInterleaved(Block): """Fix mentions with interleaved or crossing spans. https://github.com/ufal/corefUD/issues/25 """ def __init__(self, same_entity_only=True, both_discontinuous=False, crossing_only=False, nested_same_subspan=True, **kwargs): super().__init__(**kwargs) self.same_entity_only = same_entity_only self.both_discontinuous = both_discontinuous self.crossing_only = crossing_only self.nested_same_subspan = nested_same_subspan
[docs] def process_tree(self, tree): mentions, deleted = set(), set() for node in tree.descendants_and_empty: for m in node.coref_mentions: mentions.add(m) for mA, mB in itertools.combinations(mentions, 2): if mA in deleted or mB in deleted: continue if self.same_entity_only and mA.entity != mB.entity: continue # Fully nested spans are OK, except for same-subspan. sA, sB = set(mA.words), set(mB.words) if (sA <= sB) or (sB <= sA): if not self.nested_same_subspan: continue elif not set(mA.span.split(',')).intersection(set(mB.span.split(','))): continue # Crossing or interleaved+crossing? elif self.crossing_only: if not sA.intersection(sB): continue else: if mA.words[0] < mB.words[0] and mA.words[-1] < mB.words[0]: continue if mB.words[0] < mA.words[0] and mB.words[-1] < mA.words[0]: continue if self.both_discontinuous and (',' not in mA.span or ',' not in mB.span): continue mA.words = list(sA.union(sB)) for wb in sB: try: wb._mentions.remove(mB) except ValueError: pass try: mB.entity.mentions.remove(mB) except ValueError: pass deleted.add(mB) # By changing the mA.words, we could have created another error: # making the span same as another mention. Let's fix it. sA = set(mA.words) for mC in sorted(mentions): if mC in deleted or mC is mA or mC is mB: continue if sA != set(mC.words): continue # So mA and mC have the same span and we need to delete one of them to fix it. # We will delete mA because it has the artificially enlarged span, # while mC is from the original annotation. for wa in sA: try: wa._mentions.remove(mA) except ValueError: pass try: mA.entity.mentions.remove(mA) except ValueError: pass break deleted.add(mA)