Skip to content

Index

PageTextLocation

Bases: BaseModel

Specifies the location of a piece of text in a page

Source code in docprompt/provenance/source.py
class PageTextLocation(BaseModel):
    """
    Specifies the location of a piece of text in a page
    """

    source_blocks: List[TextBlock] = Field(
        description="The source text blocks", repr=False
    )
    text: str  # Sometimes the source text is less than the textblock's text.
    score: float
    granularity: Literal["word", "line", "block"] = "block"

    merged_source_block: Optional[TextBlock] = Field(default=None)

ProvenanceSource

Bases: BaseModel

Bundled with some data, specifies exactly where a piece of verbatim text came from in a document.

Source code in docprompt/provenance/source.py
class ProvenanceSource(BaseModel):
    """
    Bundled with some data, specifies exactly where a piece of verbatim text came from
    in a document.
    """

    document_name: str
    page_number: PositiveInt
    text_location: Optional[PageTextLocation] = None

    @computed_field  # type: ignore
    @property
    def source_block(self) -> Optional[TextBlock]:
        if self.text_location:
            if self.text_location.merged_source_block:
                return self.text_location.merged_source_block
            if self.text_location.source_blocks:
                return self.text_location.source_blocks[0]

            return None

    @property
    def text(self) -> str:
        if self.text_location:
            return "\n".join([block.text for block in self.text_location.source_blocks])

        return ""

search

DocumentProvenanceLocator dataclass

