Module a2t.data

The module data implements different dataloaders or Datasets for predefined tasks.

Expand source code
"""The module `data` implements different dataloaders or `Dataset`s for predefined tasks.
"""
from .tacred import TACREDRelationClassificationDataset
from .babeldomains import BabelDomainsTopicClassificationDataset
from .wikievents import WikiEventsArgumentClassificationDataset
from .ace import ACEArgumentClassificationDataset
from .base import Dataset

PREDEFINED_DATASETS = {
    "tacred": TACREDRelationClassificationDataset,
    "babeldomains": BabelDomainsTopicClassificationDataset,
    "wikievents_arguments": WikiEventsArgumentClassificationDataset,
    "ace_arguments": ACEArgumentClassificationDataset,
}

__all__ = [
    "Dataset",
    "TACREDRelationClassificationDataset",
    "BabelDomainsTopicClassificationDataset",
    "WikiEventsArgumentClassificationDataset",
    "ACEArgumentClassificationDataset",
]

__pdoc__ = {"base": False, "babeldomains": False, "tacred": False, "wikievents": False, "ace": False}

Classes

class Dataset (labels: List[str], *args, **kwargs)

A simple class to handle the datasets.

Inherits from list, so the instances should be added with append or extend methods to itself.

Args

labels : List[str]
The possible label set of the dataset.
Expand source code
class Dataset(list):
    """A simple class to handle the datasets.

    Inherits from `list`, so the instances should be added with `append` or `extend` methods to itself.
    """

    def __init__(self, labels: List[str], *args, **kwargs) -> None:
        """
        Args:
            labels (List[str]): The possible label set of the dataset.
        """
        super().__init__()

        self.labels2id = {label: i for i, label in enumerate(labels)}
        self.id2labels = {i: label for i, label in enumerate(labels)}

    @property
    def labels(self):
        # TODO: Unittest
        if not hasattr(self, "_labels"):
            self._labels = np.asarray([self.labels2id[inst.label] for inst in self])
        return self._labels

Ancestors

  • builtins.list

Subclasses

  • a2t.data.ace._ACEDataset
  • a2t.data.babeldomains.BabelDomainsTopicClassificationDataset
  • a2t.data.tacred.TACREDRelationClassificationDataset
  • a2t.data.wikievents._WikiEventsDataset

Instance variables

var labels
Expand source code
@property
def labels(self):
    # TODO: Unittest
    if not hasattr(self, "_labels"):
        self._labels = np.asarray([self.labels2id[inst.label] for inst in self])
    return self._labels
class TACREDRelationClassificationDataset (input_path: str, labels: List[str], *args, **kwargs)

A class to handle TACRED datasets.

This class converts TACRED data files into a list of TACREDFeatures.

Args

input_path : str
The path to the input file.
labels : List[str]
The possible label set of the dataset.
Expand source code
class TACREDRelationClassificationDataset(Dataset):
    """A class to handle TACRED datasets.

    This class converts TACRED data files into a list of `a2t.tasks.TACREDFeatures`.
    """

    def __init__(self, input_path: str, labels: List[str], *args, **kwargs) -> None:
        """
        Args:
            input_path (str): The path to the input file.
            labels (List[str]): The possible label set of the dataset.
        """
        super().__init__(labels=labels, *args, **kwargs)

        with open(input_path, "rt") as f:
            for i, line in enumerate(json.load(f)):
                self.append(
                    TACREDFeatures(
                        subj=" ".join(line["token"][line["subj_start"] : line["subj_end"] + 1])
                        .replace("-LRB-", "(")
                        .replace("-RRB-", ")")
                        .replace("-LSB-", "[")
                        .replace("-RSB-", "]"),
                        obj=" ".join(line["token"][line["obj_start"] : line["obj_end"] + 1])
                        .replace("-LRB-", "(")
                        .replace("-RRB-", ")")
                        .replace("-LSB-", "[")
                        .replace("-RSB-", "]"),
                        inst_type=f"{line['subj_type']}:{line['obj_type']}",
                        context=" ".join(line["token"])
                        .replace("-LRB-", "(")
                        .replace("-RRB-", ")")
                        .replace("-LSB-", "[")
                        .replace("-RSB-", "]"),
                        label=line["relation"],
                    )
                )

Ancestors

  • a2t.data.base.Dataset
  • builtins.list
class BabelDomainsTopicClassificationDataset (input_path: str, labels: List[str], *args, **kwargs)

A class to handle BabelDomains datasets.

This class converts BabelDomains data files into a list of TopicClassificationFeatures.

Args

input_path : str
The path to the input file.
labels : List[str]
The possible label set of the dataset.
Expand source code
class BabelDomainsTopicClassificationDataset(Dataset):
    """A class to handle BabelDomains datasets.

    This class converts BabelDomains data files into a list of `a2t.tasks.TopicClassificationFeatures`.
    """

    def __init__(self, input_path: str, labels: List[str], *args, **kwargs) -> None:
        """
        Args:
            input_path (str): The path to the input file.
            labels (List[str]): The possible label set of the dataset.
        """
        super().__init__(labels=labels, *args, **kwargs)

        with open(input_path, "rt") as f:
            for line in f:
                _, label, context = line.strip().split("\t")
                self.append(TopicClassificationFeatures(context=context, label=label))

Ancestors

  • a2t.data.base.Dataset
  • builtins.list
class WikiEventsArgumentClassificationDataset (input_path: str, labels: List[str], *args, mark_trigger: bool = True, **kwargs)

A class to handle WikiEvents datasets.

This class converts WikiEvents data files into a list of EventArgumentClassificationFeatures.

