This commit is contained in:
2025-04-27 21:22:28 +01:00
parent 05f6f149ad
commit 5399169b11
5193 changed files with 843837 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
File generated from our OpenAPI spec by Stainless.
This directory can be used to store custom files to expand the SDK.
It is ignored by Stainless code generation and its content (other than this keep file) won't be touched.

View File

@@ -0,0 +1,2 @@
from ._tools import pydantic_function_tool as pydantic_function_tool
from ._parsing import ResponseFormatT as ResponseFormatT

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing_extensions import override
from .._utils import LazyProxy
from .._exceptions import OpenAIError
INSTRUCTIONS = """
You tried to access openai.{symbol}, but this is no longer supported in openai>=1.0.0 - see the README at https://github.com/openai/openai-python for the API.
You can run `openai migrate` to automatically upgrade your codebase to use the 1.0.0 interface.
Alternatively, you can pin your installation to the old version, e.g. `pip install openai==0.28`
A detailed migration guide is available here: https://github.com/openai/openai-python/discussions/742
"""
class APIRemovedInV1(OpenAIError):
def __init__(self, *, symbol: str) -> None:
super().__init__(INSTRUCTIONS.format(symbol=symbol))
class APIRemovedInV1Proxy(LazyProxy[Any]):
def __init__(self, *, symbol: str) -> None:
super().__init__()
self._symbol = symbol
@override
def __load__(self) -> Any:
# return the proxy until it is eventually called so that
# we don't break people that are just checking the attributes
# of a module
return self
def __call__(self, *_args: Any, **_kwargs: Any) -> Any:
raise APIRemovedInV1(symbol=self._symbol)
SYMBOLS = [
"Edit",
"File",
"Audio",
"Image",
"Model",
"Engine",
"Customer",
"FineTune",
"Embedding",
"Completion",
"Deployment",
"Moderation",
"ErrorObject",
"FineTuningJob",
"ChatCompletion",
]
# we explicitly tell type checkers that nothing is exported
# from this file so that when we re-export the old symbols
# in `openai/__init__.py` they aren't added to the auto-complete
# suggestions given by editors
if TYPE_CHECKING:
__all__: list[str] = []
else:
__all__ = SYMBOLS
__locals = locals()
for symbol in SYMBOLS:
__locals[symbol] = APIRemovedInV1Proxy(symbol=symbol)

View File

@@ -0,0 +1,12 @@
from ._completions import (
ResponseFormatT as ResponseFormatT,
has_parseable_input,
has_parseable_input as has_parseable_input,
maybe_parse_content as maybe_parse_content,
validate_input_tools as validate_input_tools,
parse_chat_completion as parse_chat_completion,
get_input_tool_by_name as get_input_tool_by_name,
solve_response_format_t as solve_response_format_t,
parse_function_tool_arguments as parse_function_tool_arguments,
type_to_response_format_param as type_to_response_format_param,
)

View File

@@ -0,0 +1,264 @@
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Iterable, cast
from typing_extensions import TypeVar, TypeGuard, assert_never
import pydantic
from .._tools import PydanticFunctionTool
from ..._types import NOT_GIVEN, NotGiven
from ..._utils import is_dict, is_given
from ..._compat import PYDANTIC_V2, model_parse_json
from ..._models import construct_type_unchecked
from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
from ...types.chat import (
ParsedChoice,
ChatCompletion,
ParsedFunction,
ParsedChatCompletion,
ChatCompletionMessage,
ParsedFunctionToolCall,
ChatCompletionToolParam,
ParsedChatCompletionMessage,
completion_create_params,
)
from ..._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
from ...types.shared_params import FunctionDefinition
from ...types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
from ...types.chat.chat_completion_message_tool_call import Function
ResponseFormatT = TypeVar(
"ResponseFormatT",
# if it isn't given then we don't do any parsing
default=None,
)
_default_response_format: None = None
def validate_input_tools(
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> None:
if not is_given(tools):
return
for tool in tools:
if tool["type"] != "function":
raise ValueError(
f"Currently only `function` tool types support auto-parsing; Received `{tool['type']}`",
)
strict = tool["function"].get("strict")
if strict is not True:
raise ValueError(
f"`{tool['function']['name']}` is not strict. Only `strict` function tools can be auto-parsed"
)
def parse_chat_completion(
*,
response_format: type[ResponseFormatT] | completion_create_params.ResponseFormat | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
chat_completion: ChatCompletion | ParsedChatCompletion[object],
) -> ParsedChatCompletion[ResponseFormatT]:
if is_given(input_tools):
input_tools = [t for t in input_tools]
else:
input_tools = []
choices: list[ParsedChoice[ResponseFormatT]] = []
for choice in chat_completion.choices:
if choice.finish_reason == "length":
raise LengthFinishReasonError(completion=chat_completion)
if choice.finish_reason == "content_filter":
raise ContentFilterFinishReasonError()
message = choice.message
tool_calls: list[ParsedFunctionToolCall] = []
if message.tool_calls:
for tool_call in message.tool_calls:
if tool_call.type == "function":
tool_call_dict = tool_call.to_dict()
tool_calls.append(
construct_type_unchecked(
value={
**tool_call_dict,
"function": {
**cast(Any, tool_call_dict["function"]),
"parsed_arguments": parse_function_tool_arguments(
input_tools=input_tools, function=tool_call.function
),
},
},
type_=ParsedFunctionToolCall,
)
)
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(tool_call)
else:
tool_calls.append(tool_call)
choices.append(
construct_type_unchecked(
type_=cast(Any, ParsedChoice)[solve_response_format_t(response_format)],
value={
**choice.to_dict(),
"message": {
**message.to_dict(),
"parsed": maybe_parse_content(
response_format=response_format,
message=message,
),
"tool_calls": tool_calls if tool_calls else None,
},
},
)
)
return cast(
ParsedChatCompletion[ResponseFormatT],
construct_type_unchecked(
type_=cast(Any, ParsedChatCompletion)[solve_response_format_t(response_format)],
value={
**chat_completion.to_dict(),
"choices": choices,
},
),
)
def get_input_tool_by_name(*, input_tools: list[ChatCompletionToolParam], name: str) -> ChatCompletionToolParam | None:
return next((t for t in input_tools if t.get("function", {}).get("name") == name), None)
def parse_function_tool_arguments(
*, input_tools: list[ChatCompletionToolParam], function: Function | ParsedFunction
) -> object:
input_tool = get_input_tool_by_name(input_tools=input_tools, name=function.name)
if not input_tool:
return None
input_fn = cast(object, input_tool.get("function"))
if isinstance(input_fn, PydanticFunctionTool):
return model_parse_json(input_fn.model, function.arguments)
input_fn = cast(FunctionDefinition, input_fn)
if not input_fn.get("strict"):
return None
return json.loads(function.arguments)
def maybe_parse_content(
*,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
message: ChatCompletionMessage | ParsedChatCompletionMessage[object],
) -> ResponseFormatT | None:
if has_rich_response_format(response_format) and message.content and not message.refusal:
return _parse_content(response_format, message.content)
return None
def solve_response_format_t(
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
) -> type[ResponseFormatT]:
"""Return the runtime type for the given response format.
If no response format is given, or if we won't auto-parse the response format
then we default to `None`.
"""
if has_rich_response_format(response_format):
return response_format
return cast("type[ResponseFormatT]", _default_response_format)
def has_parseable_input(
*,
response_format: type | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
) -> bool:
if has_rich_response_format(response_format):
return True
for input_tool in input_tools or []:
if is_parseable_tool(input_tool):
return True
return False
def has_rich_response_format(
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
) -> TypeGuard[type[ResponseFormatT]]:
if not is_given(response_format):
return False
if is_response_format_param(response_format):
return False
return True
def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
return is_dict(response_format)
def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool:
input_fn = cast(object, input_tool.get("function"))
if isinstance(input_fn, PydanticFunctionTool):
return True
return cast(FunctionDefinition, input_fn).get("strict") or False
def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
if is_basemodel_type(response_format):
return cast(ResponseFormatT, model_parse_json(response_format, content))
if is_dataclass_like_type(response_format):
if not PYDANTIC_V2:
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
return pydantic.TypeAdapter(response_format).validate_json(content)
raise TypeError(f"Unable to automatically parse response format type {response_format}")
def type_to_response_format_param(
response_format: type | completion_create_params.ResponseFormat | NotGiven,
) -> ResponseFormatParam | NotGiven:
if not is_given(response_format):
return NOT_GIVEN
if is_response_format_param(response_format):
return response_format
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
# a safe default behaviour but we know that at this point the `response_format`
# can only be a `type`
response_format = cast(type, response_format)
json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
if is_basemodel_type(response_format):
name = response_format.__name__
json_schema_type = response_format
elif is_dataclass_like_type(response_format):
name = response_format.__name__
json_schema_type = pydantic.TypeAdapter(response_format)
else:
raise TypeError(f"Unsupported response_format type - {response_format}")
return {
"type": "json_schema",
"json_schema": {
"schema": to_strict_json_schema(json_schema_type),
"name": name,
"strict": True,
},
}

View File

@@ -0,0 +1,168 @@
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, List, Iterable, cast
from typing_extensions import TypeVar, assert_never
import pydantic
from .._tools import ResponsesPydanticFunctionTool
from ..._types import NotGiven
from ..._utils import is_given
from ..._compat import PYDANTIC_V2, model_parse_json
from ..._models import construct_type_unchecked
from .._pydantic import is_basemodel_type, is_dataclass_like_type
from ._completions import solve_response_format_t, type_to_response_format_param
from ...types.responses import (
Response,
ToolParam,
ParsedContent,
ParsedResponse,
FunctionToolParam,
ParsedResponseOutputItem,
ParsedResponseOutputText,
ResponseFunctionToolCall,
ParsedResponseOutputMessage,
ResponseFormatTextConfigParam,
ParsedResponseFunctionToolCall,
)
from ...types.chat.completion_create_params import ResponseFormat
TextFormatT = TypeVar(
"TextFormatT",
# if it isn't given then we don't do any parsing
default=None,
)
def type_to_text_format_param(type_: type) -> ResponseFormatTextConfigParam:
response_format_dict = type_to_response_format_param(type_)
assert is_given(response_format_dict)
response_format_dict = cast(ResponseFormat, response_format_dict) # pyright: ignore[reportUnnecessaryCast]
assert response_format_dict["type"] == "json_schema"
assert "schema" in response_format_dict["json_schema"]
return {
"type": "json_schema",
"strict": True,
"name": response_format_dict["json_schema"]["name"],
"schema": response_format_dict["json_schema"]["schema"],
}
def parse_response(
*,
text_format: type[TextFormatT] | NotGiven,
input_tools: Iterable[ToolParam] | NotGiven | None,
response: Response | ParsedResponse[object],
) -> ParsedResponse[TextFormatT]:
solved_t = solve_response_format_t(text_format)
output_list: List[ParsedResponseOutputItem[TextFormatT]] = []
for output in response.output:
if output.type == "message":
content_list: List[ParsedContent[TextFormatT]] = []
for item in output.content:
if item.type != "output_text":
content_list.append(item)
continue
content_list.append(
construct_type_unchecked(
type_=cast(Any, ParsedResponseOutputText)[solved_t],
value={
**item.to_dict(),
"parsed": parse_text(item.text, text_format=text_format),
},
)
)
output_list.append(
construct_type_unchecked(
type_=cast(Any, ParsedResponseOutputMessage)[solved_t],
value={
**output.to_dict(),
"content": content_list,
},
)
)
elif output.type == "function_call":
output_list.append(
construct_type_unchecked(
type_=ParsedResponseFunctionToolCall,
value={
**output.to_dict(),
"parsed_arguments": parse_function_tool_arguments(
input_tools=input_tools, function_call=output
),
},
)
)
elif (
output.type == "computer_call"
or output.type == "file_search_call"
or output.type == "web_search_call"
or output.type == "reasoning"
):
output_list.append(output)
elif TYPE_CHECKING: # type: ignore
assert_never(output)
else:
output_list.append(output)
return cast(
ParsedResponse[TextFormatT],
construct_type_unchecked(
type_=cast(Any, ParsedResponse)[solved_t],
value={
**response.to_dict(),
"output": output_list,
},
),
)
def parse_text(text: str, text_format: type[TextFormatT] | NotGiven) -> TextFormatT | None:
if not is_given(text_format):
return None
if is_basemodel_type(text_format):
return cast(TextFormatT, model_parse_json(text_format, text))
if is_dataclass_like_type(text_format):
if not PYDANTIC_V2:
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {text_format}")
return pydantic.TypeAdapter(text_format).validate_json(text)
raise TypeError(f"Unable to automatically parse response format type {text_format}")
def get_input_tool_by_name(*, input_tools: Iterable[ToolParam], name: str) -> FunctionToolParam | None:
for tool in input_tools:
if tool["type"] == "function" and tool.get("name") == name:
return tool
return None
def parse_function_tool_arguments(
*,
input_tools: Iterable[ToolParam] | NotGiven | None,
function_call: ParsedResponseFunctionToolCall | ResponseFunctionToolCall,
) -> object:
if input_tools is None or not is_given(input_tools):
return None
input_tool = get_input_tool_by_name(input_tools=input_tools, name=function_call.name)
if not input_tool:
return None
tool = cast(object, input_tool)
if isinstance(tool, ResponsesPydanticFunctionTool):
return model_parse_json(tool.model, function_call.arguments)
if not input_tool.get("strict"):
return None
return json.loads(function_call.arguments)

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
import inspect
from typing import Any, TypeVar
from typing_extensions import TypeGuard
import pydantic
from .._types import NOT_GIVEN
from .._utils import is_dict as _is_dict, is_list
from .._compat import PYDANTIC_V2, model_json_schema
_T = TypeVar("_T")
def to_strict_json_schema(model: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any]) -> dict[str, Any]:
if inspect.isclass(model) and is_basemodel_type(model):
schema = model_json_schema(model)
elif PYDANTIC_V2 and isinstance(model, pydantic.TypeAdapter):
schema = model.json_schema()
else:
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {model}")
return _ensure_strict_json_schema(schema, path=(), root=schema)
def _ensure_strict_json_schema(
json_schema: object,
*,
path: tuple[str, ...],
root: dict[str, object],
) -> dict[str, Any]:
"""Mutates the given JSON schema to ensure it conforms to the `strict` standard
that the API expects.
"""
if not is_dict(json_schema):
raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")
defs = json_schema.get("$defs")
if is_dict(defs):
for def_name, def_schema in defs.items():
_ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name), root=root)
definitions = json_schema.get("definitions")
if is_dict(definitions):
for definition_name, definition_schema in definitions.items():
_ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name), root=root)
typ = json_schema.get("type")
if typ == "object" and "additionalProperties" not in json_schema:
json_schema["additionalProperties"] = False
# object types
# { 'type': 'object', 'properties': { 'a': {...} } }
properties = json_schema.get("properties")
if is_dict(properties):
json_schema["required"] = [prop for prop in properties.keys()]
json_schema["properties"] = {
key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root)
for key, prop_schema in properties.items()
}
# arrays
# { 'type': 'array', 'items': {...} }
items = json_schema.get("items")
if is_dict(items):
json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root)
# unions
any_of = json_schema.get("anyOf")
if is_list(any_of):
json_schema["anyOf"] = [
_ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
for i, variant in enumerate(any_of)
]
# intersections
all_of = json_schema.get("allOf")
if is_list(all_of):
if len(all_of) == 1:
json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root))
json_schema.pop("allOf")
else:
json_schema["allOf"] = [
_ensure_strict_json_schema(entry, path=(*path, "allOf", str(i)), root=root)
for i, entry in enumerate(all_of)
]
# strip `None` defaults as there's no meaningful distinction here
# the schema will still be `nullable` and the model will default
# to using `None` anyway
if json_schema.get("default", NOT_GIVEN) is None:
json_schema.pop("default")
# we can't use `$ref`s if there are also other properties defined, e.g.
# `{"$ref": "...", "description": "my description"}`
#
# so we unravel the ref
# `{"type": "string", "description": "my description"}`
ref = json_schema.get("$ref")
if ref and has_more_than_n_keys(json_schema, 1):
assert isinstance(ref, str), f"Received non-string $ref - {ref}"
resolved = resolve_ref(root=root, ref=ref)
if not is_dict(resolved):
raise ValueError(f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}")
# properties from the json schema take priority over the ones on the `$ref`
json_schema.update({**resolved, **json_schema})
json_schema.pop("$ref")
# Since the schema expanded from `$ref` might not have `additionalProperties: false` applied,
# we call `_ensure_strict_json_schema` again to fix the inlined schema and ensure it's valid.
return _ensure_strict_json_schema(json_schema, path=path, root=root)
return json_schema
def resolve_ref(*, root: dict[str, object], ref: str) -> object:
if not ref.startswith("#/"):
raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")
path = ref[2:].split("/")
resolved = root
for key in path:
value = resolved[key]
assert is_dict(value), f"encountered non-dictionary entry while resolving {ref} - {resolved}"
resolved = value
return resolved
def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
if not inspect.isclass(typ):
return False
return issubclass(typ, pydantic.BaseModel)
def is_dataclass_like_type(typ: type) -> bool:
"""Returns True if the given type likely used `@pydantic.dataclass`"""
return hasattr(typ, "__pydantic_config__")
def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
# just pretend that we know there are only `str` keys
# as that check is not worth the performance cost
return _is_dict(obj)
def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
i = 0
for _ in obj.keys():
i += 1
if i > n:
return True
return False