Source code in docprompt/provenance/search.py
@dataclass
class DocumentProvenanceLocator:
    document_name: str
    search_index: "tantivy.Index"
    block_mapping: Dict[int, OcrPageResult] = field(repr=False)
    geo_index: DocumentProvenanceGeoMap = field(repr=False)

    @classmethod
    def from_document_node(cls, document_node: "DocumentNode"):
        # TODO: See if we can remove the ocr_results attribute from the
        # PageNode and just use the metadata.task_result["<provider>_ocr"],
        # result of the OCR task instead.

        index = create_tantivy_document_wise_block_index()
        block_mapping_dict = {}
        geo_index_dict: DocumentProvenanceGeoMap = {}

        writer = index.writer()

        for page_node in document_node.page_nodes:
            if (
                not page_node.ocr_results.result
                or not page_node.ocr_results.result.block_level_blocks
            ):
                continue

            ocr_result = page_node.ocr_results.result

            for idx, text_block in enumerate(ocr_result.block_level_blocks):
                writer.add_document(
                    tantivy.Document(
                        page_number=page_node.page_number,
                        block_type=text_block.type,
                        block_page_idx=idx,
                        content=text_block.text,
                    )
                )

            for granularity in ["word", "line", "block"]:
                text_blocks = getattr(ocr_result, f"{granularity}_level_blocks", [])

                bounding_boxes = [text_block.bounding_box for text_block in text_blocks]

                if bounding_boxes:
                    r_tree = RTreeIndex(
                        insert_generator(bounding_boxes), fill_factor=0.9
                    )
                else:
                    r_tree = RTreeIndex()

                if page_node.page_number not in geo_index_dict:
                    geo_index_dict[page_node.page_number] = {}

                geo_index_dict[page_node.page_number][granularity] = r_tree  # type: ignore

            block_mapping_dict[page_node.page_number] = ocr_result

        writer.commit()
        index.reload()

        return cls(
            document_name=document_node.document.name,
            search_index=index,
            block_mapping=block_mapping_dict,
            geo_index=geo_index_dict,
        )

    def _construct_tantivy_query(
        self, query: str, page_number: Optional[int] = None
    ) -> tantivy.Query:
        query = preprocess_query_text(query)

        if page_number is None:
            return self.search_index.parse_query(f'content:"{query}"')
        else:
            return self.search_index.parse_query(
                f'(page_number:{page_number}) AND content:"{query}"'
            )

    def get_k_nearest_blocks(
        self,
        bbox: NormBBox,
        page_number: int,
        k: int,
        granularity: BlockGranularity = "block",
    ) -> List[TextBlock]:
        """
        Get the k nearest text blocks to a given bounding box
        """
        search_tuple = construct_valid_rtree_tuple(bbox)

        word_level_bbox_indices = list(
            self.geo_index[page_number][granularity].nearest(
                search_tuple, num_results=k
            )
        )

        block_mapping = self.block_mapping[page_number]

        nearest_blocks = [
            getattr(block_mapping, granularity + "s")[idx]
            for idx in word_level_bbox_indices
        ]

        nearest_blocks.sort(key=lambda x: (x.bounding_box.top, x.bounding_box.x0))

        return [x for x in nearest_blocks if x.bounding_box != bbox]

    def get_overlapping_blocks(
        self, bbox: NormBBox, page_number: int, granularity: BlockGranularity = "block"
    ) -> List[TextBlock]:
        """
        Get the text blocks that overlap with a given bounding box
        """
        search_tuple = construct_valid_rtree_tuple(bbox)

        bbox_indices = list(
            self.geo_index[page_number][granularity].intersection(search_tuple)
        )

        block_mapping = self.block_mapping[page_number]

        overlapping_blocks = [
            getattr(block_mapping, f"{granularity}_level_blocks")[idx]
            for idx in bbox_indices
        ]

        overlapping_blocks.sort(key=lambda x: (x.bounding_box.top, x.bounding_box.x0))

        return [x for x in overlapping_blocks if x.bounding_box != bbox]

    def search_raw(self, raw_query: str) -> List[str]:
        """
        Search for a piece of text using a raw query

        Args:
            query: The text to search for
            page_number: The page number to search on
        """
        parsed_query = self.search_index.parse_query(raw_query)

        searcher = self.search_index.searcher()

        search_results = searcher.search(parsed_query, limit=100)

        results = []

        for score, doc_address in search_results.hits:
            doc = searcher.doc(doc_address)

            result_page_number = doc["page_number"][0]
            result_block_page_idx = doc["block_page_idx"][0]
            block_mapping = self.block_mapping[result_page_number]

            source_block: TextBlock = block_mapping.block_level_blocks[
                result_block_page_idx
            ]

            results.append(source_block.text)

        return results

    def refine_query_to_word_level(
        self, query: str, page_number: int, enclosing_block: TextBlock
    ):
        """
        Refine a query to the word level
        """
        search_tuple = construct_valid_rtree_tuple(enclosing_block.bounding_box)

        word_level_bbox_indices = list(
            self.geo_index[page_number]["word"].intersection(search_tuple)
        )
        word_level_blocks_in_original_bbox = [
            self.block_mapping[page_number].word_level_blocks[idx]
            for idx in word_level_bbox_indices
        ]

        refine_result = refine_block_to_word_level(
            source_block=enclosing_block,
            intersecting_word_level_blocks=word_level_blocks_in_original_bbox,
            query=query,
        )

        return refine_result

    def search(
        self,
        query: str,
        page_number: Optional[int] = None,
        *,
        refine_to_word: bool = True,
        require_exact_match: bool = True,
    ) -> List[ProvenanceSource]:
        """
        Search for a piece of text in the document and return the source of it

        Args:
            query: The text to search for
            page_number: The page number to search on
            refine_to_word: Whether to refine the search to the word level
            require_exact_match: Whether to require null results if `refine_to_word` is True and no exact match is found
        """
        search_query = self._construct_tantivy_query(query, page_number)

        searcher = self.search_index.searcher()

        search_results = searcher.search(search_query, limit=100)

        results = []

        for score, doc_address in search_results.hits:
            doc = searcher.doc(doc_address)

            result_page_number = doc["page_number"][0]
            result_block_page_idx = doc["block_page_idx"][0]
            block_mapping = self.block_mapping[result_page_number]

            source_block: TextBlock = block_mapping.block_level_blocks[
                result_block_page_idx
            ]

            source_blocks = [source_block]
            principal_block = source_block

            if refine_to_word:
                refine_result = self.refine_query_to_word_level(
                    query=query,
                    page_number=result_page_number,
                    enclosing_block=source_block,
                )

                if refine_result is not None:
                    principal_block, source_blocks = refine_result
                elif require_exact_match:
                    continue

            source = ProvenanceSource(
                document_name=self.document_name,
                page_number=result_page_number,
                text_location=PageTextLocation(
                    source_blocks=source_blocks,
                    text=query,
                    score=score,
                    granularity="block",
                    merged_source_block=principal_block,
                ),
            )
            results.append(source)

        results.sort(key=lambda x: x.page_number)

        return results

    def search_n_best(
        self, query: str, n: int = 3, mode: SearchBestModes = "shortest_text"
    ) -> List[ProvenanceSource]:
        results = self.search(query)

        if not results:
            return []

        if mode == "shortest_text":
            score_func = lambda x: len(x.source_block.text)  # noqa: E731
        elif mode == "longest_text":
            score_func = lambda x: -len(x[0].source_block.text)  # noqa: E731
        elif mode == "highest_score":
            score_func = lambda x: x[1]  # noqa: E731
        else:
            raise ValueError(f"Unknown mode {mode}")

        results.sort(key=score_func)

        return results[:n]

