Skip to content

Index

anthropic

The antrhopic implementation of page level calssification.

AnthropicClassificationProvider

Bases: BaseClassificationProvider

The Anthropic implementation of unscored page classification.

Source code in docprompt/tasks/classification/anthropic.py
class AnthropicClassificationProvider(BaseClassificationProvider):
    """The Anthropic implementation of unscored page classification."""

    name = "anthropic"

    async def _ainvoke(
        self, input: Iterable[bytes], config: ClassificationConfig = None, **kwargs
    ) -> List[ClassificationOutput]:
        messages = _prepare_messages(input, config)

        parser = AnthropicPageClassificationOutputParser.from_task_input(
            config, provider_name=self.name
        )

        completions = await inference.run_batch_inference_anthropic(messages)

        return [parser.parse(res) for res in completions]

AnthropicPageClassificationOutputParser

Bases: BasePageClassificationOutputParser

The output parser for the page classification system.

Source code in docprompt/tasks/classification/anthropic.py
class AnthropicPageClassificationOutputParser(BasePageClassificationOutputParser):
    """The output parser for the page classification system."""

    def parse(self, text: str) -> ClassificationOutput:
        """Parse the results of the classification task."""
        pattern = re.compile(r"Answer: (.+)")
        match = pattern.search(text)

        result = self.resolve_match(match)

        if self.confidence:
            conf_pattern = re.compile(r"Confidence: (.+)")
            conf_match = conf_pattern.search(text)
            conf_result = self.resolve_confidence(conf_match)

            return ClassificationOutput(
                type=self.type,
                labels=result,
                score=conf_result,
                provider_name=self.name,
            )

        return ClassificationOutput(
            type=self.type, labels=result, provider_name=self.name
        )

parse(text)

Parse the results of the classification task.

Source code in docprompt/tasks/classification/anthropic.py
def parse(self, text: str) -> ClassificationOutput:
    """Parse the results of the classification task."""
    pattern = re.compile(r"Answer: (.+)")
    match = pattern.search(text)

    result = self.resolve_match(match)

    if self.confidence:
        conf_pattern = re.compile(r"Confidence: (.+)")
        conf_match = conf_pattern.search(text)
        conf_result = self.resolve_confidence(conf_match)

        return ClassificationOutput(
            type=self.type,
            labels=result,
            score=conf_result,
            provider_name=self.name,
        )

    return ClassificationOutput(
        type=self.type, labels=result, provider_name=self.name
    )

base

BaseClassificationProvider

Bases: AbstractPageTaskProvider[bytes, ClassificationConfig, ClassificationOutput]

The base classification provider.

Source code in docprompt/tasks/classification/base.py
class BaseClassificationProvider(
    AbstractPageTaskProvider[bytes, ClassificationConfig, ClassificationOutput]
):
    """
    The base classification provider.
    """

    capabilities = [PageLevelCapabilities.PAGE_CLASSIFICATION]

    class Meta:
        abstract = True

    def process_document_node(
        self,
        document_node: "DocumentNode",
        task_config: ClassificationConfig = None,
        start: Optional[int] = None,
        stop: Optional[int] = None,
        contribute_to_document: bool = True,
        **kwargs,
    ):
        assert (
            task_config is not None
        ), "task_config must be provided for classification tasks"

        raster_bytes = []
        for page_number in range(start or 1, (stop or len(document_node)) + 1):
            image_bytes = document_node.page_nodes[
                page_number - 1
            ].rasterizer.rasterize("default")
            raster_bytes.append(image_bytes)

        # TODO: This is a somewhat dangerous way of requiring these kwargs to be drilled
        # through, potentially a decorator solution to be had here
        kwargs = {**self._default_invoke_kwargs, **kwargs}
        results = self._invoke(raster_bytes, config=task_config, **kwargs)

        return {
            i: res
            for i, res in zip(
                range(start or 1, (stop or len(document_node)) + 1), results
            )
        }

BasePageClassificationOutputParser

Bases: ABC, BaseOutputParser[ClassificationConfig, ClassificationOutput]

The output parser for the page classification system.

Source code in docprompt/tasks/classification/base.py
class BasePageClassificationOutputParser(
    ABC, BaseOutputParser[ClassificationConfig, ClassificationOutput]
):
    """The output parser for the page classification system."""

    name: str = Field(...)
    type: ClassificationTypes = Field(...)
    labels: LabelType = Field(...)
    confidence: bool = Field(False)

    @classmethod
    def from_task_input(cls, task_input: ClassificationConfig, provider_name: str):
        return cls(
            type=task_input.type,
            name=provider_name,
            labels=task_input.labels,
            confidence=task_input.confidence,
        )

    def resolve_match(self, _match: Union[re.Match, None]) -> LabelType:
        """Get the regex pattern for the output parser."""

        if not _match:
            raise ValueError("Could not find the answer in the text.")

        val = _match.group(1)
        if self.type == ClassificationTypes.BINARY:
            if val not in self.labels:
                raise ValueError(f"Invalid label: {val}")
            return val

        elif self.type == ClassificationTypes.SINGLE_LABEL:
            if val not in self.labels:
                raise ValueError(f"Invalid label: {val}")
            return val

        elif self.type == ClassificationTypes.MULTI_LABEL:
            labels = val.split(", ")
            for label in labels:
                if label not in self.labels:
                    raise ValueError(f"Invalid label: {label}")
            return labels
        else:
            raise ValueError(f"Invalid classification type: {self.type}")

    def resolve_confidence(self, _match: Union[re.Match, None]) -> ConfidenceLevel:
        """Get the confidence level from the text."""

        if not _match:
            return None

        val = _match.group(1).lower()

        return ConfidenceLevel(val)

    @abstractmethod
    def parse(self, text: str) -> ClassificationOutput: ...