View File

@@ -0,0 +1,66 @@
from __future__ import annotations
from typing import Any, Dict, cast
import pydantic
from ._pydantic import to_strict_json_schema
from ..types.chat import ChatCompletionToolParam
from ..types.shared_params import FunctionDefinition
from ..types.responses.function_tool_param import FunctionToolParam as ResponsesFunctionToolParam
class PydanticFunctionTool(Dict[str, Any]):
"""Dictionary wrapper so we can pass the given base model
throughout the entire request stack without having to special
case it.
"""
model: type[pydantic.BaseModel]
def __init__(self, defn: FunctionDefinition, model: type[pydantic.BaseModel]) -> None:
super().__init__(defn)
self.model = model
def cast(self) -> FunctionDefinition:
return cast(FunctionDefinition, self)
class ResponsesPydanticFunctionTool(Dict[str, Any]):
model: type[pydantic.BaseModel]
def __init__(self, tool: ResponsesFunctionToolParam, model: type[pydantic.BaseModel]) -> None:
super().__init__(tool)
self.model = model
def cast(self) -> ResponsesFunctionToolParam:
return cast(ResponsesFunctionToolParam, self)
def pydantic_function_tool(
model: type[pydantic.BaseModel],
*,
name: str | None = None, # inferred from class name by default
description: str | None = None, # inferred from class docstring by default
) -> ChatCompletionToolParam:
if description is None:
# note: we intentionally don't use `.getdoc()` to avoid
# including pydantic's docstrings
description = model.__doc__
function = PydanticFunctionTool(
{
"name": name or model.__name__,
"strict": True,
"parameters": to_strict_json_schema(model),
},
model,
).cast()
if description is not None:
function["description"] = description
return {
"type": "function",
"function": function,
}

View File