get_k_nearest_blocks(bbox, page_number, k, granularity='block')

Get the k nearest text blocks to a given bounding box

Source code in docprompt/provenance/search.py
def get_k_nearest_blocks(
    self,
    bbox: NormBBox,
    page_number: int,
    k: int,
    granularity: BlockGranularity = "block",
) -> List[TextBlock]:
    """
    Get the k nearest text blocks to a given bounding box
    """
    search_tuple = construct_valid_rtree_tuple(bbox)

    word_level_bbox_indices = list(
        self.geo_index[page_number][granularity].nearest(
            search_tuple, num_results=k
        )
    )

    block_mapping = self.block_mapping[page_number]

    nearest_blocks = [
        getattr(block_mapping, granularity + "s")[idx]
        for idx in word_level_bbox_indices
    ]

    nearest_blocks.sort(key=lambda x: (x.bounding_box.top, x.bounding_box.x0))

    return [x for x in nearest_blocks if x.bounding_box != bbox]

get_overlapping_blocks(bbox, page_number, granularity='block')

Get the text blocks that overlap with a given bounding box

Source code in docprompt/provenance/search.py
def get_overlapping_blocks(
    self, bbox: NormBBox, page_number: int, granularity: BlockGranularity = "block"
) -> List[TextBlock]:
    """
    Get the text blocks that overlap with a given bounding box
    """
    search_tuple = construct_valid_rtree_tuple(bbox)

    bbox_indices = list(
        self.geo_index[page_number][granularity].intersection(search_tuple)
    )

    block_mapping = self.block_mapping[page_number]

    overlapping_blocks = [
        getattr(block_mapping, f"{granularity}_level_blocks")[idx]
        for idx in bbox_indices
    ]

    overlapping_blocks.sort(key=lambda x: (x.bounding_box.top, x.bounding_box.x0))

    return [x for x in overlapping_blocks if x.bounding_box != bbox]

refine_query_to_word_level(query, page_number, enclosing_block)

Refine a query to the word level

Source code in docprompt/provenance/search.py
def refine_query_to_word_level(
    self, query: str, page_number: int, enclosing_block: TextBlock
):
    """
    Refine a query to the word level
    """
    search_tuple = construct_valid_rtree_tuple(enclosing_block.bounding_box)

    word_level_bbox_indices = list(
        self.geo_index[page_number]["word"].intersection(search_tuple)
    )
    word_level_blocks_in_original_bbox = [
        self.block_mapping[page_number].word_level_blocks[idx]
        for idx in word_level_bbox_indices
    ]

    refine_result = refine_block_to_word_level(
        source_block=enclosing_block,
        intersecting_word_level_blocks=word_level_blocks_in_original_bbox,
        query=query,
    )

    return refine_result

search(query, page_number=None, *, refine_to_word=True, require_exact_match=True)

Search for a piece of text in the document and return the source of it

Parameters:

Name Type Description Default
query str

The text to search for

required
page_number Optional[int]

The page number to search on

None
refine_to_word bool

Whether to refine the search to the word level

True
require_exact_match bool

Whether to require null results if refine_to_word is True and no exact match is found

