from __future__ import annotations import json from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason from ...typing import AsyncResult, Messages from ...requests.raise_for_status import raise_for_status from ...requests import StreamSession from ...errors import MissingAuthError class Openai(AsyncGeneratorProvider, ProviderModelMixin): url = "https://openai.com" working = True needs_auth = True supports_message_history = True supports_system_message = True @classmethod async def create_async_generator( cls, model: str, messages: Messages, proxy: str = None, timeout: int = 120, api_key: str = None, api_base: str = "https://api.openai.com/v1", temperature: float = None, max_tokens: int = None, top_p: float = None, stop: str = None, stream: bool = False, **kwargs ) -> AsyncResult: if api_key is None: raise MissingAuthError('Add a "api_key"') async with StreamSession( proxies={"all": proxy}, headers=cls.get_headers(api_key), timeout=timeout ) as session: data = { "messages": messages, "model": cls.get_model(model), "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "stop": stop, "stream": stream, } async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response: await raise_for_status(response) async for line in response.iter_lines(): if line.startswith(b"data: ") or not stream: async for chunk in cls.read_line(line[6:] if stream else line, stream): yield chunk @staticmethod async def read_line(line: str, stream: bool): if line == b"[DONE]": return choice = json.loads(line)["choices"][0] if stream and "content" in choice["delta"] and choice["delta"]["content"]: yield choice["delta"]["content"] elif not stream and "content" in choice["message"]: yield choice["message"]["content"] if "finish_reason" in choice and choice["finish_reason"] is not None: yield FinishReason(choice["finish_reason"]) @staticmethod def get_headers(api_key: str) -> dict: return { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", }