@@ -0,0 +1,809 @@
# pyright: basic
from __future__ import annotations
import os
import sys
from typing import Any, TypeVar, Callable, Optional, NamedTuple
from typing_extensions import TypeAlias
from .._extras import pandas as pd
class Remediation(NamedTuple):
name: str
immediate_msg: Optional[str] = None
necessary_msg: Optional[str] = None
necessary_fn: Optional[Callable[[Any], Any]] = None
optional_msg: Optional[str] = None
optional_fn: Optional[Callable[[Any], Any]] = None
error_msg: Optional[str] = None
OptionalDataFrameT = TypeVar("OptionalDataFrameT", bound="Optional[pd.DataFrame]")
def num_examples_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will only print out the number of examples and recommend to the user to increase the number of examples if less than 100.
"""
MIN_EXAMPLES = 100
optional_suggestion = (
""
if len(df) >= MIN_EXAMPLES
else ". In general, we recommend having at least a few hundred examples. We've found that performance tends to linearly increase for every doubling of the number of examples"
)
immediate_msg = f"\n- Your file contains {len(df)} prompt-completion pairs{optional_suggestion}"
return Remediation(name="num_examples", immediate_msg=immediate_msg)
def necessary_column_validator(df: pd.DataFrame, necessary_column: str) -> Remediation:
"""
This validator will ensure that the necessary column is present in the dataframe.
"""
def lower_case_column(df: pd.DataFrame, column: Any) -> pd.DataFrame:
cols = [c for c in df.columns if str(c).lower() == column]
df.rename(columns={cols[0]: column.lower()}, inplace=True)
return df
immediate_msg = None
necessary_fn = None
necessary_msg = None
error_msg = None
if necessary_column not in df.columns:
if necessary_column in [str(c).lower() for c in df.columns]:
def lower_case_column_creator(df: pd.DataFrame) -> pd.DataFrame:
return lower_case_column(df, necessary_column)
necessary_fn = lower_case_column_creator
immediate_msg = f"\n- The `{necessary_column}` column/key should be lowercase"
necessary_msg = f"Lower case column name to `{necessary_column}`"
else:
error_msg = f"`{necessary_column}` column/key is missing. Please make sure you name your columns/keys appropriately, then retry"
return Remediation(
name="necessary_column",
immediate_msg=immediate_msg,
necessary_msg=necessary_msg,
necessary_fn=necessary_fn,
error_msg=error_msg,
)
def additional_column_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
"""
This validator will remove additional columns from the dataframe.
"""
additional_columns = []
necessary_msg = None
immediate_msg = None
necessary_fn = None # type: ignore
if len(df.columns) > 2:
additional_columns = [c for c in df.columns if c not in fields]
warn_message = ""
for ac in additional_columns:
dups = [c for c in additional_columns if ac in c]
if len(dups) > 0:
warn_message += f"\n WARNING: Some of the additional columns/keys contain `{ac}` in their name. These will be ignored, and the column/key `{ac}` will be used instead. This could also result from a duplicate column/key in the provided file."
immediate_msg = f"\n- The input file should contain exactly two columns/keys per row. Additional columns/keys present are: {additional_columns}{warn_message}"
necessary_msg = f"Remove additional columns/keys: {additional_columns}"
def necessary_fn(x: Any) -> Any:
return x[fields]
return Remediation(
name="additional_column",
immediate_msg=immediate_msg,
necessary_msg=necessary_msg,
necessary_fn=necessary_fn,
)
def non_empty_field_validator(df: pd.DataFrame, field: str = "completion") -> Remediation:
"""
This validator will ensure that no completion is empty.
"""
necessary_msg = None
necessary_fn = None # type: ignore
immediate_msg = None
if df[field].apply(lambda x: x == "").any() or df[field].isnull().any():
empty_rows = (df[field] == "") | (df[field].isnull())
empty_indexes = df.reset_index().index[empty_rows].tolist()
immediate_msg = f"\n- `{field}` column/key should not contain empty strings. These are rows: {empty_indexes}"
def necessary_fn(x: Any) -> Any:
return x[x[field] != ""].dropna(subset=[field])
necessary_msg = f"Remove {len(empty_indexes)} rows with empty {field}s"
return Remediation(
name=f"empty_{field}",
immediate_msg=immediate_msg,
necessary_msg=necessary_msg,
necessary_fn=necessary_fn,
)
def duplicated_rows_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
"""
This validator will suggest to the user to remove duplicate rows if they exist.
"""
duplicated_rows = df.duplicated(subset=fields)
duplicated_indexes = df.reset_index().index[duplicated_rows].tolist()
immediate_msg = None
optional_msg = None
optional_fn = None # type: ignore
if len(duplicated_indexes) > 0:
immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated {'-'.join(fields)} sets. These are rows: {duplicated_indexes}"
optional_msg = f"Remove {len(duplicated_indexes)} duplicate rows"
def optional_fn(x: Any) -> Any:
return x.drop_duplicates(subset=fields)
return Remediation(
name="duplicated_rows",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
)
def long_examples_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to the user to remove examples that are too long.
"""
immediate_msg = None
optional_msg = None
optional_fn = None # type: ignore
ft_type = infer_task_type(df)
if ft_type != "open-ended generation":
def get_long_indexes(d: pd.DataFrame) -> Any:
long_examples = d.apply(lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1)
return d.reset_index().index[long_examples].tolist()
long_indexes = get_long_indexes(df)
if len(long_indexes) > 0:
immediate_msg = f"\n- There are {len(long_indexes)} examples that are very long. These are rows: {long_indexes}\nFor conditional generation, and for classification the examples shouldn't be longer than 2048 tokens."
optional_msg = f"Remove {len(long_indexes)} long examples"
def optional_fn(x: Any) -> Any:
long_indexes_to_drop = get_long_indexes(x)
if long_indexes != long_indexes_to_drop:
sys.stdout.write(
f"The indices of the long examples has changed as a result of a previously applied recommendation.\nThe {len(long_indexes_to_drop)} long examples to be dropped are now at the following indices: {long_indexes_to_drop}\n"
)
return x.drop(long_indexes_to_drop)
return Remediation(
name="long_examples",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
)
def common_prompt_suffix_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to add a common suffix to the prompt if one doesn't already exist in case of classification or conditional generation.
"""
error_msg = None
immediate_msg = None
optional_msg = None
optional_fn = None # type: ignore
# Find a suffix which is not contained within the prompt otherwise
suggested_suffix = "\n\n### =>\n\n"
suffix_options = [
" ->",
"\n\n###\n\n",
"\n\n===\n\n",
"\n\n---\n\n",
"\n\n===>\n\n",
"\n\n--->\n\n",
]
for suffix_option in suffix_options:
if suffix_option == " ->":
if df.prompt.str.contains("\n").any():
continue
if df.prompt.str.contains(suffix_option, regex=False).any():
continue
suggested_suffix = suffix_option
break
display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
ft_type = infer_task_type(df)
if ft_type == "open-ended generation":
return Remediation(name="common_suffix")
def add_suffix(x: Any, suffix: Any) -> Any:
x["prompt"] += suffix
return x
common_suffix = get_common_xfix(df.prompt, xfix="suffix")
if (df.prompt == common_suffix).all():
error_msg = f"All prompts are identical: `{common_suffix}`\nConsider leaving the prompts blank if you want to do open-ended generation, otherwise ensure prompts are different"
return Remediation(name="common_suffix", error_msg=error_msg)
if common_suffix != "":
common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
immediate_msg = f"\n- All prompts end with suffix `{common_suffix_new_line_handled}`"
if len(common_suffix) > 10:
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
if df.prompt.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
immediate_msg += f"\n WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix"
else:
immediate_msg = "\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty"
if common_suffix == "":
optional_msg = f"Add a suffix separator `{display_suggested_suffix}` to all prompts"
def optional_fn(x: Any) -> Any:
return add_suffix(x, suggested_suffix)
return Remediation(
name="common_completion_suffix",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
error_msg=error_msg,
)
def common_prompt_prefix_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to remove a common prefix from the prompt if a long one exist.
"""
MAX_PREFIX_LEN = 12
immediate_msg = None
optional_msg = None
optional_fn = None # type: ignore
common_prefix = get_common_xfix(df.prompt, xfix="prefix")
if common_prefix == "":
return Remediation(name="common_prefix")
def remove_common_prefix(x: Any, prefix: Any) -> Any:
x["prompt"] = x["prompt"].str[len(prefix) :]
return x
if (df.prompt == common_prefix).all():
# already handled by common_suffix_validator
return Remediation(name="common_prefix")
if common_prefix != "":
immediate_msg = f"\n- All prompts start with prefix `{common_prefix}`"
if MAX_PREFIX_LEN < len(common_prefix):
immediate_msg += ". Fine-tuning doesn't require the instruction specifying the task, or a few-shot example scenario. Most of the time you should only add the input data into the prompt, and the desired output into the completion"
optional_msg = f"Remove prefix `{common_prefix}` from all prompts"
def optional_fn(x: Any) -> Any:
return remove_common_prefix(x, common_prefix)
return Remediation(
name="common_prompt_prefix",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
)
def common_completion_prefix_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to remove a common prefix from the completion if a long one exist.
"""
MAX_PREFIX_LEN = 5
common_prefix = get_common_xfix(df.completion, xfix="prefix")
ws_prefix = len(common_prefix) > 0 and common_prefix[0] == " "
if len(common_prefix) < MAX_PREFIX_LEN:
return Remediation(name="common_prefix")
def remove_common_prefix(x: Any, prefix: Any, ws_prefix: Any) -> Any:
x["completion"] = x["completion"].str[len(prefix) :]
if ws_prefix:
# keep the single whitespace as prefix
x["completion"] = f" {x['completion']}"
return x
if (df.completion == common_prefix).all():
# already handled by common_suffix_validator
return Remediation(name="common_prefix")
immediate_msg = f"\n- All completions start with prefix `{common_prefix}`. Most of the time you should only add the output data into the completion, without any prefix"
optional_msg = f"Remove prefix `{common_prefix}` from all completions"
def optional_fn(x: Any) -> Any:
return remove_common_prefix(x, common_prefix, ws_prefix)
return Remediation(
name="common_completion_prefix",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
)
def common_completion_suffix_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation.
"""
error_msg = None
immediate_msg = None
optional_msg = None
optional_fn = None # type: ignore
ft_type = infer_task_type(df)
if ft_type == "open-ended generation" or ft_type == "classification":
return Remediation(name="common_suffix")
common_suffix = get_common_xfix(df.completion, xfix="suffix")
if (df.completion == common_suffix).all():
error_msg = f"All completions are identical: `{common_suffix}`\nEnsure completions are different, otherwise the model will just repeat `{common_suffix}`"
return Remediation(name="common_suffix", error_msg=error_msg)
# Find a suffix which is not contained within the completion otherwise
suggested_suffix = " [END]"
suffix_options = [
"\n",
".",
" END",
"***",
"+++",
"&&&",
"$$$",
"@@@",
"%%%",
]
for suffix_option in suffix_options:
if df.completion.str.contains(suffix_option, regex=False).any():
continue
suggested_suffix = suffix_option
break
display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
def add_suffix(x: Any, suffix: Any) -> Any:
x["completion"] += suffix
return x
if common_suffix != "":
common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
immediate_msg = f"\n- All completions end with suffix `{common_suffix_new_line_handled}`"
if len(common_suffix) > 10:
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
if df.completion.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
immediate_msg += f"\n WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending"
else:
immediate_msg = "\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples."
if common_suffix == "":
optional_msg = f"Add a suffix ending `{display_suggested_suffix}` to all completions"
def optional_fn(x: Any) -> Any:
return add_suffix(x, suggested_suffix)
return Remediation(
name="common_completion_suffix",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
error_msg=error_msg,
)
def completions_space_start_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will suggest to add a space at the start of the completion if it doesn't already exist. This helps with tokenization.
"""
def add_space_start(x: Any) -> Any:
x["completion"] = x["completion"].apply(lambda s: ("" if s.startswith(" ") else " ") + s)
return x
optional_msg = None
optional_fn = None
immediate_msg = None
if df.completion.str[:1].nunique() != 1 or df.completion.values[0][0] != " ":
immediate_msg = "\n- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details"
optional_msg = "Add a whitespace character to the beginning of the completion"
optional_fn = add_space_start
return Remediation(
name="completion_space_start",
immediate_msg=immediate_msg,
optional_msg=optional_msg,
optional_fn=optional_fn,
)
def lower_case_validator(df: pd.DataFrame, column: Any) -> Remediation | None:
"""
This validator will suggest to lowercase the column values, if more than a third of letters are uppercase.
"""
def lower_case(x: Any) -> Any:
x[column] = x[column].str.lower()
return x
count_upper = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.isupper())).sum()
count_lower = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.islower())).sum()
if count_upper * 2 > count_lower:
return Remediation(
name="lower_case",
immediate_msg=f"\n- More than a third of your `{column}` column/key is uppercase. Uppercase {column}s tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details",
optional_msg=f"Lowercase all your data in column/key `{column}`",
optional_fn=lower_case,
)
return None
def read_any_format(
fname: str, fields: list[str] = ["prompt", "completion"]
) -> tuple[pd.DataFrame | None, Remediation]:
"""
This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas.
- for .xlsx it will read the first sheet
- for .txt it will assume completions and split on newline
"""
remediation = None
necessary_msg = None
immediate_msg = None
error_msg = None
df = None
if os.path.isfile(fname):
try:
if fname.lower().endswith(".csv") or fname.lower().endswith(".tsv"):
file_extension_str, separator = ("CSV", ",") if fname.lower().endswith(".csv") else ("TSV", "\t")
immediate_msg = (
f"\n- Based on your file extension, your file is formatted as a {file_extension_str} file"
)
necessary_msg = f"Your format `{file_extension_str}` will be converted to `JSONL`"
df = pd.read_csv(fname, sep=separator, dtype=str).fillna("")
elif fname.lower().endswith(".xlsx"):
immediate_msg = "\n- Based on your file extension, your file is formatted as an Excel file"
necessary_msg = "Your format `XLSX` will be converted to `JSONL`"
xls = pd.ExcelFile(fname)
sheets = xls.sheet_names
if len(sheets) > 1:
immediate_msg += "\n- Your Excel file contains more than one sheet. Please either save as csv or ensure all data is present in the first sheet. WARNING: Reading only the first sheet..."
df = pd.read_excel(fname, dtype=str).fillna("")
elif fname.lower().endswith(".txt"):
immediate_msg = "\n- Based on your file extension, you provided a text file"
necessary_msg = "Your format `TXT` will be converted to `JSONL`"
with open(fname, "r") as f:
content = f.read()
df = pd.DataFrame(
[["", line] for line in content.split("\n")],
columns=fields,
dtype=str,
).fillna("")
elif fname.lower().endswith(".jsonl"):
df = pd.read_json(fname, lines=True, dtype=str).fillna("") # type: ignore
if len(df) == 1: # type: ignore
# this is NOT what we expect for a .jsonl file
immediate_msg = "\n- Your JSONL file appears to be in a JSON format. Your file will be converted to JSONL format"
necessary_msg = "Your format `JSON` will be converted to `JSONL`"
df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
else:
pass # this is what we expect for a .jsonl file
elif fname.lower().endswith(".json"):
try:
# to handle case where .json file is actually a .jsonl file
df = pd.read_json(fname, lines=True, dtype=str).fillna("") # type: ignore
if len(df) == 1: # type: ignore
# this code path corresponds to a .json file that has one line
df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
else:
# this is NOT what we expect for a .json file
immediate_msg = "\n- Your JSON file appears to be in a JSONL format. Your file will be converted to JSONL format"
necessary_msg = "Your format `JSON` will be converted to `JSONL`"
except ValueError:
# this code path corresponds to a .json file that has multiple lines (i.e. it is indented)
df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
else:
error_msg = (
"Your file must have one of the following extensions: .CSV, .TSV, .XLSX, .TXT, .JSON or .JSONL"
)
if "." in fname:
error_msg += f" Your file `{fname}` ends with the extension `.{fname.split('.')[-1]}` which is not supported."
else:
error_msg += f" Your file `{fname}` is missing a file extension."
except (ValueError, TypeError):
file_extension_str = fname.split(".")[-1].upper()
error_msg = f"Your file `{fname}` does not appear to be in valid {file_extension_str} format. Please ensure your file is formatted as a valid {file_extension_str} file."
else:
error_msg = f"File {fname} does not exist."
remediation = Remediation(
name="read_any_format",
necessary_msg=necessary_msg,
immediate_msg=immediate_msg,
error_msg=error_msg,
)
return df, remediation
def format_inferrer_validator(df: pd.DataFrame) -> Remediation:
"""
This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification.
It will also suggest to use ada and explain train/validation split benefits.
"""
ft_type = infer_task_type(df)
immediate_msg = None
if ft_type == "classification":
immediate_msg = f"\n- Based on your data it seems like you're trying to fine-tune a model for {ft_type}\n- For classification, we recommend you try one of the faster and cheaper models, such as `ada`\n- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training"
return Remediation(name="num_examples", immediate_msg=immediate_msg)
def apply_necessary_remediation(df: OptionalDataFrameT, remediation: Remediation) -> OptionalDataFrameT:
"""
This function will apply a necessary remediation to a dataframe, or print an error message if one exists.
"""
if remediation.error_msg is not None:
sys.stderr.write(f"\n\nERROR in {remediation.name} validator: {remediation.error_msg}\n\nAborting...")
sys.exit(1)
if remediation.immediate_msg is not None:
sys.stdout.write(remediation.immediate_msg)
if remediation.necessary_fn is not None:
df = remediation.necessary_fn(df)
return df
def accept_suggestion(input_text: str, auto_accept: bool) -> bool:
sys.stdout.write(input_text)
if auto_accept:
sys.stdout.write("Y\n")
return True
return input().lower() != "n"
def apply_optional_remediation(
df: pd.DataFrame, remediation: Remediation, auto_accept: bool
) -> tuple[pd.DataFrame, bool]:
"""
This function will apply an optional remediation to a dataframe, based on the user input.
"""
optional_applied = False
input_text = f"- [Recommended] {remediation.optional_msg} [Y/n]: "
if remediation.optional_msg is not None:
if accept_suggestion(input_text, auto_accept):
assert remediation.optional_fn is not None
df = remediation.optional_fn(df)
optional_applied = True
if remediation.necessary_msg is not None:
sys.stdout.write(f"- [Necessary] {remediation.necessary_msg}\n")
return df, optional_applied
def estimate_fine_tuning_time(df: pd.DataFrame) -> None:
"""
Estimate the time it'll take to fine-tune the dataset
"""
ft_format = infer_task_type(df)
expected_time = 1.0
if ft_format == "classification":
num_examples = len(df)
expected_time = num_examples * 1.44
else:
size = df.memory_usage(index=True).sum()
expected_time = size * 0.0515
def format_time(time: float) -> str:
if time < 60:
return f"{round(time, 2)} seconds"
elif time < 3600:
return f"{round(time / 60, 2)} minutes"
elif time < 86400:
return f"{round(time / 3600, 2)} hours"
else:
return f"{round(time / 86400, 2)} days"
time_string = format_time(expected_time + 140)
sys.stdout.write(
f"Once your model starts training, it'll approximately take {time_string} to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.\n"
)
def get_outfnames(fname: str, split: bool) -> list[str]:
suffixes = ["_train", "_valid"] if split else [""]
i = 0
while True:
index_suffix = f" ({i})" if i > 0 else ""
candidate_fnames = [f"{os.path.splitext(fname)[0]}_prepared{suffix}{index_suffix}.jsonl" for suffix in suffixes]
if not any(os.path.isfile(f) for f in candidate_fnames):
return candidate_fnames
i += 1
def get_classification_hyperparams(df: pd.DataFrame) -> tuple[int, object]:
n_classes = df.completion.nunique()
pos_class = None
if n_classes == 2:
pos_class = df.completion.value_counts().index[0]
return n_classes, pos_class
def write_out_file(df: pd.DataFrame, fname: str, any_remediations: bool, auto_accept: bool) -> None:
"""
This function will write out a dataframe to a file, if the user would like to proceed, and also offer a fine-tuning command with the newly created file.
For classification it will optionally ask the user if they would like to split the data into train/valid files, and modify the suggested command to include the valid set.
"""
ft_format = infer_task_type(df)
common_prompt_suffix = get_common_xfix(df.prompt, xfix="suffix")
common_completion_suffix = get_common_xfix(df.completion, xfix="suffix")
split = False
input_text = "- [Recommended] Would you like to split into training and validation set? [Y/n]: "
if ft_format == "classification":
if accept_suggestion(input_text, auto_accept):
split = True
additional_params = ""
common_prompt_suffix_new_line_handled = common_prompt_suffix.replace("\n", "\\n")
common_completion_suffix_new_line_handled = common_completion_suffix.replace("\n", "\\n")
optional_ending_string = (
f' Make sure to include `stop=["{common_completion_suffix_new_line_handled}"]` so that the generated texts ends at the expected place.'
if len(common_completion_suffix_new_line_handled) > 0
else ""
)
input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: "
if not any_remediations and not split:
sys.stdout.write(
f'\nYou can use your file for fine-tuning:\n> openai api fine_tunes.create -t "{fname}"{additional_params}\n\nAfter youve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n'
)
estimate_fine_tuning_time(df)
elif accept_suggestion(input_text, auto_accept):
fnames = get_outfnames(fname, split)
if split:
assert len(fnames) == 2 and "train" in fnames[0] and "valid" in fnames[1]
MAX_VALID_EXAMPLES = 1000
n_train = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8))
df_train = df.sample(n=n_train, random_state=42)
df_valid = df.drop(df_train.index)
df_train[["prompt", "completion"]].to_json( # type: ignore
fnames[0], lines=True, orient="records", force_ascii=False, indent=None
)
df_valid[["prompt", "completion"]].to_json(
fnames[1], lines=True, orient="records", force_ascii=False, indent=None
)
n_classes, pos_class = get_classification_hyperparams(df)
additional_params += " --compute_classification_metrics"
if n_classes == 2:
additional_params += f' --classification_positive_class "{pos_class}"'
else:
additional_params += f" --classification_n_classes {n_classes}"
else:
assert len(fnames) == 1
df[["prompt", "completion"]].to_json(
fnames[0], lines=True, orient="records", force_ascii=False, indent=None
)
# Add -v VALID_FILE if we split the file into train / valid
files_string = ("s" if split else "") + " to `" + ("` and `".join(fnames))
valid_string = f' -v "{fnames[1]}"' if split else ""
separator_reminder = (
""
if len(common_prompt_suffix_new_line_handled) == 0
else f"After youve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt."
)
sys.stdout.write(
f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{fnames[0]}"{valid_string}{additional_params}\n\n{separator_reminder}{optional_ending_string}\n'
)
estimate_fine_tuning_time(df)
else:
sys.stdout.write("Aborting... did not write the file\n")
def infer_task_type(df: pd.DataFrame) -> str:
"""
Infer the likely fine-tuning task type from the data
"""
CLASSIFICATION_THRESHOLD = 3 # min_average instances of each class
if sum(df.prompt.str.len()) == 0:
return "open-ended generation"
if len(df.completion.unique()) < len(df) / CLASSIFICATION_THRESHOLD:
return "classification"
return "conditional generation"
def get_common_xfix(series: Any, xfix: str = "suffix") -> str:
"""
Finds the longest common suffix or prefix of all the values in a series
"""
common_xfix = ""
while True:
common_xfixes = (
series.str[-(len(common_xfix) + 1) :] if xfix == "suffix" else series.str[: len(common_xfix) + 1]
) # first few or last few characters
if common_xfixes.nunique() != 1: # we found the character at which we don't have a unique xfix anymore
break
elif common_xfix == common_xfixes.values[0]: # the entire first row is a prefix of every other row
break
else: # the first or last few characters are still common across all rows - let's try to add one more
common_xfix = common_xfixes.values[0]
return common_xfix
Validator: TypeAlias = "Callable[[pd.DataFrame], Remediation | None]"
def get_validators() -> list[Validator]:
return [
num_examples_validator,
lambda x: necessary_column_validator(x, "prompt"),
lambda x: necessary_column_validator(x, "completion"),
additional_column_validator,
non_empty_field_validator,
format_inferrer_validator,
duplicated_rows_validator,
long_examples_validator,
lambda x: lower_case_validator(x, "prompt"),
lambda x: lower_case_validator(x, "completion"),
common_prompt_suffix_validator,
common_prompt_prefix_validator,
common_completion_prefix_validator,
common_completion_suffix_validator,
completions_space_start_validator,
]
def apply_validators(
df: pd.DataFrame,
fname: str,
remediation: Remediation | None,
validators: list[Validator],
auto_accept: bool,
write_out_file_func: Callable[..., Any],
) -> None:
optional_remediations: list[Remediation] = []
if remediation is not None:
optional_remediations.append(remediation)
for validator in validators:
remediation = validator(df)
if remediation is not None:
optional_remediations.append(remediation)
df = apply_necessary_remediation(df, remediation)
any_optional_or_necessary_remediations = any(
[
remediation
for remediation in optional_remediations
if remediation.optional_msg is not None or remediation.necessary_msg is not None
]
)
any_necessary_applied = any(
[remediation for remediation in optional_remediations if remediation.necessary_msg is not None]
)
any_optional_applied = False
if any_optional_or_necessary_remediations:
sys.stdout.write("\n\nBased on the analysis we will perform the following actions:\n")
for remediation in optional_remediations:
df, optional_applied = apply_optional_remediation(df, remediation, auto_accept)
any_optional_applied = any_optional_applied or optional_applied
else:
sys.stdout.write("\n\nNo remediations found.\n")
any_optional_or_necessary_applied = any_optional_applied or any_necessary_applied
write_out_file_func(df, fname, any_optional_or_necessary_applied, auto_accept)

View File

@@ -0,0 +1,632 @@
from __future__ import annotations
import os
import inspect
from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload
from typing_extensions import Self, override
import httpx
from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven
from .._utils import is_given, is_mapping
from .._client import OpenAI, AsyncOpenAI
from .._compat import model_copy
from .._models import FinalRequestOptions
from .._streaming import Stream, AsyncStream
from .._exceptions import OpenAIError
from .._base_client import DEFAULT_MAX_RETRIES, BaseClient
_deployments_endpoints = set(
[
"/completions",
"/chat/completions",
"/embeddings",
"/audio/transcriptions",
"/audio/translations",
"/audio/speech",
"/images/generations",
]
)
AzureADTokenProvider = Callable[[], str]
AsyncAzureADTokenProvider = Callable[[], "str | Awaitable[str]"]
_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])
# we need to use a sentinel API key value for Azure AD
# as we don't want to make the `api_key` in the main client Optional
# and Azure AD tokens may be retrieved on a per-request basis
API_KEY_SENTINEL = "".join(["<", "missing API key", ">"])
class MutuallyExclusiveAuthError(OpenAIError):
def __init__(self) -> None:
super().__init__(
"The `api_key`, `azure_ad_token` and `azure_ad_token_provider` arguments are mutually exclusive; Only one can be passed at a time"
)
class BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
_azure_endpoint: httpx.URL | None
_azure_deployment: str | None
@override
def _build_request(
self,
options: FinalRequestOptions,
*,
retries_taken: int = 0,
) -> httpx.Request:
if options.url in _deployments_endpoints and is_mapping(options.json_data):
model = options.json_data.get("model")
if model is not None and "/deployments" not in str(self.base_url.path):
options.url = f"/deployments/{model}{options.url}"
return super()._build_request(options, retries_taken=retries_taken)
@override
def _prepare_url(self, url: str) -> httpx.URL:
"""Adjust the URL if the client was configured with an Azure endpoint + deployment
and the API feature being called is **not** a deployments-based endpoint
(i.e. requires /deployments/deployment-name in the URL path).
"""
if self._azure_deployment and self._azure_endpoint and url not in _deployments_endpoints:
merge_url = httpx.URL(url)
if merge_url.is_relative_url:
merge_raw_path = (
self._azure_endpoint.raw_path.rstrip(b"/") + b"/openai/" + merge_url.raw_path.lstrip(b"/")
)
return self._azure_endpoint.copy_with(raw_path=merge_raw_path)
return merge_url
return super()._prepare_url(url)
class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
@overload
def __init__(
self,
*,
azure_endpoint: str,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
organization: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.Client | None = None,
_strict_response_validation: bool = False,
) -> None: ...
@overload
def __init__(
self,
*,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
organization: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.Client | None = None,
_strict_response_validation: bool = False,
) -> None: ...
@overload
def __init__(
self,
*,
base_url: str,
api_version: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
organization: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.Client | None = None,
_strict_response_validation: bool = False,
) -> None: ...
def __init__(
self,
*,
api_version: str | None = None,
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
base_url: str | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.Client | None = None,
_strict_response_validation: bool = False,
) -> None:
"""Construct a new synchronous azure openai client instance.
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
- `api_key` from `AZURE_OPENAI_API_KEY`
- `organization` from `OPENAI_ORG_ID`
- `project` from `OPENAI_PROJECT_ID`
- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
- `api_version` from `OPENAI_API_VERSION`
- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
Args:
azure_endpoint: Your Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
Not supported with Assistants APIs.
"""
if api_key is None:
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
if azure_ad_token is None:
azure_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN")
if api_key is None and azure_ad_token is None and azure_ad_token_provider is None:
raise OpenAIError(
"Missing credentials. Please pass one of `api_key`, `azure_ad_token`, `azure_ad_token_provider`, or the `AZURE_OPENAI_API_KEY` or `AZURE_OPENAI_AD_TOKEN` environment variables."
)
if api_version is None:
api_version = os.environ.get("OPENAI_API_VERSION")
if api_version is None:
raise ValueError(
"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
)
if default_query is None:
default_query = {"api-version": api_version}
else:
default_query = {**default_query, "api-version": api_version}
if base_url is None:
if azure_endpoint is None:
azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
if azure_endpoint is None:
raise ValueError(
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
)
if azure_deployment is not None:
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
else:
base_url = f"{azure_endpoint.rstrip('/')}/openai"
else:
if azure_endpoint is not None:
raise ValueError("base_url and azure_endpoint are mutually exclusive")
if api_key is None:
# define a sentinel value to avoid any typing issues
api_key = API_KEY_SENTINEL
super().__init__(
api_key=api_key,
organization=organization,
project=project,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
default_headers=default_headers,
default_query=default_query,
http_client=http_client,
websocket_base_url=websocket_base_url,
_strict_response_validation=_strict_response_validation,
)
self._api_version = api_version
self._azure_ad_token = azure_ad_token
self._azure_ad_token_provider = azure_ad_token_provider
self._azure_deployment = azure_deployment if azure_endpoint else None
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
@override
def copy(
self,
*,
api_key: str | None = None,
organization: str | None = None,
project: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
api_version: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AzureADTokenProvider | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.Client | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
return super().copy(
api_key=api_key,
organization=organization,
project=project,
websocket_base_url=websocket_base_url,
base_url=base_url,
timeout=timeout,
http_client=http_client,
max_retries=max_retries,
default_headers=default_headers,
set_default_headers=set_default_headers,
default_query=default_query,
set_default_query=set_default_query,
_extra_kwargs={
"api_version": api_version or self._api_version,
"azure_ad_token": azure_ad_token or self._azure_ad_token,
"azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
**_extra_kwargs,
},
)
with_options = copy
def _get_azure_ad_token(self) -> str | None:
if self._azure_ad_token is not None:
return self._azure_ad_token
provider = self._azure_ad_token_provider
if provider is not None:
token = provider()
if not token or not isinstance(token, str): # pyright: ignore[reportUnnecessaryIsInstance]
raise ValueError(
f"Expected `azure_ad_token_provider` argument to return a string but it returned {token}",
)
return token
return None
@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
options = model_copy(options)
options.headers = headers
azure_ad_token = self._get_azure_ad_token()
if azure_ad_token is not None:
if headers.get("Authorization") is None:
headers["Authorization"] = f"Bearer {azure_ad_token}"
elif self.api_key is not API_KEY_SENTINEL:
if headers.get("api-key") is None:
headers["api-key"] = self.api_key
else:
# should never be hit
raise ValueError("Unable to handle auth")
return options
def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
auth_headers = {}
query = {
**extra_query,
"api-version": self._api_version,
"deployment": self._azure_deployment or model,
}
if self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = self._get_azure_ad_token()
if token:
auth_headers = {"Authorization": f"Bearer {token}"}
if self.websocket_base_url is not None:
base_url = httpx.URL(self.websocket_base_url)
merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
realtime_url = base_url.copy_with(raw_path=merge_raw_path)
else:
base_url = self._prepare_url("/realtime")
realtime_url = base_url.copy_with(scheme="wss")
url = realtime_url.copy_with(params={**query})
return url, auth_headers
class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
@overload
def __init__(
self,
*,
azure_endpoint: str,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.AsyncClient | None = None,
_strict_response_validation: bool = False,
) -> None: ...
@overload
def __init__(
self,
*,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.AsyncClient | None = None,
_strict_response_validation: bool = False,
) -> None: ...
@overload
def __init__(
self,
*,
base_url: str,
api_version: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.AsyncClient | None = None,
_strict_response_validation: bool = False,
) -> None: ...
def __init__(
self,
*,
azure_endpoint: str | None = None,
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
max_retries: int = DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
http_client: httpx.AsyncClient | None = None,
_strict_response_validation: bool = False,
) -> None:
"""Construct a new asynchronous azure openai client instance.
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
- `api_key` from `AZURE_OPENAI_API_KEY`
- `organization` from `OPENAI_ORG_ID`
- `project` from `OPENAI_PROJECT_ID`
- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
- `api_version` from `OPENAI_API_VERSION`
- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
Args:
azure_endpoint: Your Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
Not supported with Assistants APIs.
"""
if api_key is None:
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
if azure_ad_token is None:
azure_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN")
if api_key is None and azure_ad_token is None and azure_ad_token_provider is None:
raise OpenAIError(
"Missing credentials. Please pass one of `api_key`, `azure_ad_token`, `azure_ad_token_provider`, or the `AZURE_OPENAI_API_KEY` or `AZURE_OPENAI_AD_TOKEN` environment variables."
)
if api_version is None:
api_version = os.environ.get("OPENAI_API_VERSION")
if api_version is None:
raise ValueError(
"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
)
if default_query is None:
default_query = {"api-version": api_version}
else:
default_query = {**default_query, "api-version": api_version}
if base_url is None:
if azure_endpoint is None:
azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
if azure_endpoint is None:
raise ValueError(
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
)
if azure_deployment is not None:
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
else:
base_url = f"{azure_endpoint.rstrip('/')}/openai"
else:
if azure_endpoint is not None:
raise ValueError("base_url and azure_endpoint are mutually exclusive")
if api_key is None:
# define a sentinel value to avoid any typing issues
api_key = API_KEY_SENTINEL
super().__init__(
api_key=api_key,
organization=organization,
project=project,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
default_headers=default_headers,
default_query=default_query,
http_client=http_client,
websocket_base_url=websocket_base_url,
_strict_response_validation=_strict_response_validation,
)
self._api_version = api_version
self._azure_ad_token = azure_ad_token
self._azure_ad_token_provider = azure_ad_token_provider
self._azure_deployment = azure_deployment if azure_endpoint else None
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
@override
def copy(
self,
*,
api_key: str | None = None,
organization: str | None = None,
project: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
api_version: str | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.AsyncClient | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Create a new client instance re-using the same options given to the current client with optional overriding.
"""
return super().copy(
api_key=api_key,
organization=organization,
project=project,
websocket_base_url=websocket_base_url,
base_url=base_url,
timeout=timeout,
http_client=http_client,
max_retries=max_retries,
default_headers=default_headers,
set_default_headers=set_default_headers,
default_query=default_query,
set_default_query=set_default_query,
_extra_kwargs={
"api_version": api_version or self._api_version,
"azure_ad_token": azure_ad_token or self._azure_ad_token,
"azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
**_extra_kwargs,
},
)
with_options = copy
async def _get_azure_ad_token(self) -> str | None:
if self._azure_ad_token is not None:
return self._azure_ad_token
provider = self._azure_ad_token_provider
if provider is not None:
token = provider()
if inspect.isawaitable(token):
token = await token
if not token or not isinstance(cast(Any, token), str):
raise ValueError(
f"Expected `azure_ad_token_provider` argument to return a string but it returned {token}",
)
return str(token)
return None
@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
options = model_copy(options)
options.headers = headers
azure_ad_token = await self._get_azure_ad_token()
if azure_ad_token is not None:
if headers.get("Authorization") is None:
headers["Authorization"] = f"Bearer {azure_ad_token}"
elif self.api_key is not API_KEY_SENTINEL:
if headers.get("api-key") is None:
headers["api-key"] = self.api_key
else:
# should never be hit
raise ValueError("Unable to handle auth")
return options
async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
auth_headers = {}
query = {
**extra_query,
"api-version": self._api_version,
"deployment": self._azure_deployment or model,
}
if self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = await self._get_azure_ad_token()
if token:
auth_headers = {"Authorization": f"Bearer {token}"}
if self.websocket_base_url is not None:
base_url = httpx.URL(self.websocket_base_url)
merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
realtime_url = base_url.copy_with(raw_path=merge_raw_path)
else:
base_url = self._prepare_url("/realtime")
realtime_url = base_url.copy_with(scheme="wss")
url = realtime_url.copy_with(params={**query})
return url, auth_headers

