Skip to content

inference

A utility file for running inference with various LLM providers.

run_batch_inference_anthropic(model_name, messages, **kwargs) async

Run batch inference using an Anthropic model asynchronously.

Source code in docprompt/utils/inference.py
async def run_batch_inference_anthropic(
    model_name: str, messages: List[List[OpenAIMessage]], **kwargs
) -> List[str]:
    """Run batch inference using an Anthropic model asynchronously."""
    retry_decorator = get_anthropic_retry_decorator()

    @retry_decorator
    async def process_message_set(msg_set):
        return await run_inference_anthropic(model_name, msg_set, **kwargs)

    tasks = [process_message_set(msg_set) for msg_set in messages]

    responses: List[str] = []
    for f in tqdm(asyncio.as_completed(tasks), desc="Processing messages"):
        response = await f
        responses.append(response)

    return responses

run_inference_anthropic(model_name, messages, **kwargs) async

Run inference using an Anthropic model asynchronously.

Source code in docprompt/utils/inference.py
async def run_inference_anthropic(
    model_name: str, messages: List[OpenAIMessage], **kwargs
) -> str:
    """Run inference using an Anthropic model asynchronously."""
    from anthropic import AsyncAnthropic

    api_key = kwargs.pop("api_key", os.environ.get("ANTHROPIC_API_KEY"))
    client = AsyncAnthropic(api_key=api_key)

    system = None
    if messages and messages[0].role == "system":
        system = messages[0].content
        messages = messages[1:]

    processed_messages = []
    for msg in messages:
        if isinstance(msg.content, list):
            processed_content = []
            for content in msg.content:
                if isinstance(content, OpenAIComplexContent):
                    content = content.to_anthropic_message()
                    processed_content.append(content)
                else:
                    pass
                    # raise ValueError(f"Invalid content type: {type(content)} Expected OpenAIComplexContent")

            dumped = msg.model_dump()
            dumped["content"] = processed_content
            processed_messages.append(dumped)
        else:
            processed_messages.append(msg)

    client_kwargs = {
        "model": model_name,
        "max_tokens": 2048,
        "messages": processed_messages,
        **kwargs,
    }

    if system:
        client_kwargs["system"] = system

    response = await client.messages.create(**client_kwargs)

    content = response.content[0].text

    return content