jo
This commit is contained in:
4
venv/lib/python3.11/site-packages/openai/lib/.keep
Normal file
4
venv/lib/python3.11/site-packages/openai/lib/.keep
Normal file
@@ -0,0 +1,4 @@
|
||||
File generated from our OpenAPI spec by Stainless.
|
||||
|
||||
This directory can be used to store custom files to expand the SDK.
|
||||
It is ignored by Stainless code generation and its content (other than this keep file) won't be touched.
|
||||
2
venv/lib/python3.11/site-packages/openai/lib/__init__.py
Normal file
2
venv/lib/python3.11/site-packages/openai/lib/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from ._tools import pydantic_function_tool as pydantic_function_tool
|
||||
from ._parsing import ResponseFormatT as ResponseFormatT
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
72
venv/lib/python3.11/site-packages/openai/lib/_old_api.py
Normal file
72
venv/lib/python3.11/site-packages/openai/lib/_old_api.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing_extensions import override
|
||||
|
||||
from .._utils import LazyProxy
|
||||
from .._exceptions import OpenAIError
|
||||
|
||||
INSTRUCTIONS = """
|
||||
|
||||
You tried to access openai.{symbol}, but this is no longer supported in openai>=1.0.0 - see the README at https://github.com/openai/openai-python for the API.
|
||||
|
||||
You can run `openai migrate` to automatically upgrade your codebase to use the 1.0.0 interface.
|
||||
|
||||
Alternatively, you can pin your installation to the old version, e.g. `pip install openai==0.28`
|
||||
|
||||
A detailed migration guide is available here: https://github.com/openai/openai-python/discussions/742
|
||||
"""
|
||||
|
||||
|
||||
class APIRemovedInV1(OpenAIError):
|
||||
def __init__(self, *, symbol: str) -> None:
|
||||
super().__init__(INSTRUCTIONS.format(symbol=symbol))
|
||||
|
||||
|
||||
class APIRemovedInV1Proxy(LazyProxy[Any]):
|
||||
def __init__(self, *, symbol: str) -> None:
|
||||
super().__init__()
|
||||
self._symbol = symbol
|
||||
|
||||
@override
|
||||
def __load__(self) -> Any:
|
||||
# return the proxy until it is eventually called so that
|
||||
# we don't break people that are just checking the attributes
|
||||
# of a module
|
||||
return self
|
||||
|
||||
def __call__(self, *_args: Any, **_kwargs: Any) -> Any:
|
||||
raise APIRemovedInV1(symbol=self._symbol)
|
||||
|
||||
|
||||
SYMBOLS = [
|
||||
"Edit",
|
||||
"File",
|
||||
"Audio",
|
||||
"Image",
|
||||
"Model",
|
||||
"Engine",
|
||||
"Customer",
|
||||
"FineTune",
|
||||
"Embedding",
|
||||
"Completion",
|
||||
"Deployment",
|
||||
"Moderation",
|
||||
"ErrorObject",
|
||||
"FineTuningJob",
|
||||
"ChatCompletion",
|
||||
]
|
||||
|
||||
# we explicitly tell type checkers that nothing is exported
|
||||
# from this file so that when we re-export the old symbols
|
||||
# in `openai/__init__.py` they aren't added to the auto-complete
|
||||
# suggestions given by editors
|
||||
if TYPE_CHECKING:
|
||||
__all__: list[str] = []
|
||||
else:
|
||||
__all__ = SYMBOLS
|
||||
|
||||
|
||||
__locals = locals()
|
||||
for symbol in SYMBOLS:
|
||||
__locals[symbol] = APIRemovedInV1Proxy(symbol=symbol)
|
||||
@@ -0,0 +1,12 @@
|
||||
from ._completions import (
|
||||
ResponseFormatT as ResponseFormatT,
|
||||
has_parseable_input,
|
||||
has_parseable_input as has_parseable_input,
|
||||
maybe_parse_content as maybe_parse_content,
|
||||
validate_input_tools as validate_input_tools,
|
||||
parse_chat_completion as parse_chat_completion,
|
||||
get_input_tool_by_name as get_input_tool_by_name,
|
||||
solve_response_format_t as solve_response_format_t,
|
||||
parse_function_tool_arguments as parse_function_tool_arguments,
|
||||
type_to_response_format_param as type_to_response_format_param,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,264 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Iterable, cast
|
||||
from typing_extensions import TypeVar, TypeGuard, assert_never
|
||||
|
||||
import pydantic
|
||||
|
||||
from .._tools import PydanticFunctionTool
|
||||
from ..._types import NOT_GIVEN, NotGiven
|
||||
from ..._utils import is_dict, is_given
|
||||
from ..._compat import PYDANTIC_V2, model_parse_json
|
||||
from ..._models import construct_type_unchecked
|
||||
from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
|
||||
from ...types.chat import (
|
||||
ParsedChoice,
|
||||
ChatCompletion,
|
||||
ParsedFunction,
|
||||
ParsedChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
ParsedFunctionToolCall,
|
||||
ChatCompletionToolParam,
|
||||
ParsedChatCompletionMessage,
|
||||
completion_create_params,
|
||||
)
|
||||
from ..._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
|
||||
from ...types.shared_params import FunctionDefinition
|
||||
from ...types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
|
||||
from ...types.chat.chat_completion_message_tool_call import Function
|
||||
|
||||
ResponseFormatT = TypeVar(
|
||||
"ResponseFormatT",
|
||||
# if it isn't given then we don't do any parsing
|
||||
default=None,
|
||||
)
|
||||
_default_response_format: None = None
|
||||
|
||||
|
||||
def validate_input_tools(
|
||||
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
) -> None:
|
||||
if not is_given(tools):
|
||||
return
|
||||
|
||||
for tool in tools:
|
||||
if tool["type"] != "function":
|
||||
raise ValueError(
|
||||
f"Currently only `function` tool types support auto-parsing; Received `{tool['type']}`",
|
||||
)
|
||||
|
||||
strict = tool["function"].get("strict")
|
||||
if strict is not True:
|
||||
raise ValueError(
|
||||
f"`{tool['function']['name']}` is not strict. Only `strict` function tools can be auto-parsed"
|
||||
)
|
||||
|
||||
|
||||
def parse_chat_completion(
|
||||
*,
|
||||
response_format: type[ResponseFormatT] | completion_create_params.ResponseFormat | NotGiven,
|
||||
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
|
||||
chat_completion: ChatCompletion | ParsedChatCompletion[object],
|
||||
) -> ParsedChatCompletion[ResponseFormatT]:
|
||||
if is_given(input_tools):
|
||||
input_tools = [t for t in input_tools]
|
||||
else:
|
||||
input_tools = []
|
||||
|
||||
choices: list[ParsedChoice[ResponseFormatT]] = []
|
||||
for choice in chat_completion.choices:
|
||||
if choice.finish_reason == "length":
|
||||
raise LengthFinishReasonError(completion=chat_completion)
|
||||
|
||||
if choice.finish_reason == "content_filter":
|
||||
raise ContentFilterFinishReasonError()
|
||||
|
||||
message = choice.message
|
||||
|
||||
tool_calls: list[ParsedFunctionToolCall] = []
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.type == "function":
|
||||
tool_call_dict = tool_call.to_dict()
|
||||
tool_calls.append(
|
||||
construct_type_unchecked(
|
||||
value={
|
||||
**tool_call_dict,
|
||||
"function": {
|
||||
**cast(Any, tool_call_dict["function"]),
|
||||
"parsed_arguments": parse_function_tool_arguments(
|
||||
input_tools=input_tools, function=tool_call.function
|
||||
),
|
||||
},
|
||||
},
|
||||
type_=ParsedFunctionToolCall,
|
||||
)
|
||||
)
|
||||
elif TYPE_CHECKING: # type: ignore[unreachable]
|
||||
assert_never(tool_call)
|
||||
else:
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
choices.append(
|
||||
construct_type_unchecked(
|
||||
type_=cast(Any, ParsedChoice)[solve_response_format_t(response_format)],
|
||||
value={
|
||||
**choice.to_dict(),
|
||||
"message": {
|
||||
**message.to_dict(),
|
||||
"parsed": maybe_parse_content(
|
||||
response_format=response_format,
|
||||
message=message,
|
||||
),
|
||||
"tool_calls": tool_calls if tool_calls else None,
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return cast(
|
||||
ParsedChatCompletion[ResponseFormatT],
|
||||
construct_type_unchecked(
|
||||
type_=cast(Any, ParsedChatCompletion)[solve_response_format_t(response_format)],
|
||||
value={
|
||||
**chat_completion.to_dict(),
|
||||
"choices": choices,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_input_tool_by_name(*, input_tools: list[ChatCompletionToolParam], name: str) -> ChatCompletionToolParam | None:
|
||||
return next((t for t in input_tools if t.get("function", {}).get("name") == name), None)
|
||||
|
||||
|
||||
def parse_function_tool_arguments(
|
||||
*, input_tools: list[ChatCompletionToolParam], function: Function | ParsedFunction
|
||||
) -> object:
|
||||
input_tool = get_input_tool_by_name(input_tools=input_tools, name=function.name)
|
||||
if not input_tool:
|
||||
return None
|
||||
|
||||
input_fn = cast(object, input_tool.get("function"))
|
||||
if isinstance(input_fn, PydanticFunctionTool):
|
||||
return model_parse_json(input_fn.model, function.arguments)
|
||||
|
||||
input_fn = cast(FunctionDefinition, input_fn)
|
||||
|
||||
if not input_fn.get("strict"):
|
||||
return None
|
||||
|
||||
return json.loads(function.arguments)
|
||||
|
||||
|
||||
def maybe_parse_content(
|
||||
*,
|
||||
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
||||
message: ChatCompletionMessage | ParsedChatCompletionMessage[object],
|
||||
) -> ResponseFormatT | None:
|
||||
if has_rich_response_format(response_format) and message.content and not message.refusal:
|
||||
return _parse_content(response_format, message.content)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def solve_response_format_t(
|
||||
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
||||
) -> type[ResponseFormatT]:
|
||||
"""Return the runtime type for the given response format.
|
||||
|
||||
If no response format is given, or if we won't auto-parse the response format
|
||||
then we default to `None`.
|
||||
"""
|
||||
if has_rich_response_format(response_format):
|
||||
return response_format
|
||||
|
||||
return cast("type[ResponseFormatT]", _default_response_format)
|
||||
|
||||
|
||||
def has_parseable_input(
|
||||
*,
|
||||
response_format: type | ResponseFormatParam | NotGiven,
|
||||
input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
) -> bool:
|
||||
if has_rich_response_format(response_format):
|
||||
return True
|
||||
|
||||
for input_tool in input_tools or []:
|
||||
if is_parseable_tool(input_tool):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def has_rich_response_format(
|
||||
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
||||
) -> TypeGuard[type[ResponseFormatT]]:
|
||||
if not is_given(response_format):
|
||||
return False
|
||||
|
||||
if is_response_format_param(response_format):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
|
||||
return is_dict(response_format)
|
||||
|
||||
|
||||
def is_parseable_tool(input_tool: ChatCompletionToolParam) -> bool:
|
||||
input_fn = cast(object, input_tool.get("function"))
|
||||
if isinstance(input_fn, PydanticFunctionTool):
|
||||
return True
|
||||
|
||||
return cast(FunctionDefinition, input_fn).get("strict") or False
|
||||
|
||||
|
||||
def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
|
||||
if is_basemodel_type(response_format):
|
||||
return cast(ResponseFormatT, model_parse_json(response_format, content))
|
||||
|
||||
if is_dataclass_like_type(response_format):
|
||||
if not PYDANTIC_V2:
|
||||
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
|
||||
|
||||
return pydantic.TypeAdapter(response_format).validate_json(content)
|
||||
|
||||
raise TypeError(f"Unable to automatically parse response format type {response_format}")
|
||||
|
||||
|
||||
def type_to_response_format_param(
|
||||
response_format: type | completion_create_params.ResponseFormat | NotGiven,
|
||||
) -> ResponseFormatParam | NotGiven:
|
||||
if not is_given(response_format):
|
||||
return NOT_GIVEN
|
||||
|
||||
if is_response_format_param(response_format):
|
||||
return response_format
|
||||
|
||||
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
|
||||
# a safe default behaviour but we know that at this point the `response_format`
|
||||
# can only be a `type`
|
||||
response_format = cast(type, response_format)
|
||||
|
||||
json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
|
||||
|
||||
if is_basemodel_type(response_format):
|
||||
name = response_format.__name__
|
||||
json_schema_type = response_format
|
||||
elif is_dataclass_like_type(response_format):
|
||||
name = response_format.__name__
|
||||
json_schema_type = pydantic.TypeAdapter(response_format)
|
||||
else:
|
||||
raise TypeError(f"Unsupported response_format type - {response_format}")
|
||||
|
||||
return {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"schema": to_strict_json_schema(json_schema_type),
|
||||
"name": name,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, List, Iterable, cast
|
||||
from typing_extensions import TypeVar, assert_never
|
||||
|
||||
import pydantic
|
||||
|
||||
from .._tools import ResponsesPydanticFunctionTool
|
||||
from ..._types import NotGiven
|
||||
from ..._utils import is_given
|
||||
from ..._compat import PYDANTIC_V2, model_parse_json
|
||||
from ..._models import construct_type_unchecked
|
||||
from .._pydantic import is_basemodel_type, is_dataclass_like_type
|
||||
from ._completions import solve_response_format_t, type_to_response_format_param
|
||||
from ...types.responses import (
|
||||
Response,
|
||||
ToolParam,
|
||||
ParsedContent,
|
||||
ParsedResponse,
|
||||
FunctionToolParam,
|
||||
ParsedResponseOutputItem,
|
||||
ParsedResponseOutputText,
|
||||
ResponseFunctionToolCall,
|
||||
ParsedResponseOutputMessage,
|
||||
ResponseFormatTextConfigParam,
|
||||
ParsedResponseFunctionToolCall,
|
||||
)
|
||||
from ...types.chat.completion_create_params import ResponseFormat
|
||||
|
||||
TextFormatT = TypeVar(
|
||||
"TextFormatT",
|
||||
# if it isn't given then we don't do any parsing
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
def type_to_text_format_param(type_: type) -> ResponseFormatTextConfigParam:
|
||||
response_format_dict = type_to_response_format_param(type_)
|
||||
assert is_given(response_format_dict)
|
||||
response_format_dict = cast(ResponseFormat, response_format_dict) # pyright: ignore[reportUnnecessaryCast]
|
||||
assert response_format_dict["type"] == "json_schema"
|
||||
assert "schema" in response_format_dict["json_schema"]
|
||||
|
||||
return {
|
||||
"type": "json_schema",
|
||||
"strict": True,
|
||||
"name": response_format_dict["json_schema"]["name"],
|
||||
"schema": response_format_dict["json_schema"]["schema"],
|
||||
}
|
||||
|
||||
|
||||
def parse_response(
|
||||
*,
|
||||
text_format: type[TextFormatT] | NotGiven,
|
||||
input_tools: Iterable[ToolParam] | NotGiven | None,
|
||||
response: Response | ParsedResponse[object],
|
||||
) -> ParsedResponse[TextFormatT]:
|
||||
solved_t = solve_response_format_t(text_format)
|
||||
output_list: List[ParsedResponseOutputItem[TextFormatT]] = []
|
||||
|
||||
for output in response.output:
|
||||
if output.type == "message":
|
||||
content_list: List[ParsedContent[TextFormatT]] = []
|
||||
for item in output.content:
|
||||
if item.type != "output_text":
|
||||
content_list.append(item)
|
||||
continue
|
||||
|
||||
content_list.append(
|
||||
construct_type_unchecked(
|
||||
type_=cast(Any, ParsedResponseOutputText)[solved_t],
|
||||
value={
|
||||
**item.to_dict(),
|
||||
"parsed": parse_text(item.text, text_format=text_format),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
output_list.append(
|
||||
construct_type_unchecked(
|
||||
type_=cast(Any, ParsedResponseOutputMessage)[solved_t],
|
||||
value={
|
||||
**output.to_dict(),
|
||||
"content": content_list,
|
||||
},
|
||||
)
|
||||
)
|
||||
elif output.type == "function_call":
|
||||
output_list.append(
|
||||
construct_type_unchecked(
|
||||
type_=ParsedResponseFunctionToolCall,
|
||||
value={
|
||||
**output.to_dict(),
|
||||
"parsed_arguments": parse_function_tool_arguments(
|
||||
input_tools=input_tools, function_call=output
|
||||
),
|
||||
},
|
||||
)
|
||||
)
|
||||
elif (
|
||||
output.type == "computer_call"
|
||||
or output.type == "file_search_call"
|
||||
or output.type == "web_search_call"
|
||||
or output.type == "reasoning"
|
||||
):
|
||||
output_list.append(output)
|
||||
elif TYPE_CHECKING: # type: ignore
|
||||
assert_never(output)
|
||||
else:
|
||||
output_list.append(output)
|
||||
|
||||
return cast(
|
||||
ParsedResponse[TextFormatT],
|
||||
construct_type_unchecked(
|
||||
type_=cast(Any, ParsedResponse)[solved_t],
|
||||
value={
|
||||
**response.to_dict(),
|
||||
"output": output_list,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def parse_text(text: str, text_format: type[TextFormatT] | NotGiven) -> TextFormatT | None:
|
||||
if not is_given(text_format):
|
||||
return None
|
||||
|
||||
if is_basemodel_type(text_format):
|
||||
return cast(TextFormatT, model_parse_json(text_format, text))
|
||||
|
||||
if is_dataclass_like_type(text_format):
|
||||
if not PYDANTIC_V2:
|
||||
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {text_format}")
|
||||
|
||||
return pydantic.TypeAdapter(text_format).validate_json(text)
|
||||
|
||||
raise TypeError(f"Unable to automatically parse response format type {text_format}")
|
||||
|
||||
|
||||
def get_input_tool_by_name(*, input_tools: Iterable[ToolParam], name: str) -> FunctionToolParam | None:
|
||||
for tool in input_tools:
|
||||
if tool["type"] == "function" and tool.get("name") == name:
|
||||
return tool
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def parse_function_tool_arguments(
|
||||
*,
|
||||
input_tools: Iterable[ToolParam] | NotGiven | None,
|
||||
function_call: ParsedResponseFunctionToolCall | ResponseFunctionToolCall,
|
||||
) -> object:
|
||||
if input_tools is None or not is_given(input_tools):
|
||||
return None
|
||||
|
||||
input_tool = get_input_tool_by_name(input_tools=input_tools, name=function_call.name)
|
||||
if not input_tool:
|
||||
return None
|
||||
|
||||
tool = cast(object, input_tool)
|
||||
if isinstance(tool, ResponsesPydanticFunctionTool):
|
||||
return model_parse_json(tool.model, function_call.arguments)
|
||||
|
||||
if not input_tool.get("strict"):
|
||||
return None
|
||||
|
||||
return json.loads(function_call.arguments)
|
||||
155
venv/lib/python3.11/site-packages/openai/lib/_pydantic.py
Normal file
155
venv/lib/python3.11/site-packages/openai/lib/_pydantic.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any, TypeVar
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
import pydantic
|
||||
|
||||
from .._types import NOT_GIVEN
|
||||
from .._utils import is_dict as _is_dict, is_list
|
||||
from .._compat import PYDANTIC_V2, model_json_schema
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def to_strict_json_schema(model: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any]) -> dict[str, Any]:
|
||||
if inspect.isclass(model) and is_basemodel_type(model):
|
||||
schema = model_json_schema(model)
|
||||
elif PYDANTIC_V2 and isinstance(model, pydantic.TypeAdapter):
|
||||
schema = model.json_schema()
|
||||
else:
|
||||
raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {model}")
|
||||
|
||||
return _ensure_strict_json_schema(schema, path=(), root=schema)
|
||||
|
||||
|
||||
def _ensure_strict_json_schema(
|
||||
json_schema: object,
|
||||
*,
|
||||
path: tuple[str, ...],
|
||||
root: dict[str, object],
|
||||
) -> dict[str, Any]:
|
||||
"""Mutates the given JSON schema to ensure it conforms to the `strict` standard
|
||||
that the API expects.
|
||||
"""
|
||||
if not is_dict(json_schema):
|
||||
raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}")
|
||||
|
||||
defs = json_schema.get("$defs")
|
||||
if is_dict(defs):
|
||||
for def_name, def_schema in defs.items():
|
||||
_ensure_strict_json_schema(def_schema, path=(*path, "$defs", def_name), root=root)
|
||||
|
||||
definitions = json_schema.get("definitions")
|
||||
if is_dict(definitions):
|
||||
for definition_name, definition_schema in definitions.items():
|
||||
_ensure_strict_json_schema(definition_schema, path=(*path, "definitions", definition_name), root=root)
|
||||
|
||||
typ = json_schema.get("type")
|
||||
if typ == "object" and "additionalProperties" not in json_schema:
|
||||
json_schema["additionalProperties"] = False
|
||||
|
||||
# object types
|
||||
# { 'type': 'object', 'properties': { 'a': {...} } }
|
||||
properties = json_schema.get("properties")
|
||||
if is_dict(properties):
|
||||
json_schema["required"] = [prop for prop in properties.keys()]
|
||||
json_schema["properties"] = {
|
||||
key: _ensure_strict_json_schema(prop_schema, path=(*path, "properties", key), root=root)
|
||||
for key, prop_schema in properties.items()
|
||||
}
|
||||
|
||||
# arrays
|
||||
# { 'type': 'array', 'items': {...} }
|
||||
items = json_schema.get("items")
|
||||
if is_dict(items):
|
||||
json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root)
|
||||
|
||||
# unions
|
||||
any_of = json_schema.get("anyOf")
|
||||
if is_list(any_of):
|
||||
json_schema["anyOf"] = [
|
||||
_ensure_strict_json_schema(variant, path=(*path, "anyOf", str(i)), root=root)
|
||||
for i, variant in enumerate(any_of)
|
||||
]
|
||||
|
||||
# intersections
|
||||
all_of = json_schema.get("allOf")
|
||||
if is_list(all_of):
|
||||
if len(all_of) == 1:
|
||||
json_schema.update(_ensure_strict_json_schema(all_of[0], path=(*path, "allOf", "0"), root=root))
|
||||
json_schema.pop("allOf")
|
||||
else:
|
||||
json_schema["allOf"] = [
|
||||
_ensure_strict_json_schema(entry, path=(*path, "allOf", str(i)), root=root)
|
||||
for i, entry in enumerate(all_of)
|
||||
]
|
||||
|
||||
# strip `None` defaults as there's no meaningful distinction here
|
||||
# the schema will still be `nullable` and the model will default
|
||||
# to using `None` anyway
|
||||
if json_schema.get("default", NOT_GIVEN) is None:
|
||||
json_schema.pop("default")
|
||||
|
||||
# we can't use `$ref`s if there are also other properties defined, e.g.
|
||||
# `{"$ref": "...", "description": "my description"}`
|
||||
#
|
||||
# so we unravel the ref
|
||||
# `{"type": "string", "description": "my description"}`
|
||||
ref = json_schema.get("$ref")
|
||||
if ref and has_more_than_n_keys(json_schema, 1):
|
||||
assert isinstance(ref, str), f"Received non-string $ref - {ref}"
|
||||
|
||||
resolved = resolve_ref(root=root, ref=ref)
|
||||
if not is_dict(resolved):
|
||||
raise ValueError(f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}")
|
||||
|
||||
# properties from the json schema take priority over the ones on the `$ref`
|
||||
json_schema.update({**resolved, **json_schema})
|
||||
json_schema.pop("$ref")
|
||||
# Since the schema expanded from `$ref` might not have `additionalProperties: false` applied,
|
||||
# we call `_ensure_strict_json_schema` again to fix the inlined schema and ensure it's valid.
|
||||
return _ensure_strict_json_schema(json_schema, path=path, root=root)
|
||||
|
||||
return json_schema
|
||||
|
||||
|
||||
def resolve_ref(*, root: dict[str, object], ref: str) -> object:
|
||||
if not ref.startswith("#/"):
|
||||
raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/")
|
||||
|
||||
path = ref[2:].split("/")
|
||||
resolved = root
|
||||
for key in path:
|
||||
value = resolved[key]
|
||||
assert is_dict(value), f"encountered non-dictionary entry while resolving {ref} - {resolved}"
|
||||
resolved = value
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
def is_basemodel_type(typ: type) -> TypeGuard[type[pydantic.BaseModel]]:
|
||||
if not inspect.isclass(typ):
|
||||
return False
|
||||
return issubclass(typ, pydantic.BaseModel)
|
||||
|
||||
|
||||
def is_dataclass_like_type(typ: type) -> bool:
|
||||
"""Returns True if the given type likely used `@pydantic.dataclass`"""
|
||||
return hasattr(typ, "__pydantic_config__")
|
||||
|
||||
|
||||
def is_dict(obj: object) -> TypeGuard[dict[str, object]]:
|
||||
# just pretend that we know there are only `str` keys
|
||||
# as that check is not worth the performance cost
|
||||
return _is_dict(obj)
|
||||
|
||||
|
||||
def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool:
|
||||
i = 0
|
||||
for _ in obj.keys():
|
||||
i += 1
|
||||
if i > n:
|
||||
return True
|
||||
return False
|
||||
66
venv/lib/python3.11/site-packages/openai/lib/_tools.py
Normal file
66
venv/lib/python3.11/site-packages/openai/lib/_tools.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
import pydantic
|
||||
|
||||
from ._pydantic import to_strict_json_schema
|
||||
from ..types.chat import ChatCompletionToolParam
|
||||
from ..types.shared_params import FunctionDefinition
|
||||
from ..types.responses.function_tool_param import FunctionToolParam as ResponsesFunctionToolParam
|
||||
|
||||
|
||||
class PydanticFunctionTool(Dict[str, Any]):
|
||||
"""Dictionary wrapper so we can pass the given base model
|
||||
throughout the entire request stack without having to special
|
||||
case it.
|
||||
"""
|
||||
|
||||
model: type[pydantic.BaseModel]
|
||||
|
||||
def __init__(self, defn: FunctionDefinition, model: type[pydantic.BaseModel]) -> None:
|
||||
super().__init__(defn)
|
||||
self.model = model
|
||||
|
||||
def cast(self) -> FunctionDefinition:
|
||||
return cast(FunctionDefinition, self)
|
||||
|
||||
|
||||
class ResponsesPydanticFunctionTool(Dict[str, Any]):
|
||||
model: type[pydantic.BaseModel]
|
||||
|
||||
def __init__(self, tool: ResponsesFunctionToolParam, model: type[pydantic.BaseModel]) -> None:
|
||||
super().__init__(tool)
|
||||
self.model = model
|
||||
|
||||
def cast(self) -> ResponsesFunctionToolParam:
|
||||
return cast(ResponsesFunctionToolParam, self)
|
||||
|
||||
|
||||
def pydantic_function_tool(
|
||||
model: type[pydantic.BaseModel],
|
||||
*,
|
||||
name: str | None = None, # inferred from class name by default
|
||||
description: str | None = None, # inferred from class docstring by default
|
||||
) -> ChatCompletionToolParam:
|
||||
if description is None:
|
||||
# note: we intentionally don't use `.getdoc()` to avoid
|
||||
# including pydantic's docstrings
|
||||
description = model.__doc__
|
||||
|
||||
function = PydanticFunctionTool(
|
||||
{
|
||||
"name": name or model.__name__,
|
||||
"strict": True,
|
||||
"parameters": to_strict_json_schema(model),
|
||||
},
|
||||
model,
|
||||
).cast()
|
||||
|
||||
if description is not None:
|
||||
function["description"] = description
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": function,
|
||||
}
|
||||
809
venv/lib/python3.11/site-packages/openai/lib/_validators.py
Normal file
809
venv/lib/python3.11/site-packages/openai/lib/_validators.py
Normal file
@@ -0,0 +1,809 @@
|
||||
# pyright: basic
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, TypeVar, Callable, Optional, NamedTuple
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .._extras import pandas as pd
|
||||
|
||||
|
||||
class Remediation(NamedTuple):
|
||||
name: str
|
||||
immediate_msg: Optional[str] = None
|
||||
necessary_msg: Optional[str] = None
|
||||
necessary_fn: Optional[Callable[[Any], Any]] = None
|
||||
optional_msg: Optional[str] = None
|
||||
optional_fn: Optional[Callable[[Any], Any]] = None
|
||||
error_msg: Optional[str] = None
|
||||
|
||||
|
||||
OptionalDataFrameT = TypeVar("OptionalDataFrameT", bound="Optional[pd.DataFrame]")
|
||||
|
||||
|
||||
def num_examples_validator(df: pd.DataFrame) -> Remediation:
|
||||
"""
|
||||
This validator will only print out the number of examples and recommend to the user to increase the number of examples if less than 100.
|
||||
"""
|
||||
MIN_EXAMPLES = 100
|
||||
optional_suggestion = (
|
||||
""
|
||||
if len(df) >= MIN_EXAMPLES
|
||||
else ". In general, we recommend having at least a few hundred examples. We've found that performance tends to linearly increase for every doubling of the number of examples"
|
||||
)
|
||||
immediate_msg = f"\n- Your file contains {len(df)} prompt-completion pairs{optional_suggestion}"
|
||||
return Remediation(name="num_examples", immediate_msg=immediate_msg)
|
||||
|
||||
|
||||
def necessary_column_validator(df: pd.DataFrame, necessary_column: str) -> Remediation:
|
||||
"""
|
||||
This validator will ensure that the necessary column is present in the dataframe.
|
||||
"""
|
||||
|
||||
def lower_case_column(df: pd.DataFrame, column: Any) -> pd.DataFrame:
|
||||
cols = [c for c in df.columns if str(c).lower() == column]
|
||||
df.rename(columns={cols[0]: column.lower()}, inplace=True)
|
||||
return df
|
||||
|
||||
immediate_msg = None
|
||||
necessary_fn = None
|
||||
necessary_msg = None
|
||||
error_msg = None
|
||||
|
||||
if necessary_column not in df.columns:
|
||||
if necessary_column in [str(c).lower() for c in df.columns]:
|
||||
|
||||
def lower_case_column_creator(df: pd.DataFrame) -> pd.DataFrame:
|
||||
return lower_case_column(df, necessary_column)
|
||||
|
||||
necessary_fn = lower_case_column_creator
|
||||
immediate_msg = f"\n- The `{necessary_column}` column/key should be lowercase"
|
||||
necessary_msg = f"Lower case column name to `{necessary_column}`"
|
||||
else:
|
||||
error_msg = f"`{necessary_column}` column/key is missing. Please make sure you name your columns/keys appropriately, then retry"
|
||||
|
||||
return Remediation(
|
||||
name="necessary_column",
|
||||
immediate_msg=immediate_msg,
|
||||
necessary_msg=necessary_msg,
|
||||
necessary_fn=necessary_fn,
|
||||
error_msg=error_msg,
|
||||
)
|
||||
|
||||
|
||||
def additional_column_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
|
||||
"""
|
||||
This validator will remove additional columns from the dataframe.
|
||||
"""
|
||||
additional_columns = []
|
||||
necessary_msg = None
|
||||
immediate_msg = None
|
||||
necessary_fn = None # type: ignore
|
||||
|
||||
if len(df.columns) > 2:
|
||||
additional_columns = [c for c in df.columns if c not in fields]
|
||||
warn_message = ""
|
||||
for ac in additional_columns:
|
||||
dups = [c for c in additional_columns if ac in c]
|
||||
if len(dups) > 0:
|
||||
warn_message += f"\n WARNING: Some of the additional columns/keys contain `{ac}` in their name. These will be ignored, and the column/key `{ac}` will be used instead. This could also result from a duplicate column/key in the provided file."
|
||||
immediate_msg = f"\n- The input file should contain exactly two columns/keys per row. Additional columns/keys present are: {additional_columns}{warn_message}"
|
||||
necessary_msg = f"Remove additional columns/keys: {additional_columns}"
|
||||
|
||||
def necessary_fn(x: Any) -> Any:
|
||||
return x[fields]
|
||||
|
||||
return Remediation(
|
||||
name="additional_column",
|
||||
immediate_msg=immediate_msg,
|
||||
necessary_msg=necessary_msg,
|
||||
necessary_fn=necessary_fn,
|
||||
)
|
||||
|
||||
|
||||
def non_empty_field_validator(df: pd.DataFrame, field: str = "completion") -> Remediation:
|
||||
"""
|
||||
This validator will ensure that no completion is empty.
|
||||
"""
|
||||
necessary_msg = None
|
||||
necessary_fn = None # type: ignore
|
||||
immediate_msg = None
|
||||
|
||||
if df[field].apply(lambda x: x == "").any() or df[field].isnull().any():
|
||||
empty_rows = (df[field] == "") | (df[field].isnull())
|
||||
empty_indexes = df.reset_index().index[empty_rows].tolist()
|
||||
immediate_msg = f"\n- `{field}` column/key should not contain empty strings. These are rows: {empty_indexes}"
|
||||
|
||||
def necessary_fn(x: Any) -> Any:
|
||||
return x[x[field] != ""].dropna(subset=[field])
|
||||
|
||||
necessary_msg = f"Remove {len(empty_indexes)} rows with empty {field}s"
|
||||
|
||||
return Remediation(
|
||||
name=f"empty_{field}",
|
||||
immediate_msg=immediate_msg,
|
||||
necessary_msg=necessary_msg,
|
||||
necessary_fn=necessary_fn,
|
||||
)
|
||||
|
||||
|
||||
def duplicated_rows_validator(df: pd.DataFrame, fields: list[str] = ["prompt", "completion"]) -> Remediation:
|
||||
"""
|
||||
This validator will suggest to the user to remove duplicate rows if they exist.
|
||||
"""
|
||||
duplicated_rows = df.duplicated(subset=fields)
|
||||
duplicated_indexes = df.reset_index().index[duplicated_rows].tolist()
|
||||
immediate_msg = None
|
||||
optional_msg = None
|
||||
optional_fn = None # type: ignore
|
||||
|
||||
if len(duplicated_indexes) > 0:
|
||||
immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated {'-'.join(fields)} sets. These are rows: {duplicated_indexes}"
|
||||
optional_msg = f"Remove {len(duplicated_indexes)} duplicate rows"
|
||||
|
||||
def optional_fn(x: Any) -> Any:
|
||||
return x.drop_duplicates(subset=fields)
|
||||
|
||||
return Remediation(
|
||||
name="duplicated_rows",
|
||||
immediate_msg=immediate_msg,
|
||||
optional_msg=optional_msg,
|
||||
optional_fn=optional_fn,
|
||||
)
|
||||
|
||||
|
||||
def long_examples_validator(df: pd.DataFrame) -> Remediation:
|
||||
"""
|
||||
This validator will suggest to the user to remove examples that are too long.
|
||||
"""
|
||||
immediate_msg = None
|
||||
optional_msg = None
|
||||
optional_fn = None # type: ignore
|
||||
|
||||
ft_type = infer_task_type(df)
|
||||
if ft_type != "open-ended generation":
|
||||
|
||||
def get_long_indexes(d: pd.DataFrame) -> Any:
|
||||
long_examples = d.apply(lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1)
|
||||
return d.reset_index().index[long_examples].tolist()
|
||||
|
||||
long_indexes = get_long_indexes(df)
|
||||
|
||||
if len(long_indexes) > 0:
|
||||
immediate_msg = f"\n- There are {len(long_indexes)} examples that are very long. These are rows: {long_indexes}\nFor conditional generation, and for classification the examples shouldn't be longer than 2048 tokens."
|
||||
optional_msg = f"Remove {len(long_indexes)} long examples"
|
||||
|
||||
def optional_fn(x: Any) -> Any:
|
||||
long_indexes_to_drop = get_long_indexes(x)
|
||||
if long_indexes != long_indexes_to_drop:
|
||||
sys.stdout.write(
|
||||
f"The indices of the long examples has changed as a result of a previously applied recommendation.\nThe {len(long_indexes_to_drop)} long examples to be dropped are now at the following indices: {long_indexes_to_drop}\n"
|
||||
)
|
||||
return x.drop(long_indexes_to_drop)
|
||||
|
||||
return Remediation(
|
||||
name="long_examples",
|
||||
immediate_msg=immediate_msg,
|
||||
optional_msg=optional_msg,
|
||||
optional_fn=optional_fn,
|
||||
)
|
||||
|
||||
|
||||
def common_prompt_suffix_validator(df: pd.DataFrame) -> Remediation:
|
||||
"""
|
||||
This validator will suggest to add a common suffix to the prompt if one doesn't already exist in case of classification or conditional generation.
|
||||
"""
|
||||
error_msg = None
|
||||
immediate_msg = None
|
||||
optional_msg = None
|
||||
optional_fn = None # type: ignore
|
||||
|
||||
# Find a suffix which is not contained within the prompt otherwise
|
||||
suggested_suffix = "\n\n### =>\n\n"
|
||||
suffix_options = [
|
||||
" ->",
|
||||
"\n\n###\n\n",
|
||||
"\n\n===\n\n",
|
||||
"\n\n---\n\n",
|
||||
"\n\n===>\n\n",
|
||||
"\n\n--->\n\n",
|
||||
]
|
||||
for suffix_option in suffix_options:
|
||||
if suffix_option == " ->":
|
||||
if df.prompt.str.contains("\n").any():
|
||||
continue
|
||||
if df.prompt.str.contains(suffix_option, regex=False).any():
|
||||
continue
|
||||
suggested_suffix = suffix_option
|
||||
break
|
||||
display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
|
||||
|
||||
ft_type = infer_task_type(df)
|
||||
if ft_type == "open-ended generation":
|
||||
return Remediation(name="common_suffix")
|
||||
|
||||
def add_suffix(x: Any, suffix: Any) -> Any:
|
||||
x["prompt"] += suffix
|
||||
return x
|
||||
|
||||
common_suffix = get_common_xfix(df.prompt, xfix="suffix")
|
||||
if (df.prompt == common_suffix).all():
|
||||
error_msg = f"All prompts are identical: `{common_suffix}`\nConsider leaving the prompts blank if you want to do open-ended generation, otherwise ensure prompts are different"
|
||||
return Remediation(name="common_suffix", error_msg=error_msg)
|
||||
|
||||
if common_suffix != "":
|
||||
common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
|
||||
immediate_msg = f"\n- All prompts end with suffix `{common_suffix_new_line_handled}`"
|
||||
if len(common_suffix) > 10:
|
||||
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
|
||||
if df.prompt.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
|
||||
immediate_msg += f"\n WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix"
|
||||
|
||||
else:
|
||||
immediate_msg = "\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty"
|
||||
|
||||
if common_suffix == "":
|
||||
optional_msg = f"Add a suffix separator `{display_suggested_suffix}` to all prompts"
|
||||
|
||||
def optional_fn(x: Any) -> Any:
|
||||
return add_suffix(x, suggested_suffix)
|
||||
|
||||
return Remediation(
|
||||
name="common_completion_suffix",
|
||||
immediate_msg=immediate_msg,
|
||||
optional_msg=optional_msg,
|
||||
optional_fn=optional_fn,
|
||||
error_msg=error_msg,
|
||||
)
|
||||
|
||||
|
||||
def common_prompt_prefix_validator(df: pd.DataFrame) -> Remediation:
|
||||
"""
|
||||
This validator will suggest to remove a common prefix from the prompt if a long one exist.
|
||||
"""
|
||||
MAX_PREFIX_LEN = 12
|
||||
|
||||
immediate_msg = None
|
||||
optional_msg = None
|
||||
optional_fn = None # type: ignore
|
||||
|
||||
common_prefix = get_common_xfix(df.prompt, xfix="prefix")
|
||||
if common_prefix == "":
|
||||
return Remediation(name="common_prefix")
|
||||
|
||||
def remove_common_prefix(x: Any, prefix: Any) -> Any:
|
||||
x["prompt"] = x["prompt"].str[len(prefix) :]
|
||||
return x
|
||||
|
||||
if (df.prompt == common_prefix).all():
|
||||
# already handled by common_suffix_validator
|
||||
return Remediation(name="common_prefix")
|
||||
|
||||
if common_prefix != "":
|
||||
immediate_msg = f"\n- All prompts start with prefix `{common_prefix}`"
|
||||
if MAX_PREFIX_LEN < len(common_prefix):
|
||||
immediate_msg += ". Fine-tuning doesn't require the instruction specifying the task, or a few-shot example scenario. Most of the time you should only add the input data into the prompt, and the desired output into the completion"
|
||||
optional_msg = f"Remove prefix `{common_prefix}` from all prompts"
|
||||
|
||||
def optional_fn(x: Any) -> Any:
|
||||
return remove_common_prefix(x, common_prefix)
|
||||
|
||||
return Remediation(
|
||||
name="common_prompt_prefix",
|
||||
immediate_msg=immediate_msg,
|
||||
optional_msg=optional_msg,
|
||||
optional_fn=optional_fn,
|
||||
)
|
||||
|
||||
|
||||
def common_completion_prefix_validator(df: pd.DataFrame) -> Remediation:
|
||||
"""
|
||||
This validator will suggest to remove a common prefix from the completion if a long one exist.
|
||||
"""
|
||||
MAX_PREFIX_LEN = 5
|
||||
|
||||
common_prefix = get_common_xfix(df.completion, xfix="prefix")
|
||||
ws_prefix = len(common_prefix) > 0 and common_prefix[0] == " "
|
||||
if len(common_prefix) < MAX_PREFIX_LEN:
|
||||
return Remediation(name="common_prefix")
|
||||
|
||||
def remove_common_prefix(x: Any, prefix: Any, ws_prefix: Any) -> Any:
|
||||
x["completion"] = x["completion"].str[len(prefix) :]
|
||||
if ws_prefix:
|
||||
# keep the single whitespace as prefix
|
||||
x["completion"] = f" {x['completion']}"
|
||||
return x
|
||||
|
||||
if (df.completion == common_prefix).all():
|
||||
# already handled by common_suffix_validator
|
||||
return Remediation(name="common_prefix")
|
||||
|
||||
immediate_msg = f"\n- All completions start with prefix `{common_prefix}`. Most of the time you should only add the output data into the completion, without any prefix"
|
||||
optional_msg = f"Remove prefix `{common_prefix}` from all completions"
|
||||
|
||||
def optional_fn(x: Any) -> Any:
|
||||
return remove_common_prefix(x, common_prefix, ws_prefix)
|
||||
|
||||
return Remediation(
|
||||
name="common_completion_prefix",
|
||||
immediate_msg=immediate_msg,
|
||||
optional_msg=optional_msg,
|
||||
optional_fn=optional_fn,
|
||||
)
|
||||
|
||||
|
||||
def common_completion_suffix_validator(df: pd.DataFrame) -> Remediation:
|
||||
"""
|
||||
This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation.
|
||||
"""
|
||||
error_msg = None
|
||||
immediate_msg = None
|
||||
optional_msg = None
|
||||
optional_fn = None # type: ignore
|
||||
|
||||
ft_type = infer_task_type(df)
|
||||
if ft_type == "open-ended generation" or ft_type == "classification":
|
||||
return Remediation(name="common_suffix")
|
||||
|
||||
common_suffix = get_common_xfix(df.completion, xfix="suffix")
|
||||
if (df.completion == common_suffix).all():
|
||||
error_msg = f"All completions are identical: `{common_suffix}`\nEnsure completions are different, otherwise the model will just repeat `{common_suffix}`"
|
||||
return Remediation(name="common_suffix", error_msg=error_msg)
|
||||
|
||||
# Find a suffix which is not contained within the completion otherwise
|
||||
suggested_suffix = " [END]"
|
||||
suffix_options = [
|
||||
"\n",
|
||||
".",
|
||||
" END",
|
||||
"***",
|
||||
"+++",
|
||||
"&&&",
|
||||
"$$$",
|
||||
"@@@",
|
||||
"%%%",
|
||||
]
|
||||
for suffix_option in suffix_options:
|
||||
if df.completion.str.contains(suffix_option, regex=False).any():
|
||||
continue
|
||||
suggested_suffix = suffix_option
|
||||
break
|
||||
display_suggested_suffix = suggested_suffix.replace("\n", "\\n")
|
||||
|
||||
def add_suffix(x: Any, suffix: Any) -> Any:
|
||||
x["completion"] += suffix
|
||||
return x
|
||||
|
||||
if common_suffix != "":
|
||||
common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
|
||||
immediate_msg = f"\n- All completions end with suffix `{common_suffix_new_line_handled}`"
|
||||
if len(common_suffix) > 10:
|
||||
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
|
||||
if df.completion.str[: -len(common_suffix)].str.contains(common_suffix, regex=False).any():
|
||||
immediate_msg += f"\n WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending"
|
||||
|
||||
else:
|
||||
immediate_msg = "\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples."
|
||||
|
||||
if common_suffix == "":
|
||||
optional_msg = f"Add a suffix ending `{display_suggested_suffix}` to all completions"
|
||||
|
||||
def optional_fn(x: Any) -> Any:
|
||||
return add_suffix(x, suggested_suffix)
|
||||
|
||||
return Remediation(
|
||||
name="common_completion_suffix",
|
||||
immediate_msg=immediate_msg,
|
||||
optional_msg=optional_msg,
|
||||
optional_fn=optional_fn,
|
||||
error_msg=error_msg,
|
||||
)
|
||||
|
||||
|
||||
def completions_space_start_validator(df: pd.DataFrame) -> Remediation:
|
||||
"""
|
||||
This validator will suggest to add a space at the start of the completion if it doesn't already exist. This helps with tokenization.
|
||||
"""
|
||||
|
||||
def add_space_start(x: Any) -> Any:
|
||||
x["completion"] = x["completion"].apply(lambda s: ("" if s.startswith(" ") else " ") + s)
|
||||
return x
|
||||
|
||||
optional_msg = None
|
||||
optional_fn = None
|
||||
immediate_msg = None
|
||||
|
||||
if df.completion.str[:1].nunique() != 1 or df.completion.values[0][0] != " ":
|
||||
immediate_msg = "\n- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details"
|
||||
optional_msg = "Add a whitespace character to the beginning of the completion"
|
||||
optional_fn = add_space_start
|
||||
return Remediation(
|
||||
name="completion_space_start",
|
||||
immediate_msg=immediate_msg,
|
||||
optional_msg=optional_msg,
|
||||
optional_fn=optional_fn,
|
||||
)
|
||||
|
||||
|
||||
def lower_case_validator(df: pd.DataFrame, column: Any) -> Remediation | None:
|
||||
"""
|
||||
This validator will suggest to lowercase the column values, if more than a third of letters are uppercase.
|
||||
"""
|
||||
|
||||
def lower_case(x: Any) -> Any:
|
||||
x[column] = x[column].str.lower()
|
||||
return x
|
||||
|
||||
count_upper = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.isupper())).sum()
|
||||
count_lower = df[column].apply(lambda x: sum(1 for c in x if c.isalpha() and c.islower())).sum()
|
||||
|
||||
if count_upper * 2 > count_lower:
|
||||
return Remediation(
|
||||
name="lower_case",
|
||||
immediate_msg=f"\n- More than a third of your `{column}` column/key is uppercase. Uppercase {column}s tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details",
|
||||
optional_msg=f"Lowercase all your data in column/key `{column}`",
|
||||
optional_fn=lower_case,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def read_any_format(
|
||||
fname: str, fields: list[str] = ["prompt", "completion"]
|
||||
) -> tuple[pd.DataFrame | None, Remediation]:
|
||||
"""
|
||||
This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas.
|
||||
- for .xlsx it will read the first sheet
|
||||
- for .txt it will assume completions and split on newline
|
||||
"""
|
||||
remediation = None
|
||||
necessary_msg = None
|
||||
immediate_msg = None
|
||||
error_msg = None
|
||||
df = None
|
||||
|
||||
if os.path.isfile(fname):
|
||||
try:
|
||||
if fname.lower().endswith(".csv") or fname.lower().endswith(".tsv"):
|
||||
file_extension_str, separator = ("CSV", ",") if fname.lower().endswith(".csv") else ("TSV", "\t")
|
||||
immediate_msg = (
|
||||
f"\n- Based on your file extension, your file is formatted as a {file_extension_str} file"
|
||||
)
|
||||
necessary_msg = f"Your format `{file_extension_str}` will be converted to `JSONL`"
|
||||
df = pd.read_csv(fname, sep=separator, dtype=str).fillna("")
|
||||
elif fname.lower().endswith(".xlsx"):
|
||||
immediate_msg = "\n- Based on your file extension, your file is formatted as an Excel file"
|
||||
necessary_msg = "Your format `XLSX` will be converted to `JSONL`"
|
||||
xls = pd.ExcelFile(fname)
|
||||
sheets = xls.sheet_names
|
||||
if len(sheets) > 1:
|
||||
immediate_msg += "\n- Your Excel file contains more than one sheet. Please either save as csv or ensure all data is present in the first sheet. WARNING: Reading only the first sheet..."
|
||||
df = pd.read_excel(fname, dtype=str).fillna("")
|
||||
elif fname.lower().endswith(".txt"):
|
||||
immediate_msg = "\n- Based on your file extension, you provided a text file"
|
||||
necessary_msg = "Your format `TXT` will be converted to `JSONL`"
|
||||
with open(fname, "r") as f:
|
||||
content = f.read()
|
||||
df = pd.DataFrame(
|
||||
[["", line] for line in content.split("\n")],
|
||||
columns=fields,
|
||||
dtype=str,
|
||||
).fillna("")
|
||||
elif fname.lower().endswith(".jsonl"):
|
||||
df = pd.read_json(fname, lines=True, dtype=str).fillna("") # type: ignore
|
||||
if len(df) == 1: # type: ignore
|
||||
# this is NOT what we expect for a .jsonl file
|
||||
immediate_msg = "\n- Your JSONL file appears to be in a JSON format. Your file will be converted to JSONL format"
|
||||
necessary_msg = "Your format `JSON` will be converted to `JSONL`"
|
||||
df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
|
||||
else:
|
||||
pass # this is what we expect for a .jsonl file
|
||||
elif fname.lower().endswith(".json"):
|
||||
try:
|
||||
# to handle case where .json file is actually a .jsonl file
|
||||
df = pd.read_json(fname, lines=True, dtype=str).fillna("") # type: ignore
|
||||
if len(df) == 1: # type: ignore
|
||||
# this code path corresponds to a .json file that has one line
|
||||
df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
|
||||
else:
|
||||
# this is NOT what we expect for a .json file
|
||||
immediate_msg = "\n- Your JSON file appears to be in a JSONL format. Your file will be converted to JSONL format"
|
||||
necessary_msg = "Your format `JSON` will be converted to `JSONL`"
|
||||
except ValueError:
|
||||
# this code path corresponds to a .json file that has multiple lines (i.e. it is indented)
|
||||
df = pd.read_json(fname, dtype=str).fillna("") # type: ignore
|
||||
else:
|
||||
error_msg = (
|
||||
"Your file must have one of the following extensions: .CSV, .TSV, .XLSX, .TXT, .JSON or .JSONL"
|
||||
)
|
||||
if "." in fname:
|
||||
error_msg += f" Your file `{fname}` ends with the extension `.{fname.split('.')[-1]}` which is not supported."
|
||||
else:
|
||||
error_msg += f" Your file `{fname}` is missing a file extension."
|
||||
|
||||
except (ValueError, TypeError):
|
||||
file_extension_str = fname.split(".")[-1].upper()
|
||||
error_msg = f"Your file `{fname}` does not appear to be in valid {file_extension_str} format. Please ensure your file is formatted as a valid {file_extension_str} file."
|
||||
|
||||
else:
|
||||
error_msg = f"File {fname} does not exist."
|
||||
|
||||
remediation = Remediation(
|
||||
name="read_any_format",
|
||||
necessary_msg=necessary_msg,
|
||||
immediate_msg=immediate_msg,
|
||||
error_msg=error_msg,
|
||||
)
|
||||
return df, remediation
|
||||
|
||||
|
||||
def format_inferrer_validator(df: pd.DataFrame) -> Remediation:
|
||||
"""
|
||||
This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification.
|
||||
It will also suggest to use ada and explain train/validation split benefits.
|
||||
"""
|
||||
ft_type = infer_task_type(df)
|
||||
immediate_msg = None
|
||||
if ft_type == "classification":
|
||||
immediate_msg = f"\n- Based on your data it seems like you're trying to fine-tune a model for {ft_type}\n- For classification, we recommend you try one of the faster and cheaper models, such as `ada`\n- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training"
|
||||
return Remediation(name="num_examples", immediate_msg=immediate_msg)
|
||||
|
||||
|
||||
def apply_necessary_remediation(df: OptionalDataFrameT, remediation: Remediation) -> OptionalDataFrameT:
|
||||
"""
|
||||
This function will apply a necessary remediation to a dataframe, or print an error message if one exists.
|
||||
"""
|
||||
if remediation.error_msg is not None:
|
||||
sys.stderr.write(f"\n\nERROR in {remediation.name} validator: {remediation.error_msg}\n\nAborting...")
|
||||
sys.exit(1)
|
||||
if remediation.immediate_msg is not None:
|
||||
sys.stdout.write(remediation.immediate_msg)
|
||||
if remediation.necessary_fn is not None:
|
||||
df = remediation.necessary_fn(df)
|
||||
return df
|
||||
|
||||
|
||||
def accept_suggestion(input_text: str, auto_accept: bool) -> bool:
|
||||
sys.stdout.write(input_text)
|
||||
if auto_accept:
|
||||
sys.stdout.write("Y\n")
|
||||
return True
|
||||
return input().lower() != "n"
|
||||
|
||||
|
||||
def apply_optional_remediation(
|
||||
df: pd.DataFrame, remediation: Remediation, auto_accept: bool
|
||||
) -> tuple[pd.DataFrame, bool]:
|
||||
"""
|
||||
This function will apply an optional remediation to a dataframe, based on the user input.
|
||||
"""
|
||||
optional_applied = False
|
||||
input_text = f"- [Recommended] {remediation.optional_msg} [Y/n]: "
|
||||
if remediation.optional_msg is not None:
|
||||
if accept_suggestion(input_text, auto_accept):
|
||||
assert remediation.optional_fn is not None
|
||||
df = remediation.optional_fn(df)
|
||||
optional_applied = True
|
||||
if remediation.necessary_msg is not None:
|
||||
sys.stdout.write(f"- [Necessary] {remediation.necessary_msg}\n")
|
||||
return df, optional_applied
|
||||
|
||||
|
||||
def estimate_fine_tuning_time(df: pd.DataFrame) -> None:
|
||||
"""
|
||||
Estimate the time it'll take to fine-tune the dataset
|
||||
"""
|
||||
ft_format = infer_task_type(df)
|
||||
expected_time = 1.0
|
||||
if ft_format == "classification":
|
||||
num_examples = len(df)
|
||||
expected_time = num_examples * 1.44
|
||||
else:
|
||||
size = df.memory_usage(index=True).sum()
|
||||
expected_time = size * 0.0515
|
||||
|
||||
def format_time(time: float) -> str:
|
||||
if time < 60:
|
||||
return f"{round(time, 2)} seconds"
|
||||
elif time < 3600:
|
||||
return f"{round(time / 60, 2)} minutes"
|
||||
elif time < 86400:
|
||||
return f"{round(time / 3600, 2)} hours"
|
||||
else:
|
||||
return f"{round(time / 86400, 2)} days"
|
||||
|
||||
time_string = format_time(expected_time + 140)
|
||||
sys.stdout.write(
|
||||
f"Once your model starts training, it'll approximately take {time_string} to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.\n"
|
||||
)
|
||||
|
||||
|
||||
def get_outfnames(fname: str, split: bool) -> list[str]:
|
||||
suffixes = ["_train", "_valid"] if split else [""]
|
||||
i = 0
|
||||
while True:
|
||||
index_suffix = f" ({i})" if i > 0 else ""
|
||||
candidate_fnames = [f"{os.path.splitext(fname)[0]}_prepared{suffix}{index_suffix}.jsonl" for suffix in suffixes]
|
||||
if not any(os.path.isfile(f) for f in candidate_fnames):
|
||||
return candidate_fnames
|
||||
i += 1
|
||||
|
||||
|
||||
def get_classification_hyperparams(df: pd.DataFrame) -> tuple[int, object]:
|
||||
n_classes = df.completion.nunique()
|
||||
pos_class = None
|
||||
if n_classes == 2:
|
||||
pos_class = df.completion.value_counts().index[0]
|
||||
return n_classes, pos_class
|
||||
|
||||
|
||||
def write_out_file(df: pd.DataFrame, fname: str, any_remediations: bool, auto_accept: bool) -> None:
|
||||
"""
|
||||
This function will write out a dataframe to a file, if the user would like to proceed, and also offer a fine-tuning command with the newly created file.
|
||||
For classification it will optionally ask the user if they would like to split the data into train/valid files, and modify the suggested command to include the valid set.
|
||||
"""
|
||||
ft_format = infer_task_type(df)
|
||||
common_prompt_suffix = get_common_xfix(df.prompt, xfix="suffix")
|
||||
common_completion_suffix = get_common_xfix(df.completion, xfix="suffix")
|
||||
|
||||
split = False
|
||||
input_text = "- [Recommended] Would you like to split into training and validation set? [Y/n]: "
|
||||
if ft_format == "classification":
|
||||
if accept_suggestion(input_text, auto_accept):
|
||||
split = True
|
||||
|
||||
additional_params = ""
|
||||
common_prompt_suffix_new_line_handled = common_prompt_suffix.replace("\n", "\\n")
|
||||
common_completion_suffix_new_line_handled = common_completion_suffix.replace("\n", "\\n")
|
||||
optional_ending_string = (
|
||||
f' Make sure to include `stop=["{common_completion_suffix_new_line_handled}"]` so that the generated texts ends at the expected place.'
|
||||
if len(common_completion_suffix_new_line_handled) > 0
|
||||
else ""
|
||||
)
|
||||
|
||||
input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: "
|
||||
|
||||
if not any_remediations and not split:
|
||||
sys.stdout.write(
|
||||
f'\nYou can use your file for fine-tuning:\n> openai api fine_tunes.create -t "{fname}"{additional_params}\n\nAfter you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n'
|
||||
)
|
||||
estimate_fine_tuning_time(df)
|
||||
|
||||
elif accept_suggestion(input_text, auto_accept):
|
||||
fnames = get_outfnames(fname, split)
|
||||
if split:
|
||||
assert len(fnames) == 2 and "train" in fnames[0] and "valid" in fnames[1]
|
||||
MAX_VALID_EXAMPLES = 1000
|
||||
n_train = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8))
|
||||
df_train = df.sample(n=n_train, random_state=42)
|
||||
df_valid = df.drop(df_train.index)
|
||||
df_train[["prompt", "completion"]].to_json( # type: ignore
|
||||
fnames[0], lines=True, orient="records", force_ascii=False, indent=None
|
||||
)
|
||||
df_valid[["prompt", "completion"]].to_json(
|
||||
fnames[1], lines=True, orient="records", force_ascii=False, indent=None
|
||||
)
|
||||
|
||||
n_classes, pos_class = get_classification_hyperparams(df)
|
||||
additional_params += " --compute_classification_metrics"
|
||||
if n_classes == 2:
|
||||
additional_params += f' --classification_positive_class "{pos_class}"'
|
||||
else:
|
||||
additional_params += f" --classification_n_classes {n_classes}"
|
||||
else:
|
||||
assert len(fnames) == 1
|
||||
df[["prompt", "completion"]].to_json(
|
||||
fnames[0], lines=True, orient="records", force_ascii=False, indent=None
|
||||
)
|
||||
|
||||
# Add -v VALID_FILE if we split the file into train / valid
|
||||
files_string = ("s" if split else "") + " to `" + ("` and `".join(fnames))
|
||||
valid_string = f' -v "{fnames[1]}"' if split else ""
|
||||
separator_reminder = (
|
||||
""
|
||||
if len(common_prompt_suffix_new_line_handled) == 0
|
||||
else f"After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt."
|
||||
)
|
||||
sys.stdout.write(
|
||||
f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{fnames[0]}"{valid_string}{additional_params}\n\n{separator_reminder}{optional_ending_string}\n'
|
||||
)
|
||||
estimate_fine_tuning_time(df)
|
||||
else:
|
||||
sys.stdout.write("Aborting... did not write the file\n")
|
||||
|
||||
|
||||
def infer_task_type(df: pd.DataFrame) -> str:
|
||||
"""
|
||||
Infer the likely fine-tuning task type from the data
|
||||
"""
|
||||
CLASSIFICATION_THRESHOLD = 3 # min_average instances of each class
|
||||
if sum(df.prompt.str.len()) == 0:
|
||||
return "open-ended generation"
|
||||
|
||||
if len(df.completion.unique()) < len(df) / CLASSIFICATION_THRESHOLD:
|
||||
return "classification"
|
||||
|
||||
return "conditional generation"
|
||||
|
||||
|
||||
def get_common_xfix(series: Any, xfix: str = "suffix") -> str:
|
||||
"""
|
||||
Finds the longest common suffix or prefix of all the values in a series
|
||||
"""
|
||||
common_xfix = ""
|
||||
while True:
|
||||
common_xfixes = (
|
||||
series.str[-(len(common_xfix) + 1) :] if xfix == "suffix" else series.str[: len(common_xfix) + 1]
|
||||
) # first few or last few characters
|
||||
if common_xfixes.nunique() != 1: # we found the character at which we don't have a unique xfix anymore
|
||||
break
|
||||
elif common_xfix == common_xfixes.values[0]: # the entire first row is a prefix of every other row
|
||||
break
|
||||
else: # the first or last few characters are still common across all rows - let's try to add one more
|
||||
common_xfix = common_xfixes.values[0]
|
||||
return common_xfix
|
||||
|
||||
|
||||
Validator: TypeAlias = "Callable[[pd.DataFrame], Remediation | None]"
|
||||
|
||||
|
||||
def get_validators() -> list[Validator]:
|
||||
return [
|
||||
num_examples_validator,
|
||||
lambda x: necessary_column_validator(x, "prompt"),
|
||||
lambda x: necessary_column_validator(x, "completion"),
|
||||
additional_column_validator,
|
||||
non_empty_field_validator,
|
||||
format_inferrer_validator,
|
||||
duplicated_rows_validator,
|
||||
long_examples_validator,
|
||||
lambda x: lower_case_validator(x, "prompt"),
|
||||
lambda x: lower_case_validator(x, "completion"),
|
||||
common_prompt_suffix_validator,
|
||||
common_prompt_prefix_validator,
|
||||
common_completion_prefix_validator,
|
||||
common_completion_suffix_validator,
|
||||
completions_space_start_validator,
|
||||
]
|
||||
|
||||
|
||||
def apply_validators(
|
||||
df: pd.DataFrame,
|
||||
fname: str,
|
||||
remediation: Remediation | None,
|
||||
validators: list[Validator],
|
||||
auto_accept: bool,
|
||||
write_out_file_func: Callable[..., Any],
|
||||
) -> None:
|
||||
optional_remediations: list[Remediation] = []
|
||||
if remediation is not None:
|
||||
optional_remediations.append(remediation)
|
||||
for validator in validators:
|
||||
remediation = validator(df)
|
||||
if remediation is not None:
|
||||
optional_remediations.append(remediation)
|
||||
df = apply_necessary_remediation(df, remediation)
|
||||
|
||||
any_optional_or_necessary_remediations = any(
|
||||
[
|
||||
remediation
|
||||
for remediation in optional_remediations
|
||||
if remediation.optional_msg is not None or remediation.necessary_msg is not None
|
||||
]
|
||||
)
|
||||
any_necessary_applied = any(
|
||||
[remediation for remediation in optional_remediations if remediation.necessary_msg is not None]
|
||||
)
|
||||
any_optional_applied = False
|
||||
|
||||
if any_optional_or_necessary_remediations:
|
||||
sys.stdout.write("\n\nBased on the analysis we will perform the following actions:\n")
|
||||
for remediation in optional_remediations:
|
||||
df, optional_applied = apply_optional_remediation(df, remediation, auto_accept)
|
||||
any_optional_applied = any_optional_applied or optional_applied
|
||||
else:
|
||||
sys.stdout.write("\n\nNo remediations found.\n")
|
||||
|
||||
any_optional_or_necessary_applied = any_optional_applied or any_necessary_applied
|
||||
|
||||
write_out_file_func(df, fname, any_optional_or_necessary_applied, auto_accept)
|
||||
632
venv/lib/python3.11/site-packages/openai/lib/azure.py
Normal file
632
venv/lib/python3.11/site-packages/openai/lib/azure.py
Normal file
@@ -0,0 +1,632 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import inspect
|
||||
from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload
|
||||
from typing_extensions import Self, override
|
||||
|
||||
import httpx
|
||||
|
||||
from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven
|
||||
from .._utils import is_given, is_mapping
|
||||
from .._client import OpenAI, AsyncOpenAI
|
||||
from .._compat import model_copy
|
||||
from .._models import FinalRequestOptions
|
||||
from .._streaming import Stream, AsyncStream
|
||||
from .._exceptions import OpenAIError
|
||||
from .._base_client import DEFAULT_MAX_RETRIES, BaseClient
|
||||
|
||||
_deployments_endpoints = set(
|
||||
[
|
||||
"/completions",
|
||||
"/chat/completions",
|
||||
"/embeddings",
|
||||
"/audio/transcriptions",
|
||||
"/audio/translations",
|
||||
"/audio/speech",
|
||||
"/images/generations",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
AzureADTokenProvider = Callable[[], str]
|
||||
AsyncAzureADTokenProvider = Callable[[], "str | Awaitable[str]"]
|
||||
_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
|
||||
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])
|
||||
|
||||
|
||||
# we need to use a sentinel API key value for Azure AD
|
||||
# as we don't want to make the `api_key` in the main client Optional
|
||||
# and Azure AD tokens may be retrieved on a per-request basis
|
||||
API_KEY_SENTINEL = "".join(["<", "missing API key", ">"])
|
||||
|
||||
|
||||
class MutuallyExclusiveAuthError(OpenAIError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"The `api_key`, `azure_ad_token` and `azure_ad_token_provider` arguments are mutually exclusive; Only one can be passed at a time"
|
||||
)
|
||||
|
||||
|
||||
class BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
|
||||
_azure_endpoint: httpx.URL | None
|
||||
_azure_deployment: str | None
|
||||
|
||||
@override
|
||||
def _build_request(
|
||||
self,
|
||||
options: FinalRequestOptions,
|
||||
*,
|
||||
retries_taken: int = 0,
|
||||
) -> httpx.Request:
|
||||
if options.url in _deployments_endpoints and is_mapping(options.json_data):
|
||||
model = options.json_data.get("model")
|
||||
if model is not None and "/deployments" not in str(self.base_url.path):
|
||||
options.url = f"/deployments/{model}{options.url}"
|
||||
|
||||
return super()._build_request(options, retries_taken=retries_taken)
|
||||
|
||||
@override
|
||||
def _prepare_url(self, url: str) -> httpx.URL:
|
||||
"""Adjust the URL if the client was configured with an Azure endpoint + deployment
|
||||
and the API feature being called is **not** a deployments-based endpoint
|
||||
(i.e. requires /deployments/deployment-name in the URL path).
|
||||
"""
|
||||
if self._azure_deployment and self._azure_endpoint and url not in _deployments_endpoints:
|
||||
merge_url = httpx.URL(url)
|
||||
if merge_url.is_relative_url:
|
||||
merge_raw_path = (
|
||||
self._azure_endpoint.raw_path.rstrip(b"/") + b"/openai/" + merge_url.raw_path.lstrip(b"/")
|
||||
)
|
||||
return self._azure_endpoint.copy_with(raw_path=merge_raw_path)
|
||||
|
||||
return merge_url
|
||||
|
||||
return super()._prepare_url(url)
|
||||
|
||||
|
||||
class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
azure_endpoint: str,
|
||||
azure_deployment: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
azure_ad_token: str | None = None,
|
||||
azure_ad_token_provider: AzureADTokenProvider | None = None,
|
||||
organization: str | None = None,
|
||||
websocket_base_url: str | httpx.URL | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
http_client: httpx.Client | None = None,
|
||||
_strict_response_validation: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
azure_deployment: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
azure_ad_token: str | None = None,
|
||||
azure_ad_token_provider: AzureADTokenProvider | None = None,
|
||||
organization: str | None = None,
|
||||
websocket_base_url: str | httpx.URL | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
http_client: httpx.Client | None = None,
|
||||
_strict_response_validation: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
azure_ad_token: str | None = None,
|
||||
azure_ad_token_provider: AzureADTokenProvider | None = None,
|
||||
organization: str | None = None,
|
||||
websocket_base_url: str | httpx.URL | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
http_client: httpx.Client | None = None,
|
||||
_strict_response_validation: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_version: str | None = None,
|
||||
azure_endpoint: str | None = None,
|
||||
azure_deployment: str | None = None,
|
||||
api_key: str | None = None,
|
||||
azure_ad_token: str | None = None,
|
||||
azure_ad_token_provider: AzureADTokenProvider | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
websocket_base_url: str | httpx.URL | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
http_client: httpx.Client | None = None,
|
||||
_strict_response_validation: bool = False,
|
||||
) -> None:
|
||||
"""Construct a new synchronous azure openai client instance.
|
||||
|
||||
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
|
||||
- `api_key` from `AZURE_OPENAI_API_KEY`
|
||||
- `organization` from `OPENAI_ORG_ID`
|
||||
- `project` from `OPENAI_PROJECT_ID`
|
||||
- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
|
||||
- `api_version` from `OPENAI_API_VERSION`
|
||||
- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
|
||||
|
||||
Args:
|
||||
azure_endpoint: Your Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`
|
||||
|
||||
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
||||
|
||||
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
|
||||
|
||||
azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
|
||||
Not supported with Assistants APIs.
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||
|
||||
if azure_ad_token is None:
|
||||
azure_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN")
|
||||
|
||||
if api_key is None and azure_ad_token is None and azure_ad_token_provider is None:
|
||||
raise OpenAIError(
|
||||
"Missing credentials. Please pass one of `api_key`, `azure_ad_token`, `azure_ad_token_provider`, or the `AZURE_OPENAI_API_KEY` or `AZURE_OPENAI_AD_TOKEN` environment variables."
|
||||
)
|
||||
|
||||
if api_version is None:
|
||||
api_version = os.environ.get("OPENAI_API_VERSION")
|
||||
|
||||
if api_version is None:
|
||||
raise ValueError(
|
||||
"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
|
||||
)
|
||||
|
||||
if default_query is None:
|
||||
default_query = {"api-version": api_version}
|
||||
else:
|
||||
default_query = {**default_query, "api-version": api_version}
|
||||
|
||||
if base_url is None:
|
||||
if azure_endpoint is None:
|
||||
azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
|
||||
if azure_endpoint is None:
|
||||
raise ValueError(
|
||||
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
|
||||
)
|
||||
|
||||
if azure_deployment is not None:
|
||||
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
|
||||
else:
|
||||
base_url = f"{azure_endpoint.rstrip('/')}/openai"
|
||||
else:
|
||||
if azure_endpoint is not None:
|
||||
raise ValueError("base_url and azure_endpoint are mutually exclusive")
|
||||
|
||||
if api_key is None:
|
||||
# define a sentinel value to avoid any typing issues
|
||||
api_key = API_KEY_SENTINEL
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
project=project,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
default_headers=default_headers,
|
||||
default_query=default_query,
|
||||
http_client=http_client,
|
||||
websocket_base_url=websocket_base_url,
|
||||
_strict_response_validation=_strict_response_validation,
|
||||
)
|
||||
self._api_version = api_version
|
||||
self._azure_ad_token = azure_ad_token
|
||||
self._azure_ad_token_provider = azure_ad_token_provider
|
||||
self._azure_deployment = azure_deployment if azure_endpoint else None
|
||||
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
|
||||
|
||||
@override
|
||||
def copy(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
websocket_base_url: str | httpx.URL | None = None,
|
||||
api_version: str | None = None,
|
||||
azure_ad_token: str | None = None,
|
||||
azure_ad_token_provider: AzureADTokenProvider | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
http_client: httpx.Client | None = None,
|
||||
max_retries: int | NotGiven = NOT_GIVEN,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
set_default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
set_default_query: Mapping[str, object] | None = None,
|
||||
_extra_kwargs: Mapping[str, Any] = {},
|
||||
) -> Self:
|
||||
"""
|
||||
Create a new client instance re-using the same options given to the current client with optional overriding.
|
||||
"""
|
||||
return super().copy(
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
project=project,
|
||||
websocket_base_url=websocket_base_url,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
http_client=http_client,
|
||||
max_retries=max_retries,
|
||||
default_headers=default_headers,
|
||||
set_default_headers=set_default_headers,
|
||||
default_query=default_query,
|
||||
set_default_query=set_default_query,
|
||||
_extra_kwargs={
|
||||
"api_version": api_version or self._api_version,
|
||||
"azure_ad_token": azure_ad_token or self._azure_ad_token,
|
||||
"azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
|
||||
**_extra_kwargs,
|
||||
},
|
||||
)
|
||||
|
||||
with_options = copy
|
||||
|
||||
def _get_azure_ad_token(self) -> str | None:
|
||||
if self._azure_ad_token is not None:
|
||||
return self._azure_ad_token
|
||||
|
||||
provider = self._azure_ad_token_provider
|
||||
if provider is not None:
|
||||
token = provider()
|
||||
if not token or not isinstance(token, str): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
raise ValueError(
|
||||
f"Expected `azure_ad_token_provider` argument to return a string but it returned {token}",
|
||||
)
|
||||
return token
|
||||
|
||||
return None
|
||||
|
||||
@override
|
||||
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
|
||||
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
|
||||
|
||||
options = model_copy(options)
|
||||
options.headers = headers
|
||||
|
||||
azure_ad_token = self._get_azure_ad_token()
|
||||
if azure_ad_token is not None:
|
||||
if headers.get("Authorization") is None:
|
||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||
elif self.api_key is not API_KEY_SENTINEL:
|
||||
if headers.get("api-key") is None:
|
||||
headers["api-key"] = self.api_key
|
||||
else:
|
||||
# should never be hit
|
||||
raise ValueError("Unable to handle auth")
|
||||
|
||||
return options
|
||||
|
||||
def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
|
||||
auth_headers = {}
|
||||
query = {
|
||||
**extra_query,
|
||||
"api-version": self._api_version,
|
||||
"deployment": self._azure_deployment or model,
|
||||
}
|
||||
if self.api_key != "<missing API key>":
|
||||
auth_headers = {"api-key": self.api_key}
|
||||
else:
|
||||
token = self._get_azure_ad_token()
|
||||
if token:
|
||||
auth_headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
if self.websocket_base_url is not None:
|
||||
base_url = httpx.URL(self.websocket_base_url)
|
||||
merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
|
||||
realtime_url = base_url.copy_with(raw_path=merge_raw_path)
|
||||
else:
|
||||
base_url = self._prepare_url("/realtime")
|
||||
realtime_url = base_url.copy_with(scheme="wss")
|
||||
|
||||
url = realtime_url.copy_with(params={**query})
|
||||
return url, auth_headers
|
||||
|
||||
|
||||
class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
azure_endpoint: str,
|
||||
azure_deployment: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
azure_ad_token: str | None = None,
|
||||
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
websocket_base_url: str | httpx.URL | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
_strict_response_validation: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
azure_deployment: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
azure_ad_token: str | None = None,
|
||||
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
websocket_base_url: str | httpx.URL | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
_strict_response_validation: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_url: str,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
azure_ad_token: str | None = None,
|
||||
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
websocket_base_url: str | httpx.URL | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
_strict_response_validation: bool = False,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
azure_endpoint: str | None = None,
|
||||
azure_deployment: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
azure_ad_token: str | None = None,
|
||||
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
base_url: str | None = None,
|
||||
websocket_base_url: str | httpx.URL | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
_strict_response_validation: bool = False,
|
||||
) -> None:
|
||||
"""Construct a new asynchronous azure openai client instance.
|
||||
|
||||
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
|
||||
- `api_key` from `AZURE_OPENAI_API_KEY`
|
||||
- `organization` from `OPENAI_ORG_ID`
|
||||
- `project` from `OPENAI_PROJECT_ID`
|
||||
- `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
|
||||
- `api_version` from `OPENAI_API_VERSION`
|
||||
- `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
|
||||
|
||||
Args:
|
||||
azure_endpoint: Your Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`
|
||||
|
||||
azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
|
||||
|
||||
azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
|
||||
|
||||
azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
|
||||
Not supported with Assistants APIs.
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||
|
||||
if azure_ad_token is None:
|
||||
azure_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN")
|
||||
|
||||
if api_key is None and azure_ad_token is None and azure_ad_token_provider is None:
|
||||
raise OpenAIError(
|
||||
"Missing credentials. Please pass one of `api_key`, `azure_ad_token`, `azure_ad_token_provider`, or the `AZURE_OPENAI_API_KEY` or `AZURE_OPENAI_AD_TOKEN` environment variables."
|
||||
)
|
||||
|
||||
if api_version is None:
|
||||
api_version = os.environ.get("OPENAI_API_VERSION")
|
||||
|
||||
if api_version is None:
|
||||
raise ValueError(
|
||||
"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
|
||||
)
|
||||
|
||||
if default_query is None:
|
||||
default_query = {"api-version": api_version}
|
||||
else:
|
||||
default_query = {**default_query, "api-version": api_version}
|
||||
|
||||
if base_url is None:
|
||||
if azure_endpoint is None:
|
||||
azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
|
||||
if azure_endpoint is None:
|
||||
raise ValueError(
|
||||
"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
|
||||
)
|
||||
|
||||
if azure_deployment is not None:
|
||||
base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
|
||||
else:
|
||||
base_url = f"{azure_endpoint.rstrip('/')}/openai"
|
||||
else:
|
||||
if azure_endpoint is not None:
|
||||
raise ValueError("base_url and azure_endpoint are mutually exclusive")
|
||||
|
||||
if api_key is None:
|
||||
# define a sentinel value to avoid any typing issues
|
||||
api_key = API_KEY_SENTINEL
|
||||
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
project=project,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
default_headers=default_headers,
|
||||
default_query=default_query,
|
||||
http_client=http_client,
|
||||
websocket_base_url=websocket_base_url,
|
||||
_strict_response_validation=_strict_response_validation,
|
||||
)
|
||||
self._api_version = api_version
|
||||
self._azure_ad_token = azure_ad_token
|
||||
self._azure_ad_token_provider = azure_ad_token_provider
|
||||
self._azure_deployment = azure_deployment if azure_endpoint else None
|
||||
self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
|
||||
|
||||
@override
|
||||
def copy(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
websocket_base_url: str | httpx.URL | None = None,
|
||||
api_version: str | None = None,
|
||||
azure_ad_token: str | None = None,
|
||||
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
max_retries: int | NotGiven = NOT_GIVEN,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
set_default_headers: Mapping[str, str] | None = None,
|
||||
default_query: Mapping[str, object] | None = None,
|
||||
set_default_query: Mapping[str, object] | None = None,
|
||||
_extra_kwargs: Mapping[str, Any] = {},
|
||||
) -> Self:
|
||||
"""
|
||||
Create a new client instance re-using the same options given to the current client with optional overriding.
|
||||
"""
|
||||
return super().copy(
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
project=project,
|
||||
websocket_base_url=websocket_base_url,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
http_client=http_client,
|
||||
max_retries=max_retries,
|
||||
default_headers=default_headers,
|
||||
set_default_headers=set_default_headers,
|
||||
default_query=default_query,
|
||||
set_default_query=set_default_query,
|
||||
_extra_kwargs={
|
||||
"api_version": api_version or self._api_version,
|
||||
"azure_ad_token": azure_ad_token or self._azure_ad_token,
|
||||
"azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
|
||||
**_extra_kwargs,
|
||||
},
|
||||
)
|
||||
|
||||
with_options = copy
|
||||
|
||||
async def _get_azure_ad_token(self) -> str | None:
|
||||
if self._azure_ad_token is not None:
|
||||
return self._azure_ad_token
|
||||
|
||||
provider = self._azure_ad_token_provider
|
||||
if provider is not None:
|
||||
token = provider()
|
||||
if inspect.isawaitable(token):
|
||||
token = await token
|
||||
if not token or not isinstance(cast(Any, token), str):
|
||||
raise ValueError(
|
||||
f"Expected `azure_ad_token_provider` argument to return a string but it returned {token}",
|
||||
)
|
||||
return str(token)
|
||||
|
||||
return None
|
||||
|
||||
@override
|
||||
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
|
||||
headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
|
||||
|
||||
options = model_copy(options)
|
||||
options.headers = headers
|
||||
|
||||
azure_ad_token = await self._get_azure_ad_token()
|
||||
if azure_ad_token is not None:
|
||||
if headers.get("Authorization") is None:
|
||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||
elif self.api_key is not API_KEY_SENTINEL:
|
||||
if headers.get("api-key") is None:
|
||||
headers["api-key"] = self.api_key
|
||||
else:
|
||||
# should never be hit
|
||||
raise ValueError("Unable to handle auth")
|
||||
|
||||
return options
|
||||
|
||||
async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
|
||||
auth_headers = {}
|
||||
query = {
|
||||
**extra_query,
|
||||
"api-version": self._api_version,
|
||||
"deployment": self._azure_deployment or model,
|
||||
}
|
||||
if self.api_key != "<missing API key>":
|
||||
auth_headers = {"api-key": self.api_key}
|
||||
else:
|
||||
token = await self._get_azure_ad_token()
|
||||
if token:
|
||||
auth_headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
if self.websocket_base_url is not None:
|
||||
base_url = httpx.URL(self.websocket_base_url)
|
||||
merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
|
||||
realtime_url = base_url.copy_with(raw_path=merge_raw_path)
|
||||
else:
|
||||
base_url = self._prepare_url("/realtime")
|
||||
realtime_url = base_url.copy_with(scheme="wss")
|
||||
|
||||
url = realtime_url.copy_with(params={**query})
|
||||
return url, auth_headers
|
||||
@@ -0,0 +1,8 @@
|
||||
from ._assistants import (
|
||||
AssistantEventHandler as AssistantEventHandler,
|
||||
AssistantEventHandlerT as AssistantEventHandlerT,
|
||||
AssistantStreamManager as AssistantStreamManager,
|
||||
AsyncAssistantEventHandler as AsyncAssistantEventHandler,
|
||||
AsyncAssistantEventHandlerT as AsyncAssistantEventHandlerT,
|
||||
AsyncAssistantStreamManager as AsyncAssistantStreamManager,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..._utils import is_dict, is_list
|
||||
|
||||
|
||||
def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]:
|
||||
for key, delta_value in delta.items():
|
||||
if key not in acc:
|
||||
acc[key] = delta_value
|
||||
continue
|
||||
|
||||
acc_value = acc[key]
|
||||
if acc_value is None:
|
||||
acc[key] = delta_value
|
||||
continue
|
||||
|
||||
# the `index` property is used in arrays of objects so it should
|
||||
# not be accumulated like other values e.g.
|
||||
# [{'foo': 'bar', 'index': 0}]
|
||||
#
|
||||
# the same applies to `type` properties as they're used for
|
||||
# discriminated unions
|
||||
if key == "index" or key == "type":
|
||||
acc[key] = delta_value
|
||||
continue
|
||||
|
||||
if isinstance(acc_value, str) and isinstance(delta_value, str):
|
||||
acc_value += delta_value
|
||||
elif isinstance(acc_value, (int, float)) and isinstance(delta_value, (int, float)):
|
||||
acc_value += delta_value
|
||||
elif is_dict(acc_value) and is_dict(delta_value):
|
||||
acc_value = accumulate_delta(acc_value, delta_value)
|
||||
elif is_list(acc_value) and is_list(delta_value):
|
||||
# for lists of non-dictionary items we'll only ever get new entries
|
||||
# in the array, existing entries will never be changed
|
||||
if all(isinstance(x, (str, int, float)) for x in acc_value):
|
||||
acc_value.extend(delta_value)
|
||||
continue
|
||||
|
||||
for delta_entry in delta_value:
|
||||
if not is_dict(delta_entry):
|
||||
raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}")
|
||||
|
||||
try:
|
||||
index = delta_entry["index"]
|
||||
except KeyError as exc:
|
||||
raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc
|
||||
|
||||
if not isinstance(index, int):
|
||||
raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}")
|
||||
|
||||
try:
|
||||
acc_entry = acc_value[index]
|
||||
except IndexError:
|
||||
acc_value.insert(index, delta_entry)
|
||||
else:
|
||||
if not is_dict(acc_entry):
|
||||
raise TypeError("not handled yet")
|
||||
|
||||
acc_value[index] = accumulate_delta(acc_entry, delta_entry)
|
||||
|
||||
acc[key] = acc_value
|
||||
|
||||
return acc
|
||||
@@ -0,0 +1,27 @@
|
||||
from ._types import (
|
||||
ParsedChoiceSnapshot as ParsedChoiceSnapshot,
|
||||
ParsedChatCompletionSnapshot as ParsedChatCompletionSnapshot,
|
||||
ParsedChatCompletionMessageSnapshot as ParsedChatCompletionMessageSnapshot,
|
||||
)
|
||||
from ._events import (
|
||||
ChunkEvent as ChunkEvent,
|
||||
ContentDoneEvent as ContentDoneEvent,
|
||||
RefusalDoneEvent as RefusalDoneEvent,
|
||||
ContentDeltaEvent as ContentDeltaEvent,
|
||||
RefusalDeltaEvent as RefusalDeltaEvent,
|
||||
LogprobsContentDoneEvent as LogprobsContentDoneEvent,
|
||||
LogprobsRefusalDoneEvent as LogprobsRefusalDoneEvent,
|
||||
ChatCompletionStreamEvent as ChatCompletionStreamEvent,
|
||||
LogprobsContentDeltaEvent as LogprobsContentDeltaEvent,
|
||||
LogprobsRefusalDeltaEvent as LogprobsRefusalDeltaEvent,
|
||||
ParsedChatCompletionSnapshot as ParsedChatCompletionSnapshot,
|
||||
FunctionToolCallArgumentsDoneEvent as FunctionToolCallArgumentsDoneEvent,
|
||||
FunctionToolCallArgumentsDeltaEvent as FunctionToolCallArgumentsDeltaEvent,
|
||||
)
|
||||
from ._completions import (
|
||||
ChatCompletionStream as ChatCompletionStream,
|
||||
AsyncChatCompletionStream as AsyncChatCompletionStream,
|
||||
ChatCompletionStreamState as ChatCompletionStreamState,
|
||||
ChatCompletionStreamManager as ChatCompletionStreamManager,
|
||||
AsyncChatCompletionStreamManager as AsyncChatCompletionStreamManager,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,768 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from types import TracebackType
|
||||
from typing import TYPE_CHECKING, Any, Generic, Callable, Iterable, Awaitable, AsyncIterator, cast
|
||||
from typing_extensions import Self, Iterator, assert_never
|
||||
|
||||
from jiter import from_json
|
||||
|
||||
from ._types import ParsedChoiceSnapshot, ParsedChatCompletionSnapshot, ParsedChatCompletionMessageSnapshot
|
||||
from ._events import (
|
||||
ChunkEvent,
|
||||
ContentDoneEvent,
|
||||
RefusalDoneEvent,
|
||||
ContentDeltaEvent,
|
||||
RefusalDeltaEvent,
|
||||
LogprobsContentDoneEvent,
|
||||
LogprobsRefusalDoneEvent,
|
||||
ChatCompletionStreamEvent,
|
||||
LogprobsContentDeltaEvent,
|
||||
LogprobsRefusalDeltaEvent,
|
||||
FunctionToolCallArgumentsDoneEvent,
|
||||
FunctionToolCallArgumentsDeltaEvent,
|
||||
)
|
||||
from .._deltas import accumulate_delta
|
||||
from ...._types import NOT_GIVEN, IncEx, NotGiven
|
||||
from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
|
||||
from ...._compat import model_dump
|
||||
from ...._models import build, construct_type
|
||||
from ..._parsing import (
|
||||
ResponseFormatT,
|
||||
has_parseable_input,
|
||||
maybe_parse_content,
|
||||
parse_chat_completion,
|
||||
get_input_tool_by_name,
|
||||
solve_response_format_t,
|
||||
parse_function_tool_arguments,
|
||||
)
|
||||
from ...._streaming import Stream, AsyncStream
|
||||
from ....types.chat import ChatCompletionChunk, ParsedChatCompletion, ChatCompletionToolParam
|
||||
from ...._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
|
||||
from ....types.chat.chat_completion import ChoiceLogprobs
|
||||
from ....types.chat.chat_completion_chunk import Choice as ChoiceChunk
|
||||
from ....types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
|
||||
|
||||
|
||||
class ChatCompletionStream(Generic[ResponseFormatT]):
|
||||
"""Wrapper over the Chat Completions streaming API that adds helpful
|
||||
events such as `content.done`, supports automatically parsing
|
||||
responses & tool calls and accumulates a `ChatCompletion` object
|
||||
from each individual chunk.
|
||||
|
||||
https://platform.openai.com/docs/api-reference/streaming
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
raw_stream: Stream[ChatCompletionChunk],
|
||||
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
||||
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
|
||||
) -> None:
|
||||
self._raw_stream = raw_stream
|
||||
self._response = raw_stream.response
|
||||
self._iterator = self.__stream__()
|
||||
self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools)
|
||||
|
||||
def __next__(self) -> ChatCompletionStreamEvent[ResponseFormatT]:
|
||||
return self._iterator.__next__()
|
||||
|
||||
def __iter__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]:
|
||||
for item in self._iterator:
|
||||
yield item
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the response and release the connection.
|
||||
|
||||
Automatically called if the response body is read to completion.
|
||||
"""
|
||||
self._response.close()
|
||||
|
||||
def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
|
||||
"""Waits until the stream has been read to completion and returns
|
||||
the accumulated `ParsedChatCompletion` object.
|
||||
|
||||
If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed`
|
||||
property will be the content deserialised into that class, if there was any content returned
|
||||
by the API.
|
||||
"""
|
||||
self.until_done()
|
||||
return self._state.get_final_completion()
|
||||
|
||||
def until_done(self) -> Self:
|
||||
"""Blocks until the stream has been consumed."""
|
||||
consume_sync_iterator(self)
|
||||
return self
|
||||
|
||||
@property
|
||||
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
|
||||
return self._state.current_completion_snapshot
|
||||
|
||||
def __stream__(self) -> Iterator[ChatCompletionStreamEvent[ResponseFormatT]]:
|
||||
for sse_event in self._raw_stream:
|
||||
if not _is_valid_chat_completion_chunk_weak(sse_event):
|
||||
continue
|
||||
events_to_fire = self._state.handle_chunk(sse_event)
|
||||
for event in events_to_fire:
|
||||
yield event
|
||||
|
||||
|
||||
class ChatCompletionStreamManager(Generic[ResponseFormatT]):
|
||||
"""Context manager over a `ChatCompletionStream` that is returned by `.stream()`.
|
||||
|
||||
This context manager ensures the response cannot be leaked if you don't read
|
||||
the stream to completion.
|
||||
|
||||
Usage:
|
||||
```py
|
||||
with client.beta.chat.completions.stream(...) as stream:
|
||||
for event in stream:
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_request: Callable[[], Stream[ChatCompletionChunk]],
|
||||
*,
|
||||
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
||||
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
|
||||
) -> None:
|
||||
self.__stream: ChatCompletionStream[ResponseFormatT] | None = None
|
||||
self.__api_request = api_request
|
||||
self.__response_format = response_format
|
||||
self.__input_tools = input_tools
|
||||
|
||||
def __enter__(self) -> ChatCompletionStream[ResponseFormatT]:
|
||||
raw_stream = self.__api_request()
|
||||
|
||||
self.__stream = ChatCompletionStream(
|
||||
raw_stream=raw_stream,
|
||||
response_format=self.__response_format,
|
||||
input_tools=self.__input_tools,
|
||||
)
|
||||
|
||||
return self.__stream
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
if self.__stream is not None:
|
||||
self.__stream.close()
|
||||
|
||||
|
||||
class AsyncChatCompletionStream(Generic[ResponseFormatT]):
|
||||
"""Wrapper over the Chat Completions streaming API that adds helpful
|
||||
events such as `content.done`, supports automatically parsing
|
||||
responses & tool calls and accumulates a `ChatCompletion` object
|
||||
from each individual chunk.
|
||||
|
||||
https://platform.openai.com/docs/api-reference/streaming
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
raw_stream: AsyncStream[ChatCompletionChunk],
|
||||
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
||||
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
|
||||
) -> None:
|
||||
self._raw_stream = raw_stream
|
||||
self._response = raw_stream.response
|
||||
self._iterator = self.__stream__()
|
||||
self._state = ChatCompletionStreamState(response_format=response_format, input_tools=input_tools)
|
||||
|
||||
async def __anext__(self) -> ChatCompletionStreamEvent[ResponseFormatT]:
|
||||
return await self._iterator.__anext__()
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]:
|
||||
async for item in self._iterator:
|
||||
yield item
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the response and release the connection.
|
||||
|
||||
Automatically called if the response body is read to completion.
|
||||
"""
|
||||
await self._response.aclose()
|
||||
|
||||
async def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
|
||||
"""Waits until the stream has been read to completion and returns
|
||||
the accumulated `ParsedChatCompletion` object.
|
||||
|
||||
If you passed a class type to `.stream()`, the `completion.choices[0].message.parsed`
|
||||
property will be the content deserialised into that class, if there was any content returned
|
||||
by the API.
|
||||
"""
|
||||
await self.until_done()
|
||||
return self._state.get_final_completion()
|
||||
|
||||
async def until_done(self) -> Self:
|
||||
"""Blocks until the stream has been consumed."""
|
||||
await consume_async_iterator(self)
|
||||
return self
|
||||
|
||||
@property
|
||||
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
|
||||
return self._state.current_completion_snapshot
|
||||
|
||||
async def __stream__(self) -> AsyncIterator[ChatCompletionStreamEvent[ResponseFormatT]]:
|
||||
async for sse_event in self._raw_stream:
|
||||
if not _is_valid_chat_completion_chunk_weak(sse_event):
|
||||
continue
|
||||
events_to_fire = self._state.handle_chunk(sse_event)
|
||||
for event in events_to_fire:
|
||||
yield event
|
||||
|
||||
|
||||
class AsyncChatCompletionStreamManager(Generic[ResponseFormatT]):
|
||||
"""Context manager over a `AsyncChatCompletionStream` that is returned by `.stream()`.
|
||||
|
||||
This context manager ensures the response cannot be leaked if you don't read
|
||||
the stream to completion.
|
||||
|
||||
Usage:
|
||||
```py
|
||||
async with client.beta.chat.completions.stream(...) as stream:
|
||||
for event in stream:
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_request: Awaitable[AsyncStream[ChatCompletionChunk]],
|
||||
*,
|
||||
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
||||
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
|
||||
) -> None:
|
||||
self.__stream: AsyncChatCompletionStream[ResponseFormatT] | None = None
|
||||
self.__api_request = api_request
|
||||
self.__response_format = response_format
|
||||
self.__input_tools = input_tools
|
||||
|
||||
async def __aenter__(self) -> AsyncChatCompletionStream[ResponseFormatT]:
|
||||
raw_stream = await self.__api_request
|
||||
|
||||
self.__stream = AsyncChatCompletionStream(
|
||||
raw_stream=raw_stream,
|
||||
response_format=self.__response_format,
|
||||
input_tools=self.__input_tools,
|
||||
)
|
||||
|
||||
return self.__stream
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
if self.__stream is not None:
|
||||
await self.__stream.close()
|
||||
|
||||
|
||||
class ChatCompletionStreamState(Generic[ResponseFormatT]):
|
||||
"""Helper class for manually accumulating `ChatCompletionChunk`s into a final `ChatCompletion` object.
|
||||
|
||||
This is useful in cases where you can't always use the `.stream()` method, e.g.
|
||||
|
||||
```py
|
||||
from openai.lib.streaming.chat import ChatCompletionStreamState
|
||||
|
||||
state = ChatCompletionStreamState()
|
||||
|
||||
stream = client.chat.completions.create(..., stream=True)
|
||||
for chunk in response:
|
||||
state.handle_chunk(chunk)
|
||||
|
||||
# can also access the accumulated `ChatCompletion` mid-stream
|
||||
state.current_completion_snapshot
|
||||
|
||||
print(state.get_final_completion())
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven = NOT_GIVEN,
|
||||
) -> None:
|
||||
self.__current_completion_snapshot: ParsedChatCompletionSnapshot | None = None
|
||||
self.__choice_event_states: list[ChoiceEventState] = []
|
||||
|
||||
self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
|
||||
self._response_format = response_format
|
||||
self._rich_response_format: type | NotGiven = response_format if inspect.isclass(response_format) else NOT_GIVEN
|
||||
|
||||
def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
|
||||
"""Parse the final completion object.
|
||||
|
||||
Note this does not provide any guarantees that the stream has actually finished, you must
|
||||
only call this method when the stream is finished.
|
||||
"""
|
||||
return parse_chat_completion(
|
||||
chat_completion=self.current_completion_snapshot,
|
||||
response_format=self._rich_response_format,
|
||||
input_tools=self._input_tools,
|
||||
)
|
||||
|
||||
@property
|
||||
def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
|
||||
assert self.__current_completion_snapshot is not None
|
||||
return self.__current_completion_snapshot
|
||||
|
||||
def handle_chunk(self, chunk: ChatCompletionChunk) -> Iterable[ChatCompletionStreamEvent[ResponseFormatT]]:
|
||||
"""Accumulate a new chunk into the snapshot and returns an iterable of events to yield."""
|
||||
self.__current_completion_snapshot = self._accumulate_chunk(chunk)
|
||||
|
||||
return self._build_events(
|
||||
chunk=chunk,
|
||||
completion_snapshot=self.__current_completion_snapshot,
|
||||
)
|
||||
|
||||
def _get_choice_state(self, choice: ChoiceChunk) -> ChoiceEventState:
|
||||
try:
|
||||
return self.__choice_event_states[choice.index]
|
||||
except IndexError:
|
||||
choice_state = ChoiceEventState(input_tools=self._input_tools)
|
||||
self.__choice_event_states.append(choice_state)
|
||||
return choice_state
|
||||
|
||||
def _accumulate_chunk(self, chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot:
|
||||
completion_snapshot = self.__current_completion_snapshot
|
||||
|
||||
if completion_snapshot is None:
|
||||
return _convert_initial_chunk_into_snapshot(chunk)
|
||||
|
||||
for choice in chunk.choices:
|
||||
try:
|
||||
choice_snapshot = completion_snapshot.choices[choice.index]
|
||||
previous_tool_calls = choice_snapshot.message.tool_calls or []
|
||||
|
||||
choice_snapshot.message = cast(
|
||||
ParsedChatCompletionMessageSnapshot,
|
||||
construct_type(
|
||||
type_=ParsedChatCompletionMessageSnapshot,
|
||||
value=accumulate_delta(
|
||||
cast(
|
||||
"dict[object, object]",
|
||||
model_dump(
|
||||
choice_snapshot.message,
|
||||
# we don't want to serialise / deserialise our custom properties
|
||||
# as they won't appear in the delta and we don't want to have to
|
||||
# continuosly reparse the content
|
||||
exclude=cast(
|
||||
# cast required as mypy isn't smart enough to infer `True` here to `Literal[True]`
|
||||
IncEx,
|
||||
{
|
||||
"parsed": True,
|
||||
"tool_calls": {
|
||||
idx: {"function": {"parsed_arguments": True}}
|
||||
for idx, _ in enumerate(choice_snapshot.message.tool_calls or [])
|
||||
},
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
cast("dict[object, object]", choice.delta.to_dict()),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# ensure tools that have already been parsed are added back into the newly
|
||||
# constructed message snapshot
|
||||
for tool_index, prev_tool in enumerate(previous_tool_calls):
|
||||
new_tool = (choice_snapshot.message.tool_calls or [])[tool_index]
|
||||
|
||||
if prev_tool.type == "function":
|
||||
assert new_tool.type == "function"
|
||||
new_tool.function.parsed_arguments = prev_tool.function.parsed_arguments
|
||||
elif TYPE_CHECKING: # type: ignore[unreachable]
|
||||
assert_never(prev_tool)
|
||||
except IndexError:
|
||||
choice_snapshot = cast(
|
||||
ParsedChoiceSnapshot,
|
||||
construct_type(
|
||||
type_=ParsedChoiceSnapshot,
|
||||
value={
|
||||
**choice.model_dump(exclude_unset=True, exclude={"delta"}),
|
||||
"message": choice.delta.to_dict(),
|
||||
},
|
||||
),
|
||||
)
|
||||
completion_snapshot.choices.append(choice_snapshot)
|
||||
|
||||
if choice.finish_reason:
|
||||
choice_snapshot.finish_reason = choice.finish_reason
|
||||
|
||||
if has_parseable_input(response_format=self._response_format, input_tools=self._input_tools):
|
||||
if choice.finish_reason == "length":
|
||||
# at the time of writing, `.usage` will always be `None` but
|
||||
# we include it here in case that is changed in the future
|
||||
raise LengthFinishReasonError(completion=completion_snapshot)
|
||||
|
||||
if choice.finish_reason == "content_filter":
|
||||
raise ContentFilterFinishReasonError()
|
||||
|
||||
if (
|
||||
choice_snapshot.message.content
|
||||
and not choice_snapshot.message.refusal
|
||||
and is_given(self._rich_response_format)
|
||||
):
|
||||
choice_snapshot.message.parsed = from_json(
|
||||
bytes(choice_snapshot.message.content, "utf-8"),
|
||||
partial_mode=True,
|
||||
)
|
||||
|
||||
for tool_call_chunk in choice.delta.tool_calls or []:
|
||||
tool_call_snapshot = (choice_snapshot.message.tool_calls or [])[tool_call_chunk.index]
|
||||
|
||||
if tool_call_snapshot.type == "function":
|
||||
input_tool = get_input_tool_by_name(
|
||||
input_tools=self._input_tools, name=tool_call_snapshot.function.name
|
||||
)
|
||||
|
||||
if (
|
||||
input_tool
|
||||
and input_tool.get("function", {}).get("strict")
|
||||
and tool_call_snapshot.function.arguments
|
||||
):
|
||||
tool_call_snapshot.function.parsed_arguments = from_json(
|
||||
bytes(tool_call_snapshot.function.arguments, "utf-8"),
|
||||
partial_mode=True,
|
||||
)
|
||||
elif TYPE_CHECKING: # type: ignore[unreachable]
|
||||
assert_never(tool_call_snapshot)
|
||||
|
||||
if choice.logprobs is not None:
|
||||
if choice_snapshot.logprobs is None:
|
||||
choice_snapshot.logprobs = build(
|
||||
ChoiceLogprobs,
|
||||
content=choice.logprobs.content,
|
||||
refusal=choice.logprobs.refusal,
|
||||
)
|
||||
else:
|
||||
if choice.logprobs.content:
|
||||
if choice_snapshot.logprobs.content is None:
|
||||
choice_snapshot.logprobs.content = []
|
||||
|
||||
choice_snapshot.logprobs.content.extend(choice.logprobs.content)
|
||||
|
||||
if choice.logprobs.refusal:
|
||||
if choice_snapshot.logprobs.refusal is None:
|
||||
choice_snapshot.logprobs.refusal = []
|
||||
|
||||
choice_snapshot.logprobs.refusal.extend(choice.logprobs.refusal)
|
||||
|
||||
completion_snapshot.usage = chunk.usage
|
||||
completion_snapshot.system_fingerprint = chunk.system_fingerprint
|
||||
|
||||
return completion_snapshot
|
||||
|
||||
def _build_events(
|
||||
self,
|
||||
*,
|
||||
chunk: ChatCompletionChunk,
|
||||
completion_snapshot: ParsedChatCompletionSnapshot,
|
||||
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
|
||||
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
|
||||
|
||||
events_to_fire.append(
|
||||
build(ChunkEvent, type="chunk", chunk=chunk, snapshot=completion_snapshot),
|
||||
)
|
||||
|
||||
for choice in chunk.choices:
|
||||
choice_state = self._get_choice_state(choice)
|
||||
choice_snapshot = completion_snapshot.choices[choice.index]
|
||||
|
||||
if choice.delta.content is not None and choice_snapshot.message.content is not None:
|
||||
events_to_fire.append(
|
||||
build(
|
||||
ContentDeltaEvent,
|
||||
type="content.delta",
|
||||
delta=choice.delta.content,
|
||||
snapshot=choice_snapshot.message.content,
|
||||
parsed=choice_snapshot.message.parsed,
|
||||
)
|
||||
)
|
||||
|
||||
if choice.delta.refusal is not None and choice_snapshot.message.refusal is not None:
|
||||
events_to_fire.append(
|
||||
build(
|
||||
RefusalDeltaEvent,
|
||||
type="refusal.delta",
|
||||
delta=choice.delta.refusal,
|
||||
snapshot=choice_snapshot.message.refusal,
|
||||
)
|
||||
)
|
||||
|
||||
if choice.delta.tool_calls:
|
||||
tool_calls = choice_snapshot.message.tool_calls
|
||||
assert tool_calls is not None
|
||||
|
||||
for tool_call_delta in choice.delta.tool_calls:
|
||||
tool_call = tool_calls[tool_call_delta.index]
|
||||
|
||||
if tool_call.type == "function":
|
||||
assert tool_call_delta.function is not None
|
||||
events_to_fire.append(
|
||||
build(
|
||||
FunctionToolCallArgumentsDeltaEvent,
|
||||
type="tool_calls.function.arguments.delta",
|
||||
name=tool_call.function.name,
|
||||
index=tool_call_delta.index,
|
||||
arguments=tool_call.function.arguments,
|
||||
parsed_arguments=tool_call.function.parsed_arguments,
|
||||
arguments_delta=tool_call_delta.function.arguments or "",
|
||||
)
|
||||
)
|
||||
elif TYPE_CHECKING: # type: ignore[unreachable]
|
||||
assert_never(tool_call)
|
||||
|
||||
if choice.logprobs is not None and choice_snapshot.logprobs is not None:
|
||||
if choice.logprobs.content and choice_snapshot.logprobs.content:
|
||||
events_to_fire.append(
|
||||
build(
|
||||
LogprobsContentDeltaEvent,
|
||||
type="logprobs.content.delta",
|
||||
content=choice.logprobs.content,
|
||||
snapshot=choice_snapshot.logprobs.content,
|
||||
),
|
||||
)
|
||||
|
||||
if choice.logprobs.refusal and choice_snapshot.logprobs.refusal:
|
||||
events_to_fire.append(
|
||||
build(
|
||||
LogprobsRefusalDeltaEvent,
|
||||
type="logprobs.refusal.delta",
|
||||
refusal=choice.logprobs.refusal,
|
||||
snapshot=choice_snapshot.logprobs.refusal,
|
||||
),
|
||||
)
|
||||
|
||||
events_to_fire.extend(
|
||||
choice_state.get_done_events(
|
||||
choice_chunk=choice,
|
||||
choice_snapshot=choice_snapshot,
|
||||
response_format=self._response_format,
|
||||
)
|
||||
)
|
||||
|
||||
return events_to_fire
|
||||
|
||||
|
||||
class ChoiceEventState:
|
||||
def __init__(self, *, input_tools: list[ChatCompletionToolParam]) -> None:
|
||||
self._input_tools = input_tools
|
||||
|
||||
self._content_done = False
|
||||
self._refusal_done = False
|
||||
self._logprobs_content_done = False
|
||||
self._logprobs_refusal_done = False
|
||||
self._done_tool_calls: set[int] = set()
|
||||
self.__current_tool_call_index: int | None = None
|
||||
|
||||
def get_done_events(
|
||||
self,
|
||||
*,
|
||||
choice_chunk: ChoiceChunk,
|
||||
choice_snapshot: ParsedChoiceSnapshot,
|
||||
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
||||
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
|
||||
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
|
||||
|
||||
if choice_snapshot.finish_reason:
|
||||
events_to_fire.extend(
|
||||
self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format)
|
||||
)
|
||||
|
||||
if (
|
||||
self.__current_tool_call_index is not None
|
||||
and self.__current_tool_call_index not in self._done_tool_calls
|
||||
):
|
||||
self._add_tool_done_event(
|
||||
events_to_fire=events_to_fire,
|
||||
choice_snapshot=choice_snapshot,
|
||||
tool_index=self.__current_tool_call_index,
|
||||
)
|
||||
|
||||
for tool_call in choice_chunk.delta.tool_calls or []:
|
||||
if self.__current_tool_call_index != tool_call.index:
|
||||
events_to_fire.extend(
|
||||
self._content_done_events(choice_snapshot=choice_snapshot, response_format=response_format)
|
||||
)
|
||||
|
||||
if self.__current_tool_call_index is not None:
|
||||
self._add_tool_done_event(
|
||||
events_to_fire=events_to_fire,
|
||||
choice_snapshot=choice_snapshot,
|
||||
tool_index=self.__current_tool_call_index,
|
||||
)
|
||||
|
||||
self.__current_tool_call_index = tool_call.index
|
||||
|
||||
return events_to_fire
|
||||
|
||||
def _content_done_events(
|
||||
self,
|
||||
*,
|
||||
choice_snapshot: ParsedChoiceSnapshot,
|
||||
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
|
||||
) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
|
||||
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]] = []
|
||||
|
||||
if choice_snapshot.message.content and not self._content_done:
|
||||
self._content_done = True
|
||||
|
||||
parsed = maybe_parse_content(
|
||||
response_format=response_format,
|
||||
message=choice_snapshot.message,
|
||||
)
|
||||
|
||||
# update the parsed content to now use the richer `response_format`
|
||||
# as opposed to the raw JSON-parsed object as the content is now
|
||||
# complete and can be fully validated.
|
||||
choice_snapshot.message.parsed = parsed
|
||||
|
||||
events_to_fire.append(
|
||||
build(
|
||||
# we do this dance so that when the `ContentDoneEvent` instance
|
||||
# is printed at runtime the class name will include the solved
|
||||
# type variable, e.g. `ContentDoneEvent[MyModelType]`
|
||||
cast( # pyright: ignore[reportUnnecessaryCast]
|
||||
"type[ContentDoneEvent[ResponseFormatT]]",
|
||||
cast(Any, ContentDoneEvent)[solve_response_format_t(response_format)],
|
||||
),
|
||||
type="content.done",
|
||||
content=choice_snapshot.message.content,
|
||||
parsed=parsed,
|
||||
),
|
||||
)
|
||||
|
||||
if choice_snapshot.message.refusal is not None and not self._refusal_done:
|
||||
self._refusal_done = True
|
||||
events_to_fire.append(
|
||||
build(RefusalDoneEvent, type="refusal.done", refusal=choice_snapshot.message.refusal),
|
||||
)
|
||||
|
||||
if (
|
||||
choice_snapshot.logprobs is not None
|
||||
and choice_snapshot.logprobs.content is not None
|
||||
and not self._logprobs_content_done
|
||||
):
|
||||
self._logprobs_content_done = True
|
||||
events_to_fire.append(
|
||||
build(LogprobsContentDoneEvent, type="logprobs.content.done", content=choice_snapshot.logprobs.content),
|
||||
)
|
||||
|
||||
if (
|
||||
choice_snapshot.logprobs is not None
|
||||
and choice_snapshot.logprobs.refusal is not None
|
||||
and not self._logprobs_refusal_done
|
||||
):
|
||||
self._logprobs_refusal_done = True
|
||||
events_to_fire.append(
|
||||
build(LogprobsRefusalDoneEvent, type="logprobs.refusal.done", refusal=choice_snapshot.logprobs.refusal),
|
||||
)
|
||||
|
||||
return events_to_fire
|
||||
|
||||
def _add_tool_done_event(
|
||||
self,
|
||||
*,
|
||||
events_to_fire: list[ChatCompletionStreamEvent[ResponseFormatT]],
|
||||
choice_snapshot: ParsedChoiceSnapshot,
|
||||
tool_index: int,
|
||||
) -> None:
|
||||
if tool_index in self._done_tool_calls:
|
||||
return
|
||||
|
||||
self._done_tool_calls.add(tool_index)
|
||||
|
||||
assert choice_snapshot.message.tool_calls is not None
|
||||
tool_call_snapshot = choice_snapshot.message.tool_calls[tool_index]
|
||||
|
||||
if tool_call_snapshot.type == "function":
|
||||
parsed_arguments = parse_function_tool_arguments(
|
||||
input_tools=self._input_tools, function=tool_call_snapshot.function
|
||||
)
|
||||
|
||||
# update the parsed content to potentially use a richer type
|
||||
# as opposed to the raw JSON-parsed object as the content is now
|
||||
# complete and can be fully validated.
|
||||
tool_call_snapshot.function.parsed_arguments = parsed_arguments
|
||||
|
||||
events_to_fire.append(
|
||||
build(
|
||||
FunctionToolCallArgumentsDoneEvent,
|
||||
type="tool_calls.function.arguments.done",
|
||||
index=tool_index,
|
||||
name=tool_call_snapshot.function.name,
|
||||
arguments=tool_call_snapshot.function.arguments,
|
||||
parsed_arguments=parsed_arguments,
|
||||
)
|
||||
)
|
||||
elif TYPE_CHECKING: # type: ignore[unreachable]
|
||||
assert_never(tool_call_snapshot)
|
||||
|
||||
|
||||
def _convert_initial_chunk_into_snapshot(chunk: ChatCompletionChunk) -> ParsedChatCompletionSnapshot:
|
||||
data = chunk.to_dict()
|
||||
choices = cast("list[object]", data["choices"])
|
||||
|
||||
for choice in chunk.choices:
|
||||
choices[choice.index] = {
|
||||
**choice.model_dump(exclude_unset=True, exclude={"delta"}),
|
||||
"message": choice.delta.to_dict(),
|
||||
}
|
||||
|
||||
return cast(
|
||||
ParsedChatCompletionSnapshot,
|
||||
construct_type(
|
||||
type_=ParsedChatCompletionSnapshot,
|
||||
value={
|
||||
"system_fingerprint": None,
|
||||
**data,
|
||||
"object": "chat.completion",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _is_valid_chat_completion_chunk_weak(sse_event: ChatCompletionChunk) -> bool:
|
||||
# Although the _raw_stream is always supposed to contain only objects adhering to ChatCompletionChunk schema,
|
||||
# this is broken by the Azure OpenAI in case of Asynchronous Filter enabled.
|
||||
# An easy filter is to check for the "object" property:
|
||||
# - should be "chat.completion.chunk" for a ChatCompletionChunk;
|
||||
# - is an empty string for Asynchronous Filter events.
|
||||
return sse_event.object == "chat.completion.chunk" # type: ignore # pylance reports this as a useless check
|
||||
@@ -0,0 +1,123 @@
|
||||
from typing import List, Union, Generic, Optional
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ._types import ParsedChatCompletionSnapshot
|
||||
from ...._models import BaseModel, GenericModel
|
||||
from ..._parsing import ResponseFormatT
|
||||
from ....types.chat import ChatCompletionChunk, ChatCompletionTokenLogprob
|
||||
|
||||
|
||||
class ChunkEvent(BaseModel):
|
||||
type: Literal["chunk"]
|
||||
|
||||
chunk: ChatCompletionChunk
|
||||
|
||||
snapshot: ParsedChatCompletionSnapshot
|
||||
|
||||
|
||||
class ContentDeltaEvent(BaseModel):
|
||||
"""This event is yielded for every chunk with `choice.delta.content` data."""
|
||||
|
||||
type: Literal["content.delta"]
|
||||
|
||||
delta: str
|
||||
|
||||
snapshot: str
|
||||
|
||||
parsed: Optional[object] = None
|
||||
|
||||
|
||||
class ContentDoneEvent(GenericModel, Generic[ResponseFormatT]):
|
||||
type: Literal["content.done"]
|
||||
|
||||
content: str
|
||||
|
||||
parsed: Optional[ResponseFormatT] = None
|
||||
|
||||
|
||||
class RefusalDeltaEvent(BaseModel):
|
||||
type: Literal["refusal.delta"]
|
||||
|
||||
delta: str
|
||||
|
||||
snapshot: str
|
||||
|
||||
|
||||
class RefusalDoneEvent(BaseModel):
|
||||
type: Literal["refusal.done"]
|
||||
|
||||
refusal: str
|
||||
|
||||
|
||||
class FunctionToolCallArgumentsDeltaEvent(BaseModel):
|
||||
type: Literal["tool_calls.function.arguments.delta"]
|
||||
|
||||
name: str
|
||||
|
||||
index: int
|
||||
|
||||
arguments: str
|
||||
"""Accumulated raw JSON string"""
|
||||
|
||||
parsed_arguments: object
|
||||
"""The parsed arguments so far"""
|
||||
|
||||
arguments_delta: str
|
||||
"""The JSON string delta"""
|
||||
|
||||
|
||||
class FunctionToolCallArgumentsDoneEvent(BaseModel):
|
||||
type: Literal["tool_calls.function.arguments.done"]
|
||||
|
||||
name: str
|
||||
|
||||
index: int
|
||||
|
||||
arguments: str
|
||||
"""Accumulated raw JSON string"""
|
||||
|
||||
parsed_arguments: object
|
||||
"""The parsed arguments"""
|
||||
|
||||
|
||||
class LogprobsContentDeltaEvent(BaseModel):
|
||||
type: Literal["logprobs.content.delta"]
|
||||
|
||||
content: List[ChatCompletionTokenLogprob]
|
||||
|
||||
snapshot: List[ChatCompletionTokenLogprob]
|
||||
|
||||
|
||||
class LogprobsContentDoneEvent(BaseModel):
|
||||
type: Literal["logprobs.content.done"]
|
||||
|
||||
content: List[ChatCompletionTokenLogprob]
|
||||
|
||||
|
||||
class LogprobsRefusalDeltaEvent(BaseModel):
|
||||
type: Literal["logprobs.refusal.delta"]
|
||||
|
||||
refusal: List[ChatCompletionTokenLogprob]
|
||||
|
||||
snapshot: List[ChatCompletionTokenLogprob]
|
||||
|
||||
|
||||
class LogprobsRefusalDoneEvent(BaseModel):
|
||||
type: Literal["logprobs.refusal.done"]
|
||||
|
||||
refusal: List[ChatCompletionTokenLogprob]
|
||||
|
||||
|
||||
ChatCompletionStreamEvent = Union[
|
||||
ChunkEvent,
|
||||
ContentDeltaEvent,
|
||||
ContentDoneEvent[ResponseFormatT],
|
||||
RefusalDeltaEvent,
|
||||
RefusalDoneEvent,
|
||||
FunctionToolCallArgumentsDeltaEvent,
|
||||
FunctionToolCallArgumentsDoneEvent,
|
||||
LogprobsContentDeltaEvent,
|
||||
LogprobsContentDoneEvent,
|
||||
LogprobsRefusalDeltaEvent,
|
||||
LogprobsRefusalDoneEvent,
|
||||
]
|
||||
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ....types.chat import ParsedChoice, ParsedChatCompletion, ParsedChatCompletionMessage
|
||||
|
||||
ParsedChatCompletionSnapshot: TypeAlias = ParsedChatCompletion[object]
|
||||
"""Snapshot type representing an in-progress accumulation of
|
||||
a `ParsedChatCompletion` object.
|
||||
"""
|
||||
|
||||
ParsedChatCompletionMessageSnapshot: TypeAlias = ParsedChatCompletionMessage[object]
|
||||
"""Snapshot type representing an in-progress accumulation of
|
||||
a `ParsedChatCompletionMessage` object.
|
||||
|
||||
If the content has been fully accumulated, the `.parsed` content will be
|
||||
the `response_format` instance, otherwise it'll be the raw JSON parsed version.
|
||||
"""
|
||||
|
||||
ParsedChoiceSnapshot: TypeAlias = ParsedChoice[object]
|
||||
@@ -0,0 +1,13 @@
|
||||
from ._events import (
|
||||
ResponseTextDoneEvent as ResponseTextDoneEvent,
|
||||
ResponseTextDeltaEvent as ResponseTextDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent as ResponseFunctionCallArgumentsDeltaEvent,
|
||||
)
|
||||
from ._responses import (
|
||||
ResponseStream as ResponseStream,
|
||||
AsyncResponseStream as AsyncResponseStream,
|
||||
ResponseStreamEvent as ResponseStreamEvent,
|
||||
ResponseStreamState as ResponseStreamState,
|
||||
ResponseStreamManager as ResponseStreamManager,
|
||||
AsyncResponseStreamManager as AsyncResponseStreamManager,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,106 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
from typing_extensions import Union, Generic, TypeVar, Annotated, TypeAlias
|
||||
|
||||
from ...._utils import PropertyInfo
|
||||
from ...._compat import GenericModel
|
||||
from ....types.responses import (
|
||||
ParsedResponse,
|
||||
ResponseErrorEvent,
|
||||
ResponseFailedEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseTextDoneEvent as RawResponseTextDoneEvent,
|
||||
ResponseAudioDoneEvent,
|
||||
ResponseCompletedEvent as RawResponseCompletedEvent,
|
||||
ResponseTextDeltaEvent as RawResponseTextDeltaEvent,
|
||||
ResponseAudioDeltaEvent,
|
||||
ResponseIncompleteEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseRefusalDoneEvent,
|
||||
ResponseRefusalDeltaEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseAudioTranscriptDoneEvent,
|
||||
ResponseTextAnnotationDeltaEvent,
|
||||
ResponseAudioTranscriptDeltaEvent,
|
||||
ResponseWebSearchCallCompletedEvent,
|
||||
ResponseWebSearchCallSearchingEvent,
|
||||
ResponseFileSearchCallCompletedEvent,
|
||||
ResponseFileSearchCallSearchingEvent,
|
||||
ResponseWebSearchCallInProgressEvent,
|
||||
ResponseFileSearchCallInProgressEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent as RawResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseCodeInterpreterCallCodeDoneEvent,
|
||||
ResponseCodeInterpreterCallCodeDeltaEvent,
|
||||
ResponseCodeInterpreterCallCompletedEvent,
|
||||
ResponseCodeInterpreterCallInProgressEvent,
|
||||
ResponseCodeInterpreterCallInterpretingEvent,
|
||||
)
|
||||
|
||||
TextFormatT = TypeVar(
|
||||
"TextFormatT",
|
||||
# if it isn't given then we don't do any parsing
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class ResponseTextDeltaEvent(RawResponseTextDeltaEvent):
|
||||
snapshot: str
|
||||
|
||||
|
||||
class ResponseTextDoneEvent(RawResponseTextDoneEvent, GenericModel, Generic[TextFormatT]):
|
||||
parsed: Optional[TextFormatT] = None
|
||||
|
||||
|
||||
class ResponseFunctionCallArgumentsDeltaEvent(RawResponseFunctionCallArgumentsDeltaEvent):
|
||||
snapshot: str
|
||||
|
||||
|
||||
class ResponseCompletedEvent(RawResponseCompletedEvent, GenericModel, Generic[TextFormatT]):
|
||||
response: ParsedResponse[TextFormatT] # type: ignore[assignment]
|
||||
|
||||
|
||||
ResponseStreamEvent: TypeAlias = Annotated[
|
||||
Union[
|
||||
# wrappers with snapshots added on
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent[TextFormatT],
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseCompletedEvent[TextFormatT],
|
||||
# the same as the non-accumulated API
|
||||
ResponseAudioDeltaEvent,
|
||||
ResponseAudioDoneEvent,
|
||||
ResponseAudioTranscriptDeltaEvent,
|
||||
ResponseAudioTranscriptDoneEvent,
|
||||
ResponseCodeInterpreterCallCodeDeltaEvent,
|
||||
ResponseCodeInterpreterCallCodeDoneEvent,
|
||||
ResponseCodeInterpreterCallCompletedEvent,
|
||||
ResponseCodeInterpreterCallInProgressEvent,
|
||||
ResponseCodeInterpreterCallInterpretingEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseErrorEvent,
|
||||
ResponseFileSearchCallCompletedEvent,
|
||||
ResponseFileSearchCallInProgressEvent,
|
||||
ResponseFileSearchCallSearchingEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseFailedEvent,
|
||||
ResponseIncompleteEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseRefusalDeltaEvent,
|
||||
ResponseRefusalDoneEvent,
|
||||
ResponseTextAnnotationDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseWebSearchCallCompletedEvent,
|
||||
ResponseWebSearchCallInProgressEvent,
|
||||
ResponseWebSearchCallSearchingEvent,
|
||||
],
|
||||
PropertyInfo(discriminator="type"),
|
||||
]
|
||||
@@ -0,0 +1,354 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from types import TracebackType
|
||||
from typing import Any, List, Generic, Iterable, Awaitable, cast
|
||||
from typing_extensions import Self, Callable, Iterator, AsyncIterator
|
||||
|
||||
from ._types import ParsedResponseSnapshot
|
||||
from ._events import (
|
||||
ResponseStreamEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseCompletedEvent,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
)
|
||||
from ...._types import NOT_GIVEN, NotGiven
|
||||
from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
|
||||
from ...._models import build, construct_type_unchecked
|
||||
from ...._streaming import Stream, AsyncStream
|
||||
from ....types.responses import ParsedResponse, ResponseStreamEvent as RawResponseStreamEvent
|
||||
from ..._parsing._responses import TextFormatT, parse_text, parse_response
|
||||
from ....types.responses.tool_param import ToolParam
|
||||
from ....types.responses.parsed_response import (
|
||||
ParsedContent,
|
||||
ParsedResponseOutputMessage,
|
||||
ParsedResponseFunctionToolCall,
|
||||
)
|
||||
|
||||
|
||||
class ResponseStream(Generic[TextFormatT]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
raw_stream: Stream[RawResponseStreamEvent],
|
||||
text_format: type[TextFormatT] | NotGiven,
|
||||
input_tools: Iterable[ToolParam] | NotGiven,
|
||||
) -> None:
|
||||
self._raw_stream = raw_stream
|
||||
self._response = raw_stream.response
|
||||
self._iterator = self.__stream__()
|
||||
self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
|
||||
|
||||
def __next__(self) -> ResponseStreamEvent[TextFormatT]:
|
||||
return self._iterator.__next__()
|
||||
|
||||
def __iter__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
|
||||
for item in self._iterator:
|
||||
yield item
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __stream__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
|
||||
for sse_event in self._raw_stream:
|
||||
events_to_fire = self._state.handle_event(sse_event)
|
||||
for event in events_to_fire:
|
||||
yield event
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close the response and release the connection.
|
||||
|
||||
Automatically called if the response body is read to completion.
|
||||
"""
|
||||
self._response.close()
|
||||
|
||||
def get_final_response(self) -> ParsedResponse[TextFormatT]:
|
||||
"""Waits until the stream has been read to completion and returns
|
||||
the accumulated `ParsedResponse` object.
|
||||
"""
|
||||
self.until_done()
|
||||
response = self._state._completed_response
|
||||
if not response:
|
||||
raise RuntimeError("Didn't receive a `response.completed` event.")
|
||||
|
||||
return response
|
||||
|
||||
def until_done(self) -> Self:
|
||||
"""Blocks until the stream has been consumed."""
|
||||
consume_sync_iterator(self)
|
||||
return self
|
||||
|
||||
|
||||
class ResponseStreamManager(Generic[TextFormatT]):
|
||||
def __init__(
|
||||
self,
|
||||
api_request: Callable[[], Stream[RawResponseStreamEvent]],
|
||||
*,
|
||||
text_format: type[TextFormatT] | NotGiven,
|
||||
input_tools: Iterable[ToolParam] | NotGiven,
|
||||
) -> None:
|
||||
self.__stream: ResponseStream[TextFormatT] | None = None
|
||||
self.__api_request = api_request
|
||||
self.__text_format = text_format
|
||||
self.__input_tools = input_tools
|
||||
|
||||
def __enter__(self) -> ResponseStream[TextFormatT]:
|
||||
raw_stream = self.__api_request()
|
||||
|
||||
self.__stream = ResponseStream(
|
||||
raw_stream=raw_stream,
|
||||
text_format=self.__text_format,
|
||||
input_tools=self.__input_tools,
|
||||
)
|
||||
|
||||
return self.__stream
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
if self.__stream is not None:
|
||||
self.__stream.close()
|
||||
|
||||
|
||||
class AsyncResponseStream(Generic[TextFormatT]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
raw_stream: AsyncStream[RawResponseStreamEvent],
|
||||
text_format: type[TextFormatT] | NotGiven,
|
||||
input_tools: Iterable[ToolParam] | NotGiven,
|
||||
) -> None:
|
||||
self._raw_stream = raw_stream
|
||||
self._response = raw_stream.response
|
||||
self._iterator = self.__stream__()
|
||||
self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
|
||||
|
||||
async def __anext__(self) -> ResponseStreamEvent[TextFormatT]:
|
||||
return await self._iterator.__anext__()
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
|
||||
async for item in self._iterator:
|
||||
yield item
|
||||
|
||||
async def __stream__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
|
||||
async for sse_event in self._raw_stream:
|
||||
events_to_fire = self._state.handle_event(sse_event)
|
||||
for event in events_to_fire:
|
||||
yield event
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the response and release the connection.
|
||||
|
||||
Automatically called if the response body is read to completion.
|
||||
"""
|
||||
await self._response.aclose()
|
||||
|
||||
async def get_final_response(self) -> ParsedResponse[TextFormatT]:
|
||||
"""Waits until the stream has been read to completion and returns
|
||||
the accumulated `ParsedResponse` object.
|
||||
"""
|
||||
await self.until_done()
|
||||
response = self._state._completed_response
|
||||
if not response:
|
||||
raise RuntimeError("Didn't receive a `response.completed` event.")
|
||||
|
||||
return response
|
||||
|
||||
async def until_done(self) -> Self:
|
||||
"""Blocks until the stream has been consumed."""
|
||||
await consume_async_iterator(self)
|
||||
return self
|
||||
|
||||
|
||||
class AsyncResponseStreamManager(Generic[TextFormatT]):
|
||||
def __init__(
|
||||
self,
|
||||
api_request: Awaitable[AsyncStream[RawResponseStreamEvent]],
|
||||
*,
|
||||
text_format: type[TextFormatT] | NotGiven,
|
||||
input_tools: Iterable[ToolParam] | NotGiven,
|
||||
) -> None:
|
||||
self.__stream: AsyncResponseStream[TextFormatT] | None = None
|
||||
self.__api_request = api_request
|
||||
self.__text_format = text_format
|
||||
self.__input_tools = input_tools
|
||||
|
||||
async def __aenter__(self) -> AsyncResponseStream[TextFormatT]:
|
||||
raw_stream = await self.__api_request
|
||||
|
||||
self.__stream = AsyncResponseStream(
|
||||
raw_stream=raw_stream,
|
||||
text_format=self.__text_format,
|
||||
input_tools=self.__input_tools,
|
||||
)
|
||||
|
||||
return self.__stream
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
if self.__stream is not None:
|
||||
await self.__stream.close()
|
||||
|
||||
|
||||
class ResponseStreamState(Generic[TextFormatT]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_tools: Iterable[ToolParam] | NotGiven,
|
||||
text_format: type[TextFormatT] | NotGiven,
|
||||
) -> None:
|
||||
self.__current_snapshot: ParsedResponseSnapshot | None = None
|
||||
self._completed_response: ParsedResponse[TextFormatT] | None = None
|
||||
self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
|
||||
self._text_format = text_format
|
||||
self._rich_text_format: type | NotGiven = text_format if inspect.isclass(text_format) else NOT_GIVEN
|
||||
|
||||
def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEvent[TextFormatT]]:
|
||||
self.__current_snapshot = snapshot = self.accumulate_event(event)
|
||||
|
||||
events: List[ResponseStreamEvent[TextFormatT]] = []
|
||||
|
||||
if event.type == "response.output_text.delta":
|
||||
output = snapshot.output[event.output_index]
|
||||
assert output.type == "message"
|
||||
|
||||
content = output.content[event.content_index]
|
||||
assert content.type == "output_text"
|
||||
|
||||
events.append(
|
||||
build(
|
||||
ResponseTextDeltaEvent,
|
||||
content_index=event.content_index,
|
||||
delta=event.delta,
|
||||
item_id=event.item_id,
|
||||
output_index=event.output_index,
|
||||
type="response.output_text.delta",
|
||||
snapshot=content.text,
|
||||
)
|
||||
)
|
||||
elif event.type == "response.output_text.done":
|
||||
output = snapshot.output[event.output_index]
|
||||
assert output.type == "message"
|
||||
|
||||
content = output.content[event.content_index]
|
||||
assert content.type == "output_text"
|
||||
|
||||
events.append(
|
||||
build(
|
||||
ResponseTextDoneEvent[TextFormatT],
|
||||
content_index=event.content_index,
|
||||
item_id=event.item_id,
|
||||
output_index=event.output_index,
|
||||
type="response.output_text.done",
|
||||
text=event.text,
|
||||
parsed=parse_text(event.text, text_format=self._text_format),
|
||||
)
|
||||
)
|
||||
elif event.type == "response.function_call_arguments.delta":
|
||||
output = snapshot.output[event.output_index]
|
||||
assert output.type == "function_call"
|
||||
|
||||
events.append(
|
||||
build(
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
delta=event.delta,
|
||||
item_id=event.item_id,
|
||||
output_index=event.output_index,
|
||||
type="response.function_call_arguments.delta",
|
||||
snapshot=output.arguments,
|
||||
)
|
||||
)
|
||||
|
||||
elif event.type == "response.completed":
|
||||
response = self._completed_response
|
||||
assert response is not None
|
||||
|
||||
events.append(
|
||||
build(
|
||||
ResponseCompletedEvent,
|
||||
type="response.completed",
|
||||
response=response,
|
||||
)
|
||||
)
|
||||
else:
|
||||
events.append(event)
|
||||
|
||||
return events
|
||||
|
||||
def accumulate_event(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
|
||||
snapshot = self.__current_snapshot
|
||||
if snapshot is None:
|
||||
return self._create_initial_response(event)
|
||||
|
||||
if event.type == "response.output_item.added":
|
||||
if event.item.type == "function_call":
|
||||
snapshot.output.append(
|
||||
construct_type_unchecked(
|
||||
type_=cast(Any, ParsedResponseFunctionToolCall), value=event.item.to_dict()
|
||||
)
|
||||
)
|
||||
elif event.item.type == "message":
|
||||
snapshot.output.append(
|
||||
construct_type_unchecked(type_=cast(Any, ParsedResponseOutputMessage), value=event.item.to_dict())
|
||||
)
|
||||
else:
|
||||
snapshot.output.append(event.item)
|
||||
elif event.type == "response.content_part.added":
|
||||
output = snapshot.output[event.output_index]
|
||||
if output.type == "message":
|
||||
output.content.append(
|
||||
construct_type_unchecked(type_=cast(Any, ParsedContent), value=event.part.to_dict())
|
||||
)
|
||||
elif event.type == "response.output_text.delta":
|
||||
output = snapshot.output[event.output_index]
|
||||
if output.type == "message":
|
||||
content = output.content[event.content_index]
|
||||
assert content.type == "output_text"
|
||||
content.text += event.delta
|
||||
elif event.type == "response.function_call_arguments.delta":
|
||||
output = snapshot.output[event.output_index]
|
||||
if output.type == "function_call":
|
||||
output.arguments += event.delta
|
||||
elif event.type == "response.completed":
|
||||
self._completed_response = parse_response(
|
||||
text_format=self._text_format,
|
||||
response=event.response,
|
||||
input_tools=self._input_tools,
|
||||
)
|
||||
|
||||
return snapshot
|
||||
|
||||
def _create_initial_response(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
|
||||
if event.type != "response.created":
|
||||
raise RuntimeError(f"Expected to have received `response.created` before `{event.type}`")
|
||||
|
||||
return construct_type_unchecked(type_=ParsedResponseSnapshot, value=event.response.to_dict())
|
||||
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ....types.responses import ParsedResponse
|
||||
|
||||
ParsedResponseSnapshot: TypeAlias = ParsedResponse[object]
|
||||
"""Snapshot type representing an in-progress accumulation of
|
||||
a `ParsedResponse` object.
|
||||
"""
|
||||
Reference in New Issue
Block a user