View File

@@ -0,0 +1,8 @@
from ._assistants import (
AssistantEventHandler as AssistantEventHandler,
AssistantEventHandlerT as AssistantEventHandlerT,
AssistantStreamManager as AssistantStreamManager,
AsyncAssistantEventHandler as AsyncAssistantEventHandler,
AsyncAssistantEventHandlerT as AsyncAssistantEventHandlerT,
AsyncAssistantStreamManager as AsyncAssistantStreamManager,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,64 @@
from __future__ import annotations
from ..._utils import is_dict, is_list
def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]:
for key, delta_value in delta.items():
if key not in acc:
acc[key] = delta_value
continue
acc_value = acc[key]
if acc_value is None:
acc[key] = delta_value
continue
# the `index` property is used in arrays of objects so it should
# not be accumulated like other values e.g.
# [{'foo': 'bar', 'index': 0}]
#
# the same applies to `type` properties as they're used for
# discriminated unions
if key == "index" or key == "type":
acc[key] = delta_value
continue
if isinstance(acc_value, str) and isinstance(delta_value, str):
acc_value += delta_value
elif isinstance(acc_value, (int, float)) and isinstance(delta_value, (int, float)):
acc_value += delta_value
elif is_dict(acc_value) and is_dict(delta_value):
acc_value = accumulate_delta(acc_value, delta_value)
elif is_list(acc_value) and is_list(delta_value):
# for lists of non-dictionary items we'll only ever get new entries
# in the array, existing entries will never be changed
if all(isinstance(x, (str, int, float)) for x in acc_value):
acc_value.extend(delta_value)
continue
for delta_entry in delta_value:
if not is_dict(delta_entry):
raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}")
try:
index = delta_entry["index"]
except KeyError as exc:
raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc
if not isinstance(index, int):
raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}")
try:
acc_entry = acc_value[index]
except IndexError:
acc_value.insert(index, delta_entry)
else:
if not is_dict(acc_entry):
raise TypeError("not handled yet")
acc_value[index] = accumulate_delta(acc_entry, delta_entry)
acc[key] = acc_value
return acc

