A utility file for running inference with various LLM providers.
run_batch_inference_anthropic(model_name, messages, **kwargs)
async
Run batch inference using an Anthropic model asynchronously.
Source code in docprompt/utils/inference.py
| async def run_batch_inference_anthropic(
model_name: str, messages: List[List[OpenAIMessage]], **kwargs
) -> List[str]:
"""Run batch inference using an Anthropic model asynchronously."""
retry_decorator = get_anthropic_retry_decorator()
@retry_decorator
async def process_message_set(msg_set):
return await run_inference_anthropic(model_name, msg_set, **kwargs)
tasks = [process_message_set(msg_set) for msg_set in messages]
responses: List[str] = []
for f in tqdm(asyncio.as_completed(tasks), desc="Processing messages"):
response = await f
responses.append(response)
return responses
|
run_inference_anthropic(model_name, messages, **kwargs)
async
Run inference using an Anthropic model asynchronously.
Source code in docprompt/utils/inference.py
| async def run_inference_anthropic(
model_name: str, messages: List[OpenAIMessage], **kwargs
) -> str:
"""Run inference using an Anthropic model asynchronously."""
from anthropic import AsyncAnthropic
api_key = kwargs.pop("api_key", os.environ.get("ANTHROPIC_API_KEY"))
client = AsyncAnthropic(api_key=api_key)
system = None
if messages and messages[0].role == "system":
system = messages[0].content
messages = messages[1:]
processed_messages = []
for msg in messages:
if isinstance(msg.content, list):
processed_content = []
for content in msg.content:
if isinstance(content, OpenAIComplexContent):
content = content.to_anthropic_message()
processed_content.append(content)
else:
pass
# raise ValueError(f"Invalid content type: {type(content)} Expected OpenAIComplexContent")
dumped = msg.model_dump()
dumped["content"] = processed_content
processed_messages.append(dumped)
else:
processed_messages.append(msg)
client_kwargs = {
"model": model_name,
"max_tokens": 2048,
"messages": processed_messages,
**kwargs,
}
if system:
client_kwargs["system"] = system
response = await client.messages.create(**client_kwargs)
content = response.content[0].text
return content
|