True
Source code in docprompt/provenance/search.py
def search(
    self,
    query: str,
    page_number: Optional[int] = None,
    *,
    refine_to_word: bool = True,
    require_exact_match: bool = True,
) -> List[ProvenanceSource]:
    """
    Search for a piece of text in the document and return the source of it

    Args:
        query: The text to search for
        page_number: The page number to search on
        refine_to_word: Whether to refine the search to the word level
        require_exact_match: Whether to require null results if `refine_to_word` is True and no exact match is found
    """
    search_query = self._construct_tantivy_query(query, page_number)

    searcher = self.search_index.searcher()

    search_results = searcher.search(search_query, limit=100)

    results = []

    for score, doc_address in search_results.hits:
        doc = searcher.doc(doc_address)

        result_page_number = doc["page_number"][0]
        result_block_page_idx = doc["block_page_idx"][0]
        block_mapping = self.block_mapping[result_page_number]

        source_block: TextBlock = block_mapping.block_level_blocks[
            result_block_page_idx
        ]

        source_blocks = [source_block]
        principal_block = source_block

        if refine_to_word:
            refine_result = self.refine_query_to_word_level(
                query=query,
                page_number=result_page_number,
                enclosing_block=source_block,
            )

            if refine_result is not None:
                principal_block, source_blocks = refine_result
            elif require_exact_match:
                continue

        source = ProvenanceSource(
            document_name=self.document_name,
            page_number=result_page_number,
            text_location=PageTextLocation(
                source_blocks=source_blocks,
                text=query,
                score=score,
                granularity="block",
                merged_source_block=principal_block,
            ),
        )
        results.append(source)

    results.sort(key=lambda x: x.page_number)

    return results

search_raw(raw_query)

Search for a piece of text using a raw query

Parameters:

Name Type Description Default
query

The text to search for

required
page_number

The page number to search on

required
Source code in docprompt/provenance/search.py
def search_raw(self, raw_query: str) -> List[str]:
    """
    Search for a piece of text using a raw query

    Args:
        query: The text to search for
        page_number: The page number to search on
    """
    parsed_query = self.search_index.parse_query(raw_query)

    searcher = self.search_index.searcher()

    search_results = searcher.search(parsed_query, limit=100)

    results = []

    for score, doc_address in search_results.hits:
        doc = searcher.doc(doc_address)

        result_page_number = doc["page_number"][0]
        result_block_page_idx = doc["block_page_idx"][0]
        block_mapping = self.block_mapping[result_page_number]

        source_block: TextBlock = block_mapping.block_level_blocks[
            result_block_page_idx
        ]

        results.append(source_block.text)

    return results

source

PageTextLocation

Bases: BaseModel

Specifies the location of a piece of text in a page

Source code in docprompt/provenance/source.py
class PageTextLocation(BaseModel):
    """
    Specifies the location of a piece of text in a page
    """

    source_blocks: List[TextBlock] = Field(
        description="The source text blocks", repr=False
    )
    text: str  # Sometimes the source text is less than the textblock's text.
    score: float
    granularity: Literal["word", "line", "block"] = "block"

    merged_source_block: Optional[TextBlock] = Field(default=None)

ProvenanceSource

Bases: BaseModel

Bundled with some data, specifies exactly where a piece of verbatim text came from in a document.

Source code in docprompt/provenance/source.py
class ProvenanceSource(BaseModel):
    """
    Bundled with some data, specifies exactly where a piece of verbatim text came from
    in a document.
    """

    document_name: str
    page_number: PositiveInt
    text_location: Optional[PageTextLocation] = None

    @computed_field  # type: ignore
    @property
    def source_block(self) -> Optional[TextBlock]:
        if self.text_location:
            if self.text_location.merged_source_block:
                return self.text_location.merged_source_block
            if self.text_location.source_blocks:
                return self.text_location.source_blocks[0]

            return None

    @property
    def text(self) -> str:
        if self.text_location:
            return "\n".join([block.text for block in self.text_location.source_blocks])

        return ""

util

insert_generator(bboxes, data=None)

Make an iterator that yields tuples of (id, bbox, data) for insertion into an RTree index which improves performance massively.