View File

@@ -0,0 +1,27 @@
from ._types import (
ParsedChoiceSnapshot as ParsedChoiceSnapshot,
ParsedChatCompletionSnapshot as ParsedChatCompletionSnapshot,
ParsedChatCompletionMessageSnapshot as ParsedChatCompletionMessageSnapshot,
)
from ._events import (
ChunkEvent as ChunkEvent,
ContentDoneEvent as ContentDoneEvent,
RefusalDoneEvent as RefusalDoneEvent,
ContentDeltaEvent as ContentDeltaEvent,
RefusalDeltaEvent as RefusalDeltaEvent,
LogprobsContentDoneEvent as LogprobsContentDoneEvent,
LogprobsRefusalDoneEvent as LogprobsRefusalDoneEvent,
ChatCompletionStreamEvent as ChatCompletionStreamEvent,
LogprobsContentDeltaEvent as LogprobsContentDeltaEvent,
LogprobsRefusalDeltaEvent as LogprobsRefusalDeltaEvent,
ParsedChatCompletionSnapshot as ParsedChatCompletionSnapshot,
FunctionToolCallArgumentsDoneEvent as FunctionToolCallArgumentsDoneEvent,
FunctionToolCallArgumentsDeltaEvent as FunctionToolCallArgumentsDeltaEvent,
)
from ._completions import (
ChatCompletionStream as ChatCompletionStream,
AsyncChatCompletionStream as AsyncChatCompletionStream,
ChatCompletionStreamState as ChatCompletionStreamState,
ChatCompletionStreamManager as ChatCompletionStreamManager,
AsyncChatCompletionStreamManager as AsyncChatCompletionStreamManager,
)