resolve_confidence(_match)

Get the confidence level from the text.

Source code in docprompt/tasks/classification/base.py
def resolve_confidence(self, _match: Union[re.Match, None]) -> ConfidenceLevel:
    """Get the confidence level from the text."""

    if not _match:
        return None

    val = _match.group(1).lower()

    return ConfidenceLevel(val)

resolve_match(_match)

Get the regex pattern for the output parser.

Source code in docprompt/tasks/classification/base.py
def resolve_match(self, _match: Union[re.Match, None]) -> LabelType:
    """Get the regex pattern for the output parser."""

    if not _match:
        raise ValueError("Could not find the answer in the text.")

    val = _match.group(1)
    if self.type == ClassificationTypes.BINARY:
        if val not in self.labels:
            raise ValueError(f"Invalid label: {val}")
        return val

    elif self.type == ClassificationTypes.SINGLE_LABEL:
        if val not in self.labels:
            raise ValueError(f"Invalid label: {val}")
        return val

    elif self.type == ClassificationTypes.MULTI_LABEL:
        labels = val.split(", ")
        for label in labels:
            if label not in self.labels:
                raise ValueError(f"Invalid label: {label}")
        return labels
    else:
        raise ValueError(f"Invalid classification type: {self.type}")

ClassificationConfig

Bases: BaseModel

Source code in docprompt/tasks/classification/base.py
class ClassificationConfig(BaseModel):
    type: ClassificationTypes
    labels: LabelType
    descriptions: Optional[List[str]] = Field(
        None, description="The descriptions for each label (if any)."
    )

    instructions: Optional[str] = Field(
        None,
        description="Additional instructions to pass to the LLM for the task. Required for Binary Classification.",
    )

    confidence: bool = Field(False)

    @model_validator(mode="before")
    def validate_label_bindings(cls, data: Any) -> Any:
        """Validate the the label/description bindings based on the type."""

        classification_type = data.get("type", None)
        if classification_type == ClassificationTypes.SINGLE_LABEL:
            labels = data.get("labels", None)
            if not labels:
                raise ValueError(
                    "labels must be provided for single_label classification"
                )
            return data

        elif classification_type == ClassificationTypes.BINARY:
            instructions = data.get("instructions", None)
            if not instructions:
                raise ValueError(
                    "instructions must be provided for binary classification"
                )
            data["labels"] = ["YES", "NO"]
            return data

        elif classification_type == ClassificationTypes.MULTI_LABEL:
            labels = data.get("labels", None)
            if not labels:
                raise ValueError(
                    "labels must be provided for multi_label classification"
                )
            return data

    @model_validator(mode="after")
    def validate_descriptions_length(self):
        if self.descriptions is not None:
            labels = self.labels
            if labels is not None and len(self.descriptions) != len(labels):
                raise ValueError("descriptions must have the same length as labels")
        return self

    @property
    def formatted_labels(self):
        """Produce the formatted labels for the prompt template."""
        raw_labels = self.labels
        if self.descriptions:
            for label, description in zip(raw_labels, self.descriptions):
                yield f"{label}: {description}"
        else:
            yield from raw_labels

formatted_labels property

Produce the formatted labels for the prompt template.

validate_label_bindings(data)

Validate the the label/description bindings based on the type.

Source code in docprompt/tasks/classification/base.py
@model_validator(mode="before")
def validate_label_bindings(cls, data: Any) -> Any:
    """Validate the the label/description bindings based on the type."""

    classification_type = data.get("type", None)
    if classification_type == ClassificationTypes.SINGLE_LABEL:
        labels = data.get("labels", None)
        if not labels:
            raise ValueError(
                "labels must be provided for single_label classification"
            )
        return data

    elif classification_type == ClassificationTypes.BINARY:
        instructions = data.get("instructions", None)
        if not instructions:
            raise ValueError(
                "instructions must be provided for binary classification"
            )
        data["labels"] = ["YES", "NO"]
        return data

    elif classification_type == ClassificationTypes.MULTI_LABEL:
        labels = data.get("labels", None)
        if not labels:
            raise ValueError(
                "labels must be provided for multi_label classification"
            )
        return data

ConfidenceLevel

Bases: str, Enum

The confidence level of the classification.

Source code in docprompt/tasks/classification/base.py
class ConfidenceLevel(str, Enum):
    """The confidence level of the classification."""

    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"