Source code in docprompt/provenance/util.py
def insert_generator(bboxes: List[NormBBox], data: Optional[Iterable[Any]] = None):
    """
    Make an iterator that yields tuples of (id, bbox, data) for insertion into an RTree index
    which improves performance massively.
    """
    data = data or [None] * len(bboxes)

    for idx, (bbox, data_item) in enumerate(zip(bboxes, data)):
        yield (idx, construct_valid_rtree_tuple(bbox), data_item)

preprocess_query_text(text)

Improve matching ability by applying some preprocessing to the query text.

Source code in docprompt/provenance/util.py
def preprocess_query_text(text: str) -> str:
    """
    Improve matching ability by applying some preprocessing to the query text.
    """
    for regex in _prefix_regexs:
        text = regex.sub("", text)

    text = text.strip()

    text = text.replace('"', "")

    return text

refine_block_to_word_level(source_block, intersecting_word_level_blocks, query)

Create a new text block by merging the intersecting word level blocks that match the query.

Source code in docprompt/provenance/util.py
def refine_block_to_word_level(
    source_block: TextBlock,
    intersecting_word_level_blocks: List[TextBlock],
    query: str,
):
    """
    Create a new text block by merging the intersecting word level blocks that
    match the query.

    """
    intersecting_word_level_blocks.sort(
        key=lambda x: (x.bounding_box.top, x.bounding_box.x0)
    )

    tokenized_query = word_tokenize(query)

    if len(tokenized_query) == 1:
        fuzzified = default_process(tokenized_query[0])
        for word_level_block in intersecting_word_level_blocks:
            if fuzz.ratio(fuzzified, default_process(word_level_block.text)) > 87.5:
                return word_level_block, [word_level_block]
    else:
        fuzzified_word_level_texts = [
            default_process(word_level_block.text)
            for word_level_block in intersecting_word_level_blocks
        ]

        # Populate the block mapping
        token_block_mapping = defaultdict(set)

        first_word = tokenized_query[0]
        last_word = tokenized_query[-1]

        for token in tokenized_query:
            fuzzified_token = default_process(token)
            for i, word_level_block in enumerate(intersecting_word_level_blocks):
                if fuzz.ratio(fuzzified_token, fuzzified_word_level_texts[i]) > 87.5:
                    token_block_mapping[token].add(i)

        graph = networkx.DiGraph()
        prev = tokenized_query[0]

        for i in token_block_mapping[prev]:
            graph.add_node(i)

        for token in tokenized_query[1:]:
            for prev_block in token_block_mapping[prev]:
                for block in sorted(token_block_mapping[token]):
                    if block > prev_block:
                        weight = (
                            (block - prev_block) ** 2
                        )  # Square the distance to penalize large jumps, which encourages reading order
                        graph.add_edge(prev_block, block, weight=weight)

            prev = token

        # Get every combination of first and last word
        first_word_blocks = token_block_mapping[first_word]
        last_word_blocks = token_block_mapping[last_word]

        combinations = sorted(
            [(x, y) for x in first_word_blocks for y in last_word_blocks if x < y],
            key=lambda x: abs(x[1] - x[0]),
        )

        for start, end in combinations:
            try:
                path = networkx.shortest_path(graph, start, end, weight="weight")
            except networkx.NetworkXNoPath:
                continue
            except Exception:
                continue

            matching_blocks = [intersecting_word_level_blocks[i] for i in path]

            merged_bbox = NormBBox.combine(
                *[word_level_block.bounding_box for word_level_block in matching_blocks]
            )

            merged_text = ""

            for word_level_block in matching_blocks:
                merged_text += word_level_block.text
                if not word_level_block.text.endswith(" "):
                    merged_text += " "  # Ensure there is a space between words

            return (
                TextBlock(
                    text=merged_text,
                    type="block",
                    bounding_box=merged_bbox,
                    metadata=source_block.metadata,
                ),
                matching_blocks,
            )

word_tokenize(text)

Tokenize a string into words.

Source code in docprompt/provenance/util.py
def word_tokenize(text: str) -> List[str]:
    """
    Tokenize a string into words.
    """
    return re.split(r"\s+", text)