View File

@@ -0,0 +1,768 @@
from __future__ import annotations
import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, Callable, Iterable, Awaitable, AsyncIterator, cast
from typing_extensions import Self, Iterator, assert_never
from jiter import from_json
from ._types import ParsedChoiceSnapshot, ParsedChatCompletionSnapshot, ParsedChatCompletionMessageSnapshot
from ._events import (
ChunkEvent,
ContentDoneEvent,
RefusalDoneEvent,
ContentDeltaEvent,
RefusalDeltaEvent,
LogprobsContentDoneEvent,
LogprobsRefusalDoneEvent,
ChatCompletionStreamEvent,
LogprobsContentDeltaEvent,
LogprobsRefusalDeltaEvent,
FunctionToolCallArgumentsDoneEvent,
FunctionToolCallArgumentsDeltaEvent,
)
from .._deltas import accumulate_delta
from ...._types import NOT_GIVEN, IncEx, NotGiven
from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
from ...._compat import model_dump
from ...._models import build, construct_type
from ..._parsing import (
ResponseFormatT,
has_parseable_input,
maybe_parse_content,
parse_chat_completion,
get_input_tool_by_name,
solve_response_format_t,
parse_function_tool_arguments,
)
from ...._streaming import Stream, AsyncStream
from ....types.chat import ChatCompletionChunk, ParsedChatCompletion, ChatCompletionToolParam
from ...._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
from ....types.chat.chat_completion import ChoiceLogprobs
from ....types.chat.chat_completion_chunk import Choice as ChoiceChunk
from ....types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
class ChatCompletionStream(Generic[ResponseFormatT]):
"""Wrapper over the Chat Completions streaming API that adds helpful
events such as `content.done`, supports automatically parsing
responses & tool calls and accumulates a `ChatCompletion` object
from each individual chunk.
https://platform.openai.com/docs/api-reference/streaming
"""
def __init__(
self,
*,
raw_stream: Stream[ChatCompletionChunk],
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
) -> None:
self._raw_stream = raw_stream
self._response = raw_stream.response
self._iterator = self.__stream__()
self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools)
def __next__(self) -> ChatCompletionStreamEvent[ResponseFormatT]:
return self._iterator.__next__()
def __iter__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]:
for item in self._iterator:
yield item
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self._response.close()
def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
"""Waits until the stream has been read to completion and returns
the accumulated `ParsedChatCompletion` object.
If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed`
property will be the content deserialised into that class, if there was any content returned
by the API.
"""
self.until_done()
return self._state.get_final_completion()
def until_done(self) -> Self:
"""Blocks until the stream has been consumed."""
consume_sync_iterator(self)
return self
@property
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
return self._state.current_completion_snapshot
def __stream__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]:
for sse_event in self._raw_stream:
if not _is_valid_chat_completion_chunk_weak(sse_event):
continue
events_to_fire = self._state.handle_chunk(sse_event)
for event in events_to_fire:
yield event
class ChatCompletionStreamManager(Generic[ResponseFormatT]):
"""Context manager over a `ChatCompletionStream` that is returned by `.stream()`.
This context manager ensures the response cannot be leaked if you don't read
the stream to completion.
Usage:
```py
with client.beta.chat.completions.stream(...) as stream:
for event in stream:
...
```
"""
def __init__(
self,
api_request: Callable[[], Stream[ChatCompletionChunk]],
*,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
) -> None:
self.__stream: ChatCompletionStream[ResponseFormatT] | None = None
self.__api_request = api_request
self.__response_format = response_format
self.__input_tools = input_tools
def __enter__(self) -> ChatCompletionStream[ResponseFormatT]:
raw_stream = self.__api_request()
self.__stream = ChatCompletionStream(
raw_stream=raw_stream,
response_format=self.__response_format,
input_tools=self.__input_tools,
)
return self.__stream
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
self.__stream.close()
class AsyncChatCompletionStream(Generic[ResponseFormatT]):
"""Wrapper over the Chat Completions streaming API that adds helpful
events such as `content.done`, supports automatically parsing
responses & tool calls and accumulates a `ChatCompletion` object
from each individual chunk.
https://platform.openai.com/docs/api-reference/streaming
"""
def __init__(
self,
*,
raw_stream: AsyncStream[ChatCompletionChunk],
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
) -> None:
self._raw_stream = raw_stream
self._response = raw_stream.response
self._iterator = self.__stream__()
self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools)
async def __anext__(self) -> ChatCompletionStreamEvent[ResponseFormatT]:
return await self._iterator.__anext__()
async def __aiter__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]:
async for item in self._iterator:
yield item
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self._response.aclose()
async def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
"""Waits until the stream has been read to completion and returns
the accumulated `ParsedChatCompletion` object.
If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed`
property will be the content deserialised into that class, if there was any content returned
by the API.
"""
await self.until_done()
return self._state.get_final_completion()
async def until_done(self) -> Self:
"""Blocks until the stream has been consumed."""
await consume_async_iterator(self)
return self
@property
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
return self._state.current_completion_snapshot
async def __stream__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]:
async for sse_event in self._raw_stream:
if not _is_valid_chat_completion_chunk_weak(sse_event):
continue
events_to_fire = self._state.handle_chunk(sse_event)
for event in events_to_fire:
yield event
class AsyncChatCompletionStreamManager(Generic[ResponseFormatT]):
"""Context manager over a `AsyncChatCompletionStream` that is returned by `.stream()`.
This context manager ensures the response cannot be leaked if you don't read
the stream to completion.
Usage:
```py
async with client.beta.chat.completions.stream(...) as stream:
for event in stream:
...
```
"""
def __init__(
self,
api_request: Awaitable[AsyncStream[ChatCompletionChunk]],
*,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
) -> None:
self.__stream: AsyncChatCompletionStream[ResponseFormatT] | None = None
self.__api_request = api_request
self.__response_format = response_format
self.__input_tools = input_tools
async def __aenter__(self) -> AsyncChatCompletionStream[ResponseFormatT]:
raw_stream = await self.__api_request
self.__stream = AsyncChatCompletionStream(
raw_stream=raw_stream,
response_format=self.__response_format,
input_tools=self.__input_tools,
)
return self.__stream
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
await self.__stream.close()
class ChatCompletionStreamState(Generic[ResponseFormatT]):
"""Helper class for manually accumulating `ChatCompletionChunk`s into a final `ChatCompletion` object.
This is useful in cases where you can't always use the `.stream()` method, e.g.
```py
from openai.lib.streaming.chat import ChatCompletionStreamState
state = ChatCompletionStreamState()
stream = client.chat.completions.create(..., stream=True)
for chunk in response:
state.handle_chunk(chunk)
# can also access the accumulated `ChatCompletion` mid-stream
state.current_completion_snapshot
print(state.get_final_completion())
```
"""
def __init__(
self,
*,
input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven = NOT_GIVEN,
) -> None:
self.__current_completion_snapshot: ParsedChatCompletionSnapshot | None = None
self.__choice_event_states: list[ChoiceEventState] = []
self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
self._response_format = response_format
self._rich_response_format: type | NotGiven = response_format if inspect.isclass(response_format) else NOT_GIVEN
def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
"""Parse the final completion object.
Note this does not provide any guarantees that the stream has actually finished, you must
only call this method when the stream is finished.
"""
return parse_chat_completion(
chat_completion=self.current_completion_snapshot,
response_format=self._rich_response_format,
input_tools=self._input_tools,
)
@property
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
assert self.__current_completion_snapshot is not None
return self.__current_completion_snapshot
def handle_chunk(self, chunk: ChatCompletionChunk) -> Iterable[ChatCompletionStreamEvent[ResponseFormatT]]:
"""Accumulate a new chunk into the snapshot and returns an iterable of events to yield."""
self.__current_completion_snapshot = self._accumulate_chunk(chunk)
return self._build_events(
chunk=chunk,
completion_snapshot=self.__current_completion_snapshot,
)
def _get_choice_state(self, choice: ChoiceChunk) -> ChoiceEventState:
try:
return self.__choice_event_states[choice.index]
except IndexError:
choice_state = ChoiceEventState(input_tools=self._input_tools)
self.__choice_event_states.append(choice_state)
return choice_state
def _accumulate_chunk(self, chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot:
completion_snapshot = self.__current_completion_snapshot
if completion_snapshot is None:
return _convert_initial_chunk_into_snapshot(chunk)
for choice in chunk.choices:
try:
choice_snapshot = completion_snapshot.choices[choice.index]
previous_tool_calls = choice_snapshot.message.tool_calls or []
choice_snapshot.message = cast(
ParsedChatCompletionMessageSnapshot,
construct_type(
type_=ParsedChatCompletionMessageSnapshot,
value=accumulate_delta(
cast(
"dict[object, object]",
model_dump(
choice_snapshot.message,
# we don't want to serialise / deserialise our custom properties
# as they won't appear in the delta and we don't want to have to
# continuosly reparse the content
exclude=cast(
# cast required as mypy isn't smart enough to infer `True` here to `Literal[True]`
IncEx,
{
"parsed": True,
"tool_calls": {
idx: {"function": {"parsed_arguments": True}}
for idx, _ in enumerate(choice_snapshot.message.tool_calls or [])
},
},
),
),
),
cast("dict[object, object]", choice.delta.to_dict()),
),
),
)
# ensure tools that have already been parsed are added back into the newly
# constructed message snapshot
for tool_index, prev_tool in enumerate(previous_tool_calls):
new_tool = (choice_snapshot.message.tool_calls or [])[tool_index]
if prev_tool.type == "function":
assert new_tool.type == "function"
new_tool.function.parsed_arguments = prev_tool.function.parsed_arguments
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(prev_tool)
except IndexError:
choice_snapshot = cast(
ParsedChoiceSnapshot,
construct_type(
type_=ParsedChoiceSnapshot,
value={
**choice.model_dump(exclude_unset=True, exclude={"delta"}),
"message": choice.delta.to_dict(),
},
),
)
completion_snapshot.choices.append(choice_snapshot)
if choice.finish_reason:
choice_snapshot.finish_reason = choice.finish_reason
if has_parseable_input(response_format=self._response_format, input_tools=self._input_tools):
if choice.finish_reason == "length":
# at the time of writing, `.usage` will always be `None` but
# we include it here in case that is changed in the future
raise LengthFinishReasonError(completion=completion_snapshot)
if choice.finish_reason == "content_filter":
raise ContentFilterFinishReasonError()
if (
choice_snapshot.message.content
and not choice_snapshot.message.refusal
and is_given(self._rich_response_format)
):
choice_snapshot.message.parsed = from_json(
bytes(choice_snapshot.message.content, "utf-8"),
partial_mode=True,
)
for tool_call_chunk in choice.delta.tool_calls or []:
tool_call_snapshot = (choice_snapshot.message.tool_calls or [])[tool_call_chunk.index]
if tool_call_snapshot.type == "function":
input_tool = get_input_tool_by_name(
input_tools=self._input_tools, name=tool_call_snapshot.function.name
)
if (
input_tool
and input_tool.get("function", {}).get("strict")
and tool_call_snapshot.function.arguments
):
tool_call_snapshot.function.parsed_arguments = from_json(
bytes(tool_call_snapshot.function.arguments, "utf-8"),
partial_mode=True,
)
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(tool_call_snapshot)
if choice.logprobs is not None:
if choice_snapshot.logprobs is None:
choice_snapshot.logprobs = build(
ChoiceLogprobs,
content=choice.logprobs.content,
refusal=choice.logprobs.refusal,
)
else:
if choice.logprobs.content:
if choice_snapshot.logprobs.content is None:
choice_snapshot.logprobs.content = []
choice_snapshot.logprobs.content.extend(choice.logprobs.content)
if choice.logprobs.refusal:
if choice_snapshot.logprobs.refusal is None:
choice_snapshot.logprobs.refusal = []
choice_snapshot.logprobs.refusal.extend(choice.logprobs.refusal)
completion_snapshot.usage = chunk.usage
completion_snapshot.system_fingerprint = chunk.system_fingerprint
return completion_snapshot
def _build_events(
self,
*,
chunk: ChatCompletionChunk,
completion_snapshot: ParsedChatCompletionSnapshot,
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
events_to_fire.append(
build(ChunkEvent, type="chunk", chunk=chunk, snapshot=completion_snapshot),
)
for choice in chunk.choices:
choice_state = self._get_choice_state(choice)
choice_snapshot = completion_snapshot.choices[choice.index]
if choice.delta.content is not None and choice_snapshot.message.content is not None:
events_to_fire.append(
build(
ContentDeltaEvent,
type="content.delta",
delta=choice.delta.content,
snapshot=choice_snapshot.message.content,
parsed=choice_snapshot.message.parsed,
)
)
if choice.delta.refusal is not None and choice_snapshot.message.refusal is not None:
events_to_fire.append(
build(
RefusalDeltaEvent,
type="refusal.delta",
delta=choice.delta.refusal,
snapshot=choice_snapshot.message.refusal,
)
)
if choice.delta.tool_calls:
tool_calls = choice_snapshot.message.tool_calls
assert tool_calls is not None
for tool_call_delta in choice.delta.tool_calls:
tool_call = tool_calls[tool_call_delta.index]
if tool_call.type == "function":
assert tool_call_delta.function is not None
events_to_fire.append(
build(
FunctionToolCallArgumentsDeltaEvent,
type="tool_calls.function.arguments.delta",
name=tool_call.function.name,
index=tool_call_delta.index,
arguments=tool_call.function.arguments,
parsed_arguments=tool_call.function.parsed_arguments,
arguments_delta=tool_call_delta.function.arguments or "",
)
)
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(tool_call)
if choice.logprobs is not None and choice_snapshot.logprobs is not None:
if choice.logprobs.content and choice_snapshot.logprobs.content:
events_to_fire.append(
build(
LogprobsContentDeltaEvent,
type="logprobs.content.delta",
content=choice.logprobs.content,
snapshot=choice_snapshot.logprobs.content,
),
)
if choice.logprobs.refusal and choice_snapshot.logprobs.refusal:
events_to_fire.append(
build(
LogprobsRefusalDeltaEvent,
type="logprobs.refusal.delta",
refusal=choice.logprobs.refusal,
snapshot=choice_snapshot.logprobs.refusal,
),
)
events_to_fire.extend(
choice_state.get_done_events(
choice_chunk=choice,
choice_snapshot=choice_snapshot,
response_format=self._response_format,
)
)
return events_to_fire
class ChoiceEventState:
def __init__(self, *, input_tools: list[ChatCompletionToolParam]) -> None:
self._input_tools = input_tools
self._content_done = False
self._refusal_done = False
self._logprobs_content_done = False
self._logprobs_refusal_done = False
self._done_tool_calls: set[int] = set()
self.__current_tool_call_index: int | None = None
def get_done_events(
self,
*,
choice_chunk: ChoiceChunk,
choice_snapshot: ParsedChoiceSnapshot,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
if choice_snapshot.finish_reason:
events_to_fire.extend(
self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format)
)
if (
self.__current_tool_call_index is not None
and self.__current_tool_call_index not in self._done_tool_calls
):
self._add_tool_done_event(
events_to_fire=events_to_fire,
choice_snapshot=choice_snapshot,
tool_index=self.__current_tool_call_index,
)
for tool_call in choice_chunk.delta.tool_calls or []:
if self.__current_tool_call_index != tool_call.index:
events_to_fire.extend(
self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format)
)
if self.__current_tool_call_index is not None:
self._add_tool_done_event(
events_to_fire=events_to_fire,
choice_snapshot=choice_snapshot,
tool_index=self.__current_tool_call_index,
)
self.__current_tool_call_index = tool_call.index
return events_to_fire
def _content_done_events(
self,
*,
choice_snapshot: ParsedChoiceSnapshot,
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
if choice_snapshot.message.content and not self._content_done:
self._content_done = True
parsed = maybe_parse_content(
response_format=response_format,
message=choice_snapshot.message,
)
# update the parsed content to now use the richer `response_format`
# as opposed to the raw JSON-parsed object as the content is now
# complete and can be fully validated.
choice_snapshot.message.parsed = parsed
events_to_fire.append(
build(
# we do this dance so that when the `ContentDoneEvent` instance
# is printed at runtime the class name will include the solved
# type variable, e.g. `ContentDoneEvent[MyModelType]`
cast( # pyright: ignore[reportUnnecessaryCast]
"type[ContentDoneEvent[ResponseFormatT]]",
cast(Any, ContentDoneEvent)[solve_response_format_t(response_format)],
),
type="content.done",
content=choice_snapshot.message.content,
parsed=parsed,
),
)
if choice_snapshot.message.refusal is not None and not self._refusal_done:
self._refusal_done = True
events_to_fire.append(
build(RefusalDoneEvent, type="refusal.done", refusal=choice_snapshot.message.refusal),
)
if (
choice_snapshot.logprobs is not None
and choice_snapshot.logprobs.content is not None
and not self._logprobs_content_done
):
self._logprobs_content_done = True
events_to_fire.append(
build(LogprobsContentDoneEvent, type="logprobs.content.done", content=choice_snapshot.logprobs.content),
)
if (
choice_snapshot.logprobs is not None
and choice_snapshot.logprobs.refusal is not None
and not self._logprobs_refusal_done
):
self._logprobs_refusal_done = True
events_to_fire.append(
build(LogprobsRefusalDoneEvent, type="logprobs.refusal.done", refusal=choice_snapshot.logprobs.refusal),
)
return events_to_fire
def _add_tool_done_event(
self,
*,
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]],
choice_snapshot: ParsedChoiceSnapshot,
tool_index: int,
) -> None:
if tool_index in self._done_tool_calls:
return
self._done_tool_calls.add(tool_index)
assert choice_snapshot.message.tool_calls is not None
tool_call_snapshot = choice_snapshot.message.tool_calls[tool_index]
if tool_call_snapshot.type == "function":
parsed_arguments = parse_function_tool_arguments(
input_tools=self._input_tools, function=tool_call_snapshot.function
)
# update the parsed content to potentially use a richer type
# as opposed to the raw JSON-parsed object as the content is now
# complete and can be fully validated.
tool_call_snapshot.function.parsed_arguments = parsed_arguments
events_to_fire.append(
build(
FunctionToolCallArgumentsDoneEvent,
type="tool_calls.function.arguments.done",
index=tool_index,
name=tool_call_snapshot.function.name,
arguments=tool_call_snapshot.function.arguments,
parsed_arguments=parsed_arguments,
)
)
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(tool_call_snapshot)
def _convert_initial_chunk_into_snapshot(chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot:
data = chunk.to_dict()
choices = cast("list[object]", data["choices"])
for choice in chunk.choices:
choices[choice.index] = {
**choice.model_dump(exclude_unset=True, exclude={"delta"}),
"message": choice.delta.to_dict(),
}
return cast(
ParsedChatCompletionSnapshot,
construct_type(
type_=ParsedChatCompletionSnapshot,
value={
"system_fingerprint": None,
**data,
"object": "chat.completion",
},
),
)
def _is_valid_chat_completion_chunk_weak(sse_event: ChatCompletionChunk) -> bool:
# Although the _raw_stream is always supposed to contain only objects adhering to ChatCompletionChunk schema,
# this is broken by the Azure OpenAI in case of Asynchronous Filter enabled.
# An easy filter is to check for the "object" property:
# - should be "chat.completion.chunk" for a ChatCompletionChunk;
# - is an empty string for Asynchronous Filter events.
return sse_event.object == "chat.completion.chunk" # type: ignore # pylance reports this as a useless check

View File

@@ -0,0 +1,123 @@
from typing import List, Union, Generic, Optional
from typing_extensions import Literal
from ._types import ParsedChatCompletionSnapshot
from ...._models import BaseModel, GenericModel
from ..._parsing import ResponseFormatT
from ....types.chat import ChatCompletionChunk, ChatCompletionTokenLogprob
class ChunkEvent(BaseModel):
type: Literal["chunk"]
chunk: ChatCompletionChunk
snapshot: ParsedChatCompletionSnapshot
class ContentDeltaEvent(BaseModel):
"""This event is yielded for every chunk with `choice.delta.content` data."""
type: Literal["content.delta"]
delta: str
snapshot: str
parsed: Optional[object] = None
class ContentDoneEvent(GenericModel, Generic[ResponseFormatT]):
type: Literal["content.done"]
content: str
parsed: Optional[ResponseFormatT] = None
class RefusalDeltaEvent(BaseModel):
type: Literal["refusal.delta"]
delta: str
snapshot: str
class RefusalDoneEvent(BaseModel):
type: Literal["refusal.done"]
refusal: str
class FunctionToolCallArgumentsDeltaEvent(BaseModel):
type: Literal["tool_calls.function.arguments.delta"]
name: str
index: int
arguments: str
"""Accumulated raw JSON string"""
parsed_arguments: object
"""The parsed arguments so far"""
arguments_delta: str
"""The JSON string delta"""
class FunctionToolCallArgumentsDoneEvent(BaseModel):
type: Literal["tool_calls.function.arguments.done"]
name: str
index: int
arguments: str
"""Accumulated raw JSON string"""
parsed_arguments: object
"""The parsed arguments"""
class LogprobsContentDeltaEvent(BaseModel):
type: Literal["logprobs.content.delta"]
content: List[ChatCompletionTokenLogprob]
snapshot: List[ChatCompletionTokenLogprob]
class LogprobsContentDoneEvent(BaseModel):
type: Literal["logprobs.content.done"]
content: List[ChatCompletionTokenLogprob]
class LogprobsRefusalDeltaEvent(BaseModel):
type: Literal["logprobs.refusal.delta"]
refusal: List[ChatCompletionTokenLogprob]
snapshot: List[ChatCompletionTokenLogprob]
class LogprobsRefusalDoneEvent(BaseModel):
type: Literal["logprobs.refusal.done"]
refusal: List[ChatCompletionTokenLogprob]
ChatCompletionStreamEvent = Union[
ChunkEvent,
ContentDeltaEvent,
ContentDoneEvent[ResponseFormatT],
RefusalDeltaEvent,
RefusalDoneEvent,
FunctionToolCallArgumentsDeltaEvent,
FunctionToolCallArgumentsDoneEvent,
LogprobsContentDeltaEvent,
LogprobsContentDoneEvent,
LogprobsRefusalDeltaEvent,
LogprobsRefusalDoneEvent,
]

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from typing_extensions import TypeAlias
from ....types.chat import ParsedChoice, ParsedChatCompletion, ParsedChatCompletionMessage
ParsedChatCompletionSnapshot: TypeAlias = ParsedChatCompletion[object]
"""Snapshot type representing an in-progress accumulation of
a `ParsedChatCompletion` object.
"""
ParsedChatCompletionMessageSnapshot: TypeAlias = ParsedChatCompletionMessage[object]
"""Snapshot type representing an in-progress accumulation of
a `ParsedChatCompletionMessage` object.
If the content has been fully accumulated, the `.parsed` content will be
the `response_format` instance, otherwise it'll be the raw JSON parsed version.
"""
ParsedChoiceSnapshot: TypeAlias = ParsedChoice[object]

View File

@@ -0,0 +1,13 @@
from ._events import (
ResponseTextDoneEvent as ResponseTextDoneEvent,
ResponseTextDeltaEvent as ResponseTextDeltaEvent,
ResponseFunctionCallArgumentsDeltaEvent as ResponseFunctionCallArgumentsDeltaEvent,
)
from ._responses import (
ResponseStream as ResponseStream,
AsyncResponseStream as AsyncResponseStream,
ResponseStreamEvent as ResponseStreamEvent,
ResponseStreamState as ResponseStreamState,
ResponseStreamManager as ResponseStreamManager,
AsyncResponseStreamManager as AsyncResponseStreamManager,
)

View File

@@ -0,0 +1,106 @@
from __future__ import annotations
from typing import Optional
from typing_extensions import Union, Generic, TypeVar, Annotated, TypeAlias
from ...._utils import PropertyInfo
from ...._compat import GenericModel
from ....types.responses import (
ParsedResponse,
ResponseErrorEvent,
ResponseFailedEvent,
ResponseCreatedEvent,
ResponseTextDoneEvent as RawResponseTextDoneEvent,
ResponseAudioDoneEvent,
ResponseCompletedEvent as RawResponseCompletedEvent,
ResponseTextDeltaEvent as RawResponseTextDeltaEvent,
ResponseAudioDeltaEvent,
ResponseIncompleteEvent,
ResponseInProgressEvent,
ResponseRefusalDoneEvent,
ResponseRefusalDeltaEvent,
ResponseOutputItemDoneEvent,
ResponseContentPartDoneEvent,
ResponseOutputItemAddedEvent,
ResponseContentPartAddedEvent,
ResponseAudioTranscriptDoneEvent,
ResponseTextAnnotationDeltaEvent,
ResponseAudioTranscriptDeltaEvent,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallSearchingEvent,
ResponseFileSearchCallCompletedEvent,
ResponseFileSearchCallSearchingEvent,
ResponseWebSearchCallInProgressEvent,
ResponseFileSearchCallInProgressEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionCallArgumentsDeltaEvent as RawResponseFunctionCallArgumentsDeltaEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseCodeInterpreterCallCompletedEvent,
ResponseCodeInterpreterCallInProgressEvent,
ResponseCodeInterpreterCallInterpretingEvent,
)
TextFormatT = TypeVar(
"TextFormatT",
# if it isn't given then we don't do any parsing
default=None,
)
class ResponseTextDeltaEvent(RawResponseTextDeltaEvent):
snapshot: str
class ResponseTextDoneEvent(RawResponseTextDoneEvent, GenericModel, Generic[TextFormatT]):
parsed: Optional[TextFormatT] = None
class ResponseFunctionCallArgumentsDeltaEvent(RawResponseFunctionCallArgumentsDeltaEvent):
snapshot: str
class ResponseCompletedEvent(RawResponseCompletedEvent, GenericModel, Generic[TextFormatT]):
response: ParsedResponse[TextFormatT] # type: ignore[assignment]
ResponseStreamEvent: TypeAlias = Annotated[
Union[
# wrappers with snapshots added on
ResponseTextDeltaEvent,
ResponseTextDoneEvent[TextFormatT],
ResponseFunctionCallArgumentsDeltaEvent,
ResponseCompletedEvent[TextFormatT],
# the same as the non-accumulated API
ResponseAudioDeltaEvent,
ResponseAudioDoneEvent,
ResponseAudioTranscriptDeltaEvent,
ResponseAudioTranscriptDoneEvent,
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
ResponseCodeInterpreterCallCompletedEvent,
ResponseCodeInterpreterCallInProgressEvent,
ResponseCodeInterpreterCallInterpretingEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseErrorEvent,
ResponseFileSearchCallCompletedEvent,
ResponseFileSearchCallInProgressEvent,
ResponseFileSearchCallSearchingEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseInProgressEvent,
ResponseFailedEvent,
ResponseIncompleteEvent,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseRefusalDeltaEvent,
ResponseRefusalDoneEvent,
ResponseTextAnnotationDeltaEvent,
ResponseTextDoneEvent,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
],
PropertyInfo(discriminator="type"),
]

View File

@@ -0,0 +1,354 @@
from __future__ import annotations
import inspect
from types import TracebackType
from typing import Any, List, Generic, Iterable, Awaitable, cast
from typing_extensions import Self, Callable, Iterator, AsyncIterator
from ._types import ParsedResponseSnapshot
from ._events import (
ResponseStreamEvent,
ResponseTextDoneEvent,
ResponseCompletedEvent,
ResponseTextDeltaEvent,
ResponseFunctionCallArgumentsDeltaEvent,
)
from ...._types import NOT_GIVEN, NotGiven
from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
from ...._models import build, construct_type_unchecked
from ...._streaming import Stream, AsyncStream
from ....types.responses import ParsedResponse, ResponseStreamEvent as RawResponseStreamEvent
from ..._parsing._responses import TextFormatT, parse_text, parse_response
from ....types.responses.tool_param import ToolParam
from ....types.responses.parsed_response import (
ParsedContent,
ParsedResponseOutputMessage,
ParsedResponseFunctionToolCall,
)
class ResponseStream(Generic[TextFormatT]):
def __init__(
self,
*,
raw_stream: Stream[RawResponseStreamEvent],
text_format: type[TextFormatT] | NotGiven,
input_tools: Iterable[ToolParam] | NotGiven,
) -> None:
self._raw_stream = raw_stream
self._response = raw_stream.response
self._iterator = self.__stream__()
self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
def __next__(self) -> ResponseStreamEvent[TextFormatT]:
return self._iterator.__next__()
def __iter__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
for item in self._iterator:
yield item
def __enter__(self) -> Self:
return self
def __stream__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
for sse_event in self._raw_stream:
events_to_fire = self._state.handle_event(sse_event)
for event in events_to_fire:
yield event
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
self._response.close()
def get_final_response(self) -> ParsedResponse[TextFormatT]:
"""Waits until the stream has been read to completion and returns
the accumulated `ParsedResponse` object.
"""
self.until_done()
response = self._state._completed_response
if not response:
raise RuntimeError("Didn't receive a `response.completed` event.")
return response
def until_done(self) -> Self:
"""Blocks until the stream has been consumed."""
consume_sync_iterator(self)
return self
class ResponseStreamManager(Generic[TextFormatT]):
def __init__(
self,
api_request: Callable[[], Stream[RawResponseStreamEvent]],
*,
text_format: type[TextFormatT] | NotGiven,
input_tools: Iterable[ToolParam] | NotGiven,
) -> None:
self.__stream: ResponseStream[TextFormatT] | None = None
self.__api_request = api_request
self.__text_format = text_format
self.__input_tools = input_tools
def __enter__(self) -> ResponseStream[TextFormatT]:
raw_stream = self.__api_request()
self.__stream = ResponseStream(
raw_stream=raw_stream,
text_format=self.__text_format,
input_tools=self.__input_tools,
)
return self.__stream
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
self.__stream.close()
class AsyncResponseStream(Generic[TextFormatT]):
def __init__(
self,
*,
raw_stream: AsyncStream[RawResponseStreamEvent],
text_format: type[TextFormatT] | NotGiven,
input_tools: Iterable[ToolParam] | NotGiven,
) -> None:
self._raw_stream = raw_stream
self._response = raw_stream.response
self._iterator = self.__stream__()
self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
async def __anext__(self) -> ResponseStreamEvent[TextFormatT]:
return await self._iterator.__anext__()
async def __aiter__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
async for item in self._iterator:
yield item
async def __stream__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
async for sse_event in self._raw_stream:
events_to_fire = self._state.handle_event(sse_event)
for event in events_to_fire:
yield event
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()
async def close(self) -> None:
"""
Close the response and release the connection.
Automatically called if the response body is read to completion.
"""
await self._response.aclose()
async def get_final_response(self) -> ParsedResponse[TextFormatT]:
"""Waits until the stream has been read to completion and returns
the accumulated `ParsedResponse` object.
"""
await self.until_done()
response = self._state._completed_response
if not response:
raise RuntimeError("Didn't receive a `response.completed` event.")
return response
async def until_done(self) -> Self:
"""Blocks until the stream has been consumed."""
await consume_async_iterator(self)
return self
class AsyncResponseStreamManager(Generic[TextFormatT]):
def __init__(
self,
api_request: Awaitable[AsyncStream[RawResponseStreamEvent]],
*,
text_format: type[TextFormatT] | NotGiven,
input_tools: Iterable[ToolParam] | NotGiven,
) -> None:
self.__stream: AsyncResponseStream[TextFormatT] | None = None
self.__api_request = api_request
self.__text_format = text_format
self.__input_tools = input_tools
async def __aenter__(self) -> AsyncResponseStream[TextFormatT]:
raw_stream = await self.__api_request
self.__stream = AsyncResponseStream(
raw_stream=raw_stream,
text_format=self.__text_format,
input_tools=self.__input_tools,
)
return self.__stream
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
await self.__stream.close()
class ResponseStreamState(Generic[TextFormatT]):
def __init__(
self,
*,
input_tools: Iterable[ToolParam] | NotGiven,
text_format: type[TextFormatT] | NotGiven,
) -> None:
self.__current_snapshot: ParsedResponseSnapshot | None = None
self._completed_response: ParsedResponse[TextFormatT] | None = None
self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
self._text_format = text_format
self._rich_text_format: type | NotGiven = text_format if inspect.isclass(text_format) else NOT_GIVEN
def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEvent[TextFormatT]]:
self.__current_snapshot = snapshot = self.accumulate_event(event)
events: List[ResponseStreamEvent[TextFormatT]] = []
if event.type == "response.output_text.delta":
output = snapshot.output[event.output_index]
assert output.type == "message"
content = output.content[event.content_index]
assert content.type == "output_text"
events.append(
build(
ResponseTextDeltaEvent,
content_index=event.content_index,
delta=event.delta,
item_id=event.item_id,
output_index=event.output_index,
type="response.output_text.delta",
snapshot=content.text,
)
)
elif event.type == "response.output_text.done":
output = snapshot.output[event.output_index]
assert output.type == "message"
content = output.content[event.content_index]
assert content.type == "output_text"
events.append(
build(
ResponseTextDoneEvent[TextFormatT],
content_index=event.content_index,
item_id=event.item_id,
output_index=event.output_index,
type="response.output_text.done",
text=event.text,
parsed=parse_text(event.text, text_format=self._text_format),
)
)
elif event.type == "response.function_call_arguments.delta":
output = snapshot.output[event.output_index]
assert output.type == "function_call"
events.append(
build(
ResponseFunctionCallArgumentsDeltaEvent,
delta=event.delta,
item_id=event.item_id,
output_index=event.output_index,
type="response.function_call_arguments.delta",
snapshot=output.arguments,
)
)
elif event.type == "response.completed":
response = self._completed_response
assert response is not None
events.append(
build(
ResponseCompletedEvent,
type="response.completed",
response=response,
)
)
else:
events.append(event)
return events
def accumulate_event(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
snapshot = self.__current_snapshot
if snapshot is None:
return self._create_initial_response(event)
if event.type == "response.output_item.added":
if event.item.type == "function_call":
snapshot.output.append(
construct_type_unchecked(
type_=cast(Any, ParsedResponseFunctionToolCall), value=event.item.to_dict()
)
)
elif event.item.type == "message":
snapshot.output.append(
construct_type_unchecked(type_=cast(Any, ParsedResponseOutputMessage), value=event.item.to_dict())
)
else:
snapshot.output.append(event.item)
elif event.type == "response.content_part.added":
output = snapshot.output[event.output_index]
if output.type == "message":
output.content.append(
construct_type_unchecked(type_=cast(Any, ParsedContent), value=event.part.to_dict())
)
elif event.type == "response.output_text.delta":
output = snapshot.output[event.output_index]
if output.type == "message":
content = output.content[event.content_index]
assert content.type == "output_text"
content.text += event.delta
elif event.type == "response.function_call_arguments.delta":
output = snapshot.output[event.output_index]
if output.type == "function_call":
output.arguments += event.delta
elif event.type == "response.completed":
self._completed_response = parse_response(
text_format=self._text_format,
response=event.response,
input_tools=self._input_tools,
)
return snapshot
def _create_initial_response(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
if event.type != "response.created":
raise RuntimeError(f"Expected to have received `response.created` before `{event.type}`")
return construct_type_unchecked(type_=ParsedResponseSnapshot, value=event.response.to_dict())

View File

@@ -0,0 +1,10 @@
from __future__ import annotations
from typing_extensions import TypeAlias
from ....types.responses import ParsedResponse
ParsedResponseSnapshot: TypeAlias = ParsedResponse[object]
"""Snapshot type representing an in-progress accumulation of
a `ParsedResponse` object.
"""