Args

input_path : str
The path to the input file.
labels : List[str]
The possible label set of the dataset.
Expand source code
class WikiEventsArgumentClassificationDataset(_WikiEventsDataset):
    def __init__(self, input_path: str, labels: List[str], *args, mark_trigger: bool = True, **kwargs) -> None:
        """This class converts WikiEvents data files into a list of `a2t.tasks.EventArgumentClassificationFeatures`.

        Args:
            input_path (str): The path to the input file.
            labels (List[str]): The possible label set of the dataset.
        """
        super().__init__(labels, *args, **kwargs)

        for instance in self._load(input_path):

            id2ent = {ent["id"]: ent for ent in instance["entity_mentions"]}
            for event in instance["event_mentions"]:
                event_type = event["event_type"].replace(":", ".").split(".")  # [:-1]
                trigger_type = event_type[0]
                trigger_subtype = event_type[-2]
                event_type = ".".join(event_type)

                entities = {ent["id"] for ent in instance["entity_mentions"]}

                context = instance["text"][:]
                if mark_trigger:
                    context = (
                        context[: event["trigger"]["start"]]
                        + "[["
                        + event["trigger"]["text"]
                        + "]]"
                        + context[event["trigger"]["end"] :]
                    )

                for argument in event["arguments"]:
                    if argument["entity_id"] not in entities:
                        continue

                    self.append(
                        EventArgumentClassificationFeatures(
                            context=context,
                            trg=event["trigger"]["text"],
                            trg_type=trigger_type,
                            trg_subtype=trigger_subtype,
                            inst_type=f"{event_type}:{id2ent[argument['entity_id']]['entity_type']}",
                            arg=id2ent[argument["entity_id"]]["text"],
                            label=argument["role"] if not "OOR" in argument["role"] else "OOR",
                        )
                    )
                    self[-1].docid = instance["doc_id"]

                    entities.remove(argument["entity_id"])

                # Generate negative examples
                for entity in entities:
                    self.append(
                        EventArgumentClassificationFeatures(
                            context=context,
                            trg=event["trigger"]["text"],
                            trg_type=trigger_type,
                            trg_subtype=trigger_subtype,
                            inst_type=f"{event_type}:{id2ent[entity]['entity_type']}",
                            arg=id2ent[entity]["text"],
                            label="no_relation",
                        )
                    )
                    self[-1].docid = instance["doc_id"]

Ancestors

  • a2t.data.wikievents._WikiEventsDataset
  • a2t.data.base.Dataset
  • builtins.list
class ACEArgumentClassificationDataset (input_path: str, labels: List[str], *args, mark_trigger: bool = True, **kwargs)

A class to handle ACE datasets.

This class converts ACE data files into a list of EventArgumentClassificationFeatures.

Args

input_path : str
The path to the input file.
labels : List[str]
The possible label set of the dataset.
Expand source code
class ACEArgumentClassificationDataset(_ACEDataset):

    label_mapping = {
        "Life:Die|Person": "Victim",
        "Movement:Transport|Place": "Destination",
        "Conflict:Attack|Victim": "Target",
        "Justice:Appeal|Plantiff": "Defendant",
    }

    def __init__(self, input_path: str, labels: List[str], *args, mark_trigger: bool = True, **kwargs) -> None:
        """This class converts ACE data files into a list of `a2t.tasks.EventArgumentClassificationFeatures`.

        Args:
            input_path (str): The path to the input file.
            labels (List[str]): The possible label set of the dataset.
        """
        super().__init__(labels, *args, **kwargs)

        for instance in self._load(input_path):
            tokens = instance["tokens"]
            id2ent = {ent["id"]: ent for ent in instance["entity_mentions"]}
            for event in instance["event_mentions"]:
                event_type = event["event_type"].replace(":", ".").split(".")  # [:-1]
                trigger_type, trigger_subtype = event_type
                event_type = ".".join(event_type)

                entities = {ent["id"] for ent in instance["entity_mentions"]}

                if mark_trigger:
                    context = " ".join(
                        tokens[: event["trigger"]["start"]]
                        + ["[["]
                        + tokens[event["trigger"]["start"] : event["trigger"]["end"]]
                        + ["]]"]
                        + tokens[event["trigger"]["end"] :]
                    )
                else:
                    context = " ".join(tokens)

                for argument in event["arguments"]:
                    # Apply label mapping to sattisfy guidelines constraints
                    role = self.label_mapping.get(f'{event["event_type"]}|{argument["role"]}', argument["role"])

                    # Skip annotation errors
                    if argument["entity_id"] not in entities:
                        continue

                    self.append(
                        EventArgumentClassificationFeatures(
                            context=context,
                            trg=event["trigger"]["text"],
                            trg_type=trigger_type,
                            trg_subtype=trigger_subtype,
                            inst_type=f"{event_type}:{id2ent[argument['entity_id']]['entity_type']}",
                            arg=id2ent[argument["entity_id"]]["text"],
                            label=role,
                        )
                    )

                    entities.remove(argument["entity_id"])

                # Generate negative examples
                for entity in entities:
                    self.append(
                        EventArgumentClassificationFeatures(
                            context=context,
                            trg=event["trigger"]["text"],
                            trg_type=trigger_type,
                            trg_subtype=trigger_subtype,
                            inst_type=f"{event_type}:{id2ent[entity]['entity_type']}",
                            arg=id2ent[entity]["text"],
                            label="no_relation",
                        )
                    )

Ancestors

  • a2t.data.ace._ACEDataset
  • a2t.data.base.Dataset
  • builtins.list

Class variables

var label_mapping