diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-01-14 15:32:51 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-14 15:32:51 +0100 |
commit | 1ca80ed48b55d6462b4bd445e66d4f7de7442c2b (patch) | |
tree | 05a94b53b83461b8249de965e093b4fd3722e2d1 | |
parent | Merge pull request #1466 from hlohaus/upp (diff) | |
parent | Change doctypes style to Google (diff) | |
download | gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.gz gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.bz2 gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.lz gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.xz gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.tar.zst gpt4free-1ca80ed48b55d6462b4bd445e66d4f7de7442c2b.zip |
Diffstat (limited to '')
-rw-r--r-- | .github/workflows/unittest.yml | 19 | ||||
-rw-r--r-- | etc/unittest/main.py | 73 | ||||
-rw-r--r-- | g4f/Provider/Bing.py | 230 | ||||
-rw-r--r-- | g4f/Provider/FreeChatgpt.py | 15 | ||||
-rw-r--r-- | g4f/Provider/Phind.py | 8 | ||||
-rw-r--r-- | g4f/Provider/base_provider.py | 198 | ||||
-rw-r--r-- | g4f/Provider/bing/conversation.py | 44 | ||||
-rw-r--r-- | g4f/Provider/bing/create_images.py | 224 | ||||
-rw-r--r-- | g4f/Provider/bing/upload_image.py | 188 | ||||
-rw-r--r-- | g4f/Provider/create_images.py | 61 | ||||
-rw-r--r-- | g4f/Provider/helper.py | 143 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 339 | ||||
-rw-r--r-- | g4f/Provider/retry_provider.py | 58 | ||||
-rw-r--r-- | g4f/__init__.py | 91 | ||||
-rw-r--r-- | g4f/base_provider.py | 81 | ||||
-rw-r--r-- | g4f/gui/client/css/style.css | 13 | ||||
-rw-r--r-- | g4f/gui/client/html/index.html | 24 | ||||
-rw-r--r-- | g4f/gui/client/js/chat.v1.js | 80 | ||||
-rw-r--r-- | g4f/gui/server/backend.py | 211 | ||||
-rw-r--r-- | g4f/image.py | 105 | ||||
-rw-r--r-- | g4f/models.py | 15 | ||||
-rw-r--r-- | g4f/requests.py | 48 | ||||
-rw-r--r-- | g4f/version.py | 93 | ||||
-rw-r--r-- | g4f/webdriver.py | 111 |
24 files changed, 1841 insertions, 631 deletions
diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml new file mode 100644 index 00000000..e895e969 --- /dev/null +++ b/.github/workflows/unittest.yml @@ -0,0 +1,19 @@ +name: Unittest + +on: [push] + +jobs: + build: + name: Build unittest + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.x" + cache: 'pip' + - name: Install requirements + - run: pip install -r requirements.txt + - name: Run tests + run: python -m etc.unittest.main
\ No newline at end of file diff --git a/etc/unittest/main.py b/etc/unittest/main.py new file mode 100644 index 00000000..61f4ffda --- /dev/null +++ b/etc/unittest/main.py @@ -0,0 +1,73 @@ +import sys +import pathlib +import unittest +from unittest.mock import MagicMock + +sys.path.append(str(pathlib.Path(__file__).parent.parent.parent)) + +import g4f +from g4f import ChatCompletion, get_last_provider +from g4f.gui.server.backend import Backend_Api, get_error_message +from g4f.base_provider import BaseProvider + +g4f.debug.logging = False + +class MockProvider(BaseProvider): + working = True + + def create_completion( + model, messages, stream, **kwargs + ): + yield "Mock" + + async def create_async( + model, messages, **kwargs + ): + return "Mock" + +class TestBackendApi(unittest.TestCase): + + def setUp(self): + self.app = MagicMock() + self.api = Backend_Api(self.app) + + def test_version(self): + response = self.api.get_version() + self.assertIn("version", response) + self.assertIn("latest_version", response) + +class TestChatCompletion(unittest.TestCase): + + def test_create(self): + messages = [{'role': 'user', 'content': 'Hello'}] + result = ChatCompletion.create(g4f.models.default, messages) + self.assertTrue("Hello" in result or "Good" in result) + + def test_get_last_provider(self): + messages = [{'role': 'user', 'content': 'Hello'}] + ChatCompletion.create(g4f.models.default, messages, MockProvider) + self.assertEqual(get_last_provider(), MockProvider) + + def test_bing_provider(self): + messages = [{'role': 'user', 'content': 'Hello'}] + provider = g4f.Provider.Bing + result = ChatCompletion.create(g4f.models.default, messages, provider) + self.assertTrue("Bing" in result) + +class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase): + + async def test_async(self): + messages = [{'role': 'user', 'content': 'Hello'}] + result = await ChatCompletion.create_async(g4f.models.default, messages, MockProvider) + self.assertTrue("Mock" in result) + +class TestUtilityFunctions(unittest.TestCase): + + def test_get_error_message(self): + g4f.debug.last_provider = g4f.Provider.Bing + exception = Exception("Message") + result = get_error_message(exception) + self.assertEqual("Bing: Exception: Message", result) + +if __name__ == '__main__': + unittest.main()
\ No newline at end of file diff --git a/g4f/Provider/Bing.py b/g4f/Provider/Bing.py index 50e29d23..34687866 100644 --- a/g4f/Provider/Bing.py +++ b/g4f/Provider/Bing.py @@ -15,12 +15,18 @@ from .bing.upload_image import upload_image from .bing.create_images import create_images from .bing.conversation import Conversation, create_conversation, delete_conversation -class Tones(): +class Tones: + """ + Defines the different tone options for the Bing provider. + """ creative = "Creative" balanced = "Balanced" precise = "Precise" class Bing(AsyncGeneratorProvider): + """ + Bing provider for generating responses using the Bing API. + """ url = "https://bing.com/chat" working = True supports_message_history = True @@ -38,6 +44,19 @@ class Bing(AsyncGeneratorProvider): web_search: bool = False, **kwargs ) -> AsyncResult: + """ + Creates an asynchronous generator for producing responses from Bing. + + :param model: The model to use. + :param messages: Messages to process. + :param proxy: Proxy to use for requests. + :param timeout: Timeout for requests. + :param cookies: Cookies for the session. + :param tone: The tone of the response. + :param image: The image type to be used. + :param web_search: Flag to enable or disable web search. + :return: An asynchronous result object. + """ if len(messages) < 2: prompt = messages[0]["content"] context = None @@ -56,65 +75,48 @@ class Bing(AsyncGeneratorProvider): return stream_generate(prompt, tone, image, context, proxy, cookies, web_search, gpt4_turbo, timeout) -def create_context(messages: Messages): +def create_context(messages: Messages) -> str: + """ + Creates a context string from a list of messages. + + :param messages: A list of message dictionaries. + :return: A string representing the context created from the messages. + """ return "".join( - f"[{message['role']}]" + ("(#message)" if message['role']!="system" else "(#additional_instructions)") + f"\n{message['content']}\n\n" + f"[{message['role']}]" + ("(#message)" if message['role'] != "system" else "(#additional_instructions)") + f"\n{message['content']}\n\n" for message in messages ) class Defaults: + """ + Default settings and configurations for the Bing provider. + """ delimiter = "\x1e" ip_address = f"13.{random.randint(104, 107)}.{random.randint(0, 255)}.{random.randint(0, 255)}" + # List of allowed message types for Bing responses allowedMessageTypes = [ - "ActionRequest", - "Chat", - "Context", - # "Disengaged", unwanted - "Progress", - # "AdsQuery", unwanted - "SemanticSerp", - "GenerateContentQuery", - "SearchQuery", - # The following message types should not be added so that it does not flood with - # useless messages (such as "Analyzing images" or "Searching the web") while it's retrieving the AI response - # "InternalSearchQuery", - # "InternalSearchResult", - "RenderCardRequest", - # "RenderContentRequest" + "ActionRequest", "Chat", "Context", "Progress", "SemanticSerp", + "GenerateContentQuery", "SearchQuery", "RenderCardRequest" ] sliceIds = [ - 'abv2', - 'srdicton', - 'convcssclick', - 'stylewv2', - 'contctxp2tf', - '802fluxv1pc_a', - '806log2sphs0', - '727savemem', - '277teditgnds0', - '207hlthgrds0', + 'abv2', 'srdicton', 'convcssclick', 'stylewv2', 'contctxp2tf', + '802fluxv1pc_a', '806log2sphs0', '727savemem', '277teditgnds0', '207hlthgrds0' ] + # Default location settings location = { - "locale": "en-US", - "market": "en-US", - "region": "US", - "locationHints": [ - { - "country": "United States", - "state": "California", - "city": "Los Angeles", - "timezoneoffset": 8, - "countryConfidence": 8, - "Center": {"Latitude": 34.0536909, "Longitude": -118.242766}, - "RegionType": 2, - "SourceType": 1, - } - ], + "locale": "en-US", "market": "en-US", "region": "US", + "locationHints": [{ + "country": "United States", "state": "California", "city": "Los Angeles", + "timezoneoffset": 8, "countryConfidence": 8, + "Center": {"Latitude": 34.0536909, "Longitude": -118.242766}, + "RegionType": 2, "SourceType": 1 + }], } + # Default headers for requests headers = { 'accept': '*/*', 'accept-language': 'en-US,en;q=0.9', @@ -139,23 +141,13 @@ class Defaults: } optionsSets = [ - 'nlu_direct_response_filter', - 'deepleo', - 'disable_emoji_spoken_text', - 'responsible_ai_policy_235', - 'enablemm', - 'iyxapbing', - 'iycapbing', - 'gencontentv3', - 'fluxsrtrunc', - 'fluxtrunc', - 'fluxv1', - 'rai278', - 'replaceurl', - 'eredirecturl', - 'nojbfedge' + 'nlu_direct_response_filter', 'deepleo', 'disable_emoji_spoken_text', + 'responsible_ai_policy_235', 'enablemm', 'iyxapbing', 'iycapbing', + 'gencontentv3', 'fluxsrtrunc', 'fluxtrunc', 'fluxv1', 'rai278', + 'replaceurl', 'eredirecturl', 'nojbfedge' ] + # Default cookies cookies = { 'SRCHD' : 'AF=NOFORM', 'PPLState' : '1', @@ -166,6 +158,12 @@ class Defaults: } def format_message(msg: dict) -> str: + """ + Formats a message dictionary into a JSON string with a delimiter. + + :param msg: The message dictionary to format. + :return: A formatted string representation of the message. + """ return json.dumps(msg, ensure_ascii=False) + Defaults.delimiter def create_message( @@ -177,7 +175,20 @@ def create_message( web_search: bool = False, gpt4_turbo: bool = False ) -> str: + """ + Creates a message for the Bing API with specified parameters. + + :param conversation: The current conversation object. + :param prompt: The user's input prompt. + :param tone: The desired tone for the response. + :param context: Additional context for the prompt. + :param image_response: The response if an image is involved. + :param web_search: Flag to enable web search. + :param gpt4_turbo: Flag to enable GPT-4 Turbo. + :return: A formatted string message for the Bing API. + """ options_sets = Defaults.optionsSets + # Append tone-specific options if tone == Tones.creative: options_sets.append("h3imaginative") elif tone == Tones.precise: @@ -186,54 +197,49 @@ def create_message( options_sets.append("galileo") else: options_sets.append("harmonyv3") - + + # Additional configurations based on parameters if not web_search: options_sets.append("nosearchall") - if gpt4_turbo: options_sets.append("dlgpt4t") - + request_id = str(uuid.uuid4()) struct = { - 'arguments': [ - { - 'source': 'cib', - 'optionsSets': options_sets, - 'allowedMessageTypes': Defaults.allowedMessageTypes, - 'sliceIds': Defaults.sliceIds, - 'traceId': os.urandom(16).hex(), - 'isStartOfSession': True, + 'arguments': [{ + 'source': 'cib', 'optionsSets': options_sets, + 'allowedMessageTypes': Defaults.allowedMessageTypes, + 'sliceIds': Defaults.sliceIds, + 'traceId': os.urandom(16).hex(), 'isStartOfSession': True, + 'requestId': request_id, + 'message': { + **Defaults.location, + 'author': 'user', + 'inputMethod': 'Keyboard', + 'text': prompt, + 'messageType': 'Chat', 'requestId': request_id, - 'message': {**Defaults.location, **{ - 'author': 'user', - 'inputMethod': 'Keyboard', - 'text': prompt, - 'messageType': 'Chat', - 'requestId': request_id, - 'messageId': request_id, - }}, - "verbosity": "verbose", - "scenario": "SERP", - "plugins":[ - {"id":"c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1} - ] if web_search else [], - 'tone': tone, - 'spokenTextMode': 'None', - 'conversationId': conversation.conversationId, - 'participant': { - 'id': conversation.clientId - }, - } - ], + 'messageId': request_id + }, + "verbosity": "verbose", + "scenario": "SERP", + "plugins": [{"id": "c310c353-b9f0-4d76-ab0d-1dd5e979cf68", "category": 1}] if web_search else [], + 'tone': tone, + 'spokenTextMode': 'None', + 'conversationId': conversation.conversationId, + 'participant': {'id': conversation.clientId}, + }], 'invocationId': '1', 'target': 'chat', 'type': 4 } - if image_response.get('imageUrl') and image_response.get('originalImageUrl'): + + if image_response and image_response.get('imageUrl') and image_response.get('originalImageUrl'): struct['arguments'][0]['message']['originalImageUrl'] = image_response.get('originalImageUrl') struct['arguments'][0]['message']['imageUrl'] = image_response.get('imageUrl') struct['arguments'][0]['experienceType'] = None struct['arguments'][0]['attachedFileInfo'] = {"fileName": None, "fileType": None} + if context: struct['arguments'][0]['previousMessages'] = [{ "author": "user", @@ -242,30 +248,46 @@ def create_message( "messageType": "Context", "messageId": "discover-web--page-ping-mriduna-----" }] + return format_message(struct) async def stream_generate( - prompt: str, - tone: str, - image: ImageType = None, - context: str = None, - proxy: str = None, - cookies: dict = None, - web_search: bool = False, - gpt4_turbo: bool = False, - timeout: int = 900 - ): + prompt: str, + tone: str, + image: ImageType = None, + context: str = None, + proxy: str = None, + cookies: dict = None, + web_search: bool = False, + gpt4_turbo: bool = False, + timeout: int = 900 +): + """ + Asynchronously streams generated responses from the Bing API. + + :param prompt: The user's input prompt. + :param tone: The desired tone for the response. + :param image: The image type involved in the response. + :param context: Additional context for the prompt. + :param proxy: Proxy settings for the request. + :param cookies: Cookies for the session. + :param web_search: Flag to enable web search. + :param gpt4_turbo: Flag to enable GPT-4 Turbo. + :param timeout: Timeout for the request. + :return: An asynchronous generator yielding responses. + """ headers = Defaults.headers if cookies: headers["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items()) + async with ClientSession( - timeout=ClientTimeout(total=timeout), - headers=headers + timeout=ClientTimeout(total=timeout), headers=headers ) as session: conversation = await create_conversation(session, proxy) image_response = await upload_image(session, image, tone, proxy) if image else None if image_response: yield image_response + try: async with session.ws_connect( 'wss://sydney.bing.com/sydney/ChatHub', @@ -289,7 +311,7 @@ async def stream_generate( if obj is None or not obj: continue response = json.loads(obj) - if response.get('type') == 1 and response['arguments'][0].get('messages'): + if response and response.get('type') == 1 and response['arguments'][0].get('messages'): message = response['arguments'][0]['messages'][0] image_response = None if (message['contentOrigin'] != 'Apology'): diff --git a/g4f/Provider/FreeChatgpt.py b/g4f/Provider/FreeChatgpt.py index 75514118..0f993690 100644 --- a/g4f/Provider/FreeChatgpt.py +++ b/g4f/Provider/FreeChatgpt.py @@ -1,16 +1,20 @@ from __future__ import annotations -import json +import json, random from aiohttp import ClientSession from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider - models = { - "claude-v2": "claude-2.0", - "gemini-pro": "google-gemini-pro" + "claude-v2": "claude-2.0", + "claude-v2.1":"claude-2.1", + "gemini-pro": "google-gemini-pro" } +urls = [ + "https://free.chatgpt.org.uk", + "https://ai.chatgpt.org.uk" +] class FreeChatgpt(AsyncGeneratorProvider): url = "https://free.chatgpt.org.uk" @@ -31,6 +35,7 @@ class FreeChatgpt(AsyncGeneratorProvider): model = models[model] elif not model: model = "gpt-3.5-turbo" + url = random.choice(urls) headers = { "Accept": "application/json, text/event-stream", "Content-Type":"application/json", @@ -55,7 +60,7 @@ class FreeChatgpt(AsyncGeneratorProvider): "top_p":1, **kwargs } - async with session.post(f'{cls.url}/api/openai/v1/chat/completions', json=data, proxy=proxy) as response: + async with session.post(f'{url}/api/openai/v1/chat/completions', json=data, proxy=proxy) as response: response.raise_for_status() started = False async for line in response.content: diff --git a/g4f/Provider/Phind.py b/g4f/Provider/Phind.py index bb216989..9e80baa9 100644 --- a/g4f/Provider/Phind.py +++ b/g4f/Provider/Phind.py @@ -59,12 +59,16 @@ class Phind(AsyncGeneratorProvider): "rewrittenQuestion": prompt, "challenge": 0.21132115912208504 } - async with session.post(f"{cls.url}/api/infer/followup/answer", headers=headers, json=data) as response: + async with session.post(f"https://https.api.phind.com/infer/", headers=headers, json=data) as response: new_line = False async for line in response.iter_lines(): if line.startswith(b"data: "): chunk = line[6:] - if chunk.startswith(b"<PHIND_METADATA>") or chunk.startswith(b"<PHIND_INDICATOR>"): + if chunk.startswith(b'<PHIND_DONE/>'): + break + if chunk.startswith(b'<PHIND_WEBRESULTS>') or chunk.startswith(b'<PHIND_FOLLOWUP>'): + pass + elif chunk.startswith(b"<PHIND_METADATA>") or chunk.startswith(b"<PHIND_INDICATOR>"): pass elif chunk: yield chunk.decode() diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index e7e88841..fd92d17a 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -1,28 +1,29 @@ from __future__ import annotations - import sys import asyncio -from asyncio import AbstractEventLoop +from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor -from abc import abstractmethod -from inspect import signature, Parameter -from .helper import get_event_loop, get_cookies, format_prompt -from ..typing import CreateResult, AsyncResult, Messages -from ..base_provider import BaseProvider +from abc import abstractmethod +from inspect import signature, Parameter +from .helper import get_event_loop, get_cookies, format_prompt +from ..typing import CreateResult, AsyncResult, Messages +from ..base_provider import BaseProvider if sys.version_info < (3, 10): NoneType = type(None) else: from types import NoneType -# Change event loop policy on windows for curl_cffi +# Set Windows event loop policy for better compatibility with asyncio and curl_cffi if sys.platform == 'win32': - if isinstance( - asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy - ): + if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) class AbstractProvider(BaseProvider): + """ + Abstract class for providing asynchronous functionality to derived classes. + """ + @classmethod async def create_async( cls, @@ -33,62 +34,67 @@ class AbstractProvider(BaseProvider): executor: ThreadPoolExecutor = None, **kwargs ) -> str: - if not loop: - loop = get_event_loop() + """ + Asynchronously creates a result based on the given model and messages. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. + executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + str: The created result as a string. + """ + loop = loop or get_event_loop() def create_func() -> str: - return "".join(cls.create_completion( - model, - messages, - False, - **kwargs - )) + return "".join(cls.create_completion(model, messages, False, **kwargs)) return await asyncio.wait_for( - loop.run_in_executor( - executor, - create_func - ), + loop.run_in_executor(executor, create_func), timeout=kwargs.get("timeout", 0) ) @classmethod @property def params(cls) -> str: - if issubclass(cls, AsyncGeneratorProvider): - sig = signature(cls.create_async_generator) - elif issubclass(cls, AsyncProvider): - sig = signature(cls.create_async) - else: - sig = signature(cls.create_completion) + """ + Returns the parameters supported by the provider. + + Args: + cls (type): The class on which this property is called. + + Returns: + str: A string listing the supported parameters. + """ + sig = signature( + cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else + cls.create_async if issubclass(cls, AsyncProvider) else + cls.create_completion + ) def get_type_name(annotation: type) -> str: - if hasattr(annotation, "__name__"): - annotation = annotation.__name__ - elif isinstance(annotation, NoneType): - annotation = "None" - return str(annotation) - + return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation) + args = "" for name, param in sig.parameters.items(): - if name in ("self", "kwargs"): - continue - if name == "stream" and not cls.supports_stream: + if name in ("self", "kwargs") or (name == "stream" and not cls.supports_stream): continue - if args: - args += ", " - args += "\n " + name - if name != "model" and param.annotation is not Parameter.empty: - args += f": {get_type_name(param.annotation)}" - if param.default == "": - args += ' = ""' - elif param.default is not Parameter.empty: - args += f" = {param.default}" + args += f"\n {name}" + args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else "" + args += f' = "{param.default}"' if param.default == "" else f" = {param.default}" if param.default is not Parameter.empty else "" return f"g4f.Provider.{cls.__name__} supports: ({args}\n)" class AsyncProvider(AbstractProvider): + """ + Provides asynchronous functionality for creating completions. + """ + @classmethod def create_completion( cls, @@ -99,8 +105,21 @@ class AsyncProvider(AbstractProvider): loop: AbstractEventLoop = None, **kwargs ) -> CreateResult: - if not loop: - loop = get_event_loop() + """ + Creates a completion result synchronously. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + stream (bool): Indicates whether to stream the results. Defaults to False. + loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + CreateResult: The result of the completion creation. + """ + loop = loop or get_event_loop() coro = cls.create_async(model, messages, **kwargs) yield loop.run_until_complete(coro) @@ -111,10 +130,27 @@ class AsyncProvider(AbstractProvider): messages: Messages, **kwargs ) -> str: + """ + Abstract method for creating asynchronous results. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process. + **kwargs: Additional keyword arguments. + + Raises: + NotImplementedError: If this method is not overridden in derived classes. + + Returns: + str: The created result as a string. + """ raise NotImplementedError() class AsyncGeneratorProvider(AsyncProvider): + """ + Provides asynchronous generator functionality for streaming results. + """ supports_stream = True @classmethod @@ -127,15 +163,24 @@ class AsyncGeneratorProvider(AsyncProvider): loop: AbstractEventLoop = None, **kwargs ) -> CreateResult: - if not loop: - loop = get_event_loop() - generator = cls.create_async_generator( - model, - messages, - stream=stream, - **kwargs - ) + """ + Creates a streaming completion result synchronously. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + stream (bool): Indicates whether to stream the results. Defaults to True. + loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + CreateResult: The result of the streaming completion creation. + """ + loop = loop or get_event_loop() + generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) gen = generator.__aiter__() + while True: try: yield loop.run_until_complete(gen.__anext__()) @@ -149,21 +194,44 @@ class AsyncGeneratorProvider(AsyncProvider): messages: Messages, **kwargs ) -> str: + """ + Asynchronously creates a result from a generator. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + **kwargs: Additional keyword arguments. + + Returns: + str: The created result as a string. + """ return "".join([ - chunk async for chunk in cls.create_async_generator( - model, - messages, - stream=False, - **kwargs - ) if not isinstance(chunk, Exception) + chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs) + if not isinstance(chunk, Exception) ]) @staticmethod @abstractmethod - def create_async_generator( + async def create_async_generator( model: str, messages: Messages, stream: bool = True, **kwargs ) -> AsyncResult: - raise NotImplementedError() + """ + Abstract method for creating an asynchronous generator. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process. + stream (bool): Indicates whether to stream the results. Defaults to True. + **kwargs: Additional keyword arguments. + + Raises: + NotImplementedError: If this method is not overridden in derived classes. + + Returns: + AsyncResult: An asynchronous generator yielding results. + """ + raise NotImplementedError()
\ No newline at end of file diff --git a/g4f/Provider/bing/conversation.py b/g4f/Provider/bing/conversation.py index 9e011c26..36ada3b0 100644 --- a/g4f/Provider/bing/conversation.py +++ b/g4f/Provider/bing/conversation.py @@ -1,13 +1,33 @@ from aiohttp import ClientSession - -class Conversation(): +class Conversation: + """ + Represents a conversation with specific attributes. + """ def __init__(self, conversationId: str, clientId: str, conversationSignature: str) -> None: + """ + Initialize a new conversation instance. + + Args: + conversationId (str): Unique identifier for the conversation. + clientId (str): Client identifier. + conversationSignature (str): Signature for the conversation. + """ self.conversationId = conversationId self.clientId = clientId self.conversationSignature = conversationSignature async def create_conversation(session: ClientSession, proxy: str = None) -> Conversation: + """ + Create a new conversation asynchronously. + + Args: + session (ClientSession): An instance of aiohttp's ClientSession. + proxy (str, optional): Proxy URL. Defaults to None. + + Returns: + Conversation: An instance representing the created conversation. + """ url = 'https://www.bing.com/turing/conversation/create?bundleVersion=1.1199.4' async with session.get(url, proxy=proxy) as response: try: @@ -24,12 +44,32 @@ async def create_conversation(session: ClientSession, proxy: str = None) -> Conv return Conversation(conversationId, clientId, conversationSignature) async def list_conversations(session: ClientSession) -> list: + """ + List all conversations asynchronously. + + Args: + session (ClientSession): An instance of aiohttp's ClientSession. + + Returns: + list: A list of conversations. + """ url = "https://www.bing.com/turing/conversation/chats" async with session.get(url) as response: response = await response.json() return response["chats"] async def delete_conversation(session: ClientSession, conversation: Conversation, proxy: str = None) -> bool: + """ + Delete a conversation asynchronously. + + Args: + session (ClientSession): An instance of aiohttp's ClientSession. + conversation (Conversation): The conversation to delete. + proxy (str, optional): Proxy URL. Defaults to None. + + Returns: + bool: True if deletion was successful, False otherwise. + """ url = "https://sydney.bing.com/sydney/DeleteSingleConversation" json = { "conversationId": conversation.conversationId, diff --git a/g4f/Provider/bing/create_images.py b/g4f/Provider/bing/create_images.py index a1ecace3..060cd184 100644 --- a/g4f/Provider/bing/create_images.py +++ b/g4f/Provider/bing/create_images.py @@ -1,9 +1,16 @@ +""" +This module provides functionalities for creating and managing images using Bing's service. +It includes functions for user login, session creation, image creation, and processing. +""" + import asyncio -import time, json, os +import time +import json +import os from aiohttp import ClientSession from bs4 import BeautifulSoup from urllib.parse import quote -from typing import Generator +from typing import Generator, List, Dict from ..create_images import CreateImagesProvider from ..helper import get_cookies, get_event_loop @@ -12,23 +19,47 @@ from ...base_provider import ProviderType from ...image import format_images_markdown BING_URL = "https://www.bing.com" +TIMEOUT_LOGIN = 1200 +TIMEOUT_IMAGE_CREATION = 300 +ERRORS = [ + "this prompt is being reviewed", + "this prompt has been blocked", + "we're working hard to offer image creator in more languages", + "we can't create your images right now" +] +BAD_IMAGES = [ + "https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png", + "https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg", +] + +def wait_for_login(driver: WebDriver, timeout: int = TIMEOUT_LOGIN) -> None: + """ + Waits for the user to log in within a given timeout period. -def wait_for_login(driver: WebDriver, timeout: int = 1200) -> None: + Args: + driver (WebDriver): Webdriver for browser automation. + timeout (int): Maximum waiting time in seconds. + + Raises: + RuntimeError: If the login process exceeds the timeout. + """ driver.get(f"{BING_URL}/") - value = driver.get_cookie("_U") - if value: - return start_time = time.time() - while True: + while not driver.get_cookie("_U"): if time.time() - start_time > timeout: raise RuntimeError("Timeout error") - value = driver.get_cookie("_U") - if value: - time.sleep(1) - return time.sleep(0.5) -def create_session(cookies: dict) -> ClientSession: +def create_session(cookies: Dict[str, str]) -> ClientSession: + """ + Creates a new client session with specified cookies and headers. + + Args: + cookies (Dict[str, str]): Cookies to be used for the session. + + Returns: + ClientSession: The created client session. + """ headers = { "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7", "accept-encoding": "gzip, deflate, br", @@ -47,28 +78,32 @@ def create_session(cookies: dict) -> ClientSession: "upgrade-insecure-requests": "1", } if cookies: - headers["cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items()) + headers["Cookie"] = "; ".join(f"{k}={v}" for k, v in cookies.items()) return ClientSession(headers=headers) -async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = 300) -> list: - url_encoded_prompt = quote(prompt) +async def create_images(session: ClientSession, prompt: str, proxy: str = None, timeout: int = TIMEOUT_IMAGE_CREATION) -> List[str]: + """ + Creates images based on a given prompt using Bing's service. + + Args: + session (ClientSession): Active client session. + prompt (str): Prompt to generate images. + proxy (str, optional): Proxy configuration. + timeout (int): Timeout for the request. + + Returns: + List[str]: A list of URLs to the created images. + + Raises: + RuntimeError: If image creation fails or times out. + """ + url_encoded_prompt = quote(prompt) payload = f"q={url_encoded_prompt}&rt=4&FORM=GENCRE" url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE" - async with session.post( - url, - allow_redirects=False, - data=payload, - timeout=timeout, - ) as response: + async with session.post(url, allow_redirects=False, data=payload, timeout=timeout) as response: response.raise_for_status() - errors = [ - "this prompt is being reviewed", - "this prompt has been blocked", - "we're working hard to offer image creator in more languages", - "we can't create your images right now" - ] text = (await response.text()).lower() - for error in errors: + for error in ERRORS: if error in text: raise RuntimeError(f"Create images failed: {error}") if response.status != 302: @@ -107,54 +142,109 @@ async def create_images(session: ClientSession, prompt: str, proxy: str = None, raise RuntimeError(error) return read_images(text) -def read_images(text: str) -> list: - html_soup = BeautifulSoup(text, "html.parser") - tags = html_soup.find_all("img") - image_links = [img["src"] for img in tags if "mimg" in img["class"]] - images = [link.split("?w=")[0] for link in image_links] - bad_images = [ - "https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png", - "https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg", - ] - if any(im in bad_images for im in images): +def read_images(html_content: str) -> List[str]: + """ + Extracts image URLs from the HTML content. + + Args: + html_content (str): HTML content containing image URLs. + + Returns: + List[str]: A list of image URLs. + """ + soup = BeautifulSoup(html_content, "html.parser") + tags = soup.find_all("img", class_="mimg") + images = [img["src"].split("?w=")[0] for img in tags] + if any(im in BAD_IMAGES for im in images): raise RuntimeError("Bad images found") if not images: raise RuntimeError("No images found") return images -async def create_images_markdown(cookies: dict, prompt: str, proxy: str = None) -> str: - session = create_session(cookies) - try: +async def create_images_markdown(cookies: Dict[str, str], prompt: str, proxy: str = None) -> str: + """ + Creates markdown formatted string with images based on the prompt. + + Args: + cookies (Dict[str, str]): Cookies to be used for the session. + prompt (str): Prompt to generate images. + proxy (str, optional): Proxy configuration. + + Returns: + str: Markdown formatted string with images. + """ + async with create_session(cookies) as session: images = await create_images(session, prompt, proxy) return format_images_markdown(images, prompt) - finally: - await session.close() -def get_cookies_from_browser(proxy: str = None) -> dict: - driver = get_browser(proxy=proxy) - try: +def get_cookies_from_browser(proxy: str = None) -> Dict[str, str]: + """ + Retrieves cookies from the browser using webdriver. + + Args: + proxy (str, optional): Proxy configuration. + + Returns: + Dict[str, str]: Retrieved cookies. + """ + with get_browser(proxy=proxy) as driver: wait_for_login(driver) + time.sleep(1) return get_driver_cookies(driver) - finally: - driver.quit() - -def create_completion(prompt: str, cookies: dict = None, proxy: str = None) -> Generator: - loop = get_event_loop() - if not cookies: - cookies = get_cookies(".bing.com") - if "_U" not in cookies: - login_url = os.environ.get("G4F_LOGIN_URL") - if login_url: - yield f"Please login: [Bing]({login_url})\n\n" - cookies = get_cookies_from_browser(proxy) - yield loop.run_until_complete(create_images_markdown(cookies, prompt, proxy)) - -async def create_async(prompt: str, cookies: dict = None, proxy: str = None) -> str: - if not cookies: - cookies = get_cookies(".bing.com") - if "_U" not in cookies: - cookies = get_cookies_from_browser(proxy) - return await create_images_markdown(cookies, prompt, proxy) + +class CreateImagesBing: + """A class for creating images using Bing.""" + + _cookies: Dict[str, str] = {} + + @classmethod + def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str, None, None]: + """ + Generator for creating imagecompletion based on a prompt. + + Args: + prompt (str): Prompt to generate images. + cookies (Dict[str, str], optional): Cookies for the session. If None, cookies are retrieved automatically. + proxy (str, optional): Proxy configuration. + + Yields: + Generator[str, None, None]: The final output as markdown formatted string with images. + """ + loop = get_event_loop() + cookies = cookies or cls._cookies or get_cookies(".bing.com") + if "_U" not in cookies: + login_url = os.environ.get("G4F_LOGIN_URL") + if login_url: + yield f"Please login: [Bing]({login_url})\n\n" + cls._cookies = cookies = get_cookies_from_browser(proxy) + yield loop.run_until_complete(create_images_markdown(cookies, prompt, proxy)) + + @classmethod + async def create_async(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> str: + """ + Asynchronously creates a markdown formatted string with images based on the prompt. + + Args: + prompt (str): Prompt to generate images. + cookies (Dict[str, str], optional): Cookies for the session. If None, cookies are retrieved automatically. + proxy (str, optional): Proxy configuration. + + Returns: + str: Markdown formatted string with images. + """ + cookies = cookies or cls._cookies or get_cookies(".bing.com") + if "_U" not in cookies: + cls._cookies = cookies = get_cookies_from_browser(proxy) + return await create_images_markdown(cookies, prompt, proxy) def patch_provider(provider: ProviderType) -> CreateImagesProvider: - return CreateImagesProvider(provider, create_completion, create_async)
\ No newline at end of file + """ + Patches a provider to include image creation capabilities. + + Args: + provider (ProviderType): The provider to be patched. + + Returns: + CreateImagesProvider: The patched provider with image creation capabilities. + """ + return CreateImagesProvider(provider, CreateImagesBing.create_completion, CreateImagesBing.create_async)
\ No newline at end of file diff --git a/g4f/Provider/bing/upload_image.py b/g4f/Provider/bing/upload_image.py index 1af902ef..4d70659f 100644 --- a/g4f/Provider/bing/upload_image.py +++ b/g4f/Provider/bing/upload_image.py @@ -1,64 +1,107 @@ -from __future__ import annotations +""" +Module to handle image uploading and processing for Bing AI integrations. +""" +from __future__ import annotations import string import random import json import math -from ...typing import ImageType from aiohttp import ClientSession +from PIL import Image + +from ...typing import ImageType, Tuple from ...image import to_image, process_image, to_base64, ImageResponse -image_config = { +IMAGE_CONFIG = { "maxImagePixels": 360000, "imageCompressionRate": 0.7, - "enableFaceBlurDebug": 0, + "enableFaceBlurDebug": False, } async def upload_image( - session: ClientSession, - image: ImageType, - tone: str, + session: ClientSession, + image_data: ImageType, + tone: str, proxy: str = None ) -> ImageResponse: - image = to_image(image) - width, height = image.size - max_image_pixels = image_config['maxImagePixels'] - if max_image_pixels / (width * height) < 1: - new_width = int(width * math.sqrt(max_image_pixels / (width * height))) - new_height = int(height * math.sqrt(max_image_pixels / (width * height))) - else: - new_width = width - new_height = height - new_img = process_image(image, new_width, new_height) - new_img_binary_data = to_base64(new_img, image_config['imageCompressionRate']) - data, boundary = build_image_upload_api_payload(new_img_binary_data, tone) - headers = session.headers.copy() - headers["content-type"] = f'multipart/form-data; boundary={boundary}' - headers["referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx' - headers["origin"] = 'https://www.bing.com' + """ + Uploads an image to Bing's AI service and returns the image response. + + Args: + session (ClientSession): The active session. + image_data (bytes): The image data to be uploaded. + tone (str): The tone of the conversation. + proxy (str, optional): Proxy if any. Defaults to None. + + Raises: + RuntimeError: If the image upload fails. + + Returns: + ImageResponse: The response from the image upload. + """ + image = to_image(image_data) + new_width, new_height = calculate_new_dimensions(image) + processed_img = process_image(image, new_width, new_height) + img_binary_data = to_base64(processed_img, IMAGE_CONFIG['imageCompressionRate']) + + data, boundary = build_image_upload_payload(img_binary_data, tone) + headers = prepare_headers(session, boundary) + async with session.post("https://www.bing.com/images/kblob", data=data, headers=headers, proxy=proxy) as response: if response.status != 200: raise RuntimeError("Failed to upload image.") - image_info = await response.json() - if not image_info.get('blobId'): - raise RuntimeError("Failed to parse image info.") - result = {'bcid': image_info.get('blobId', "")} - result['blurredBcid'] = image_info.get('processedBlobId', "") - if result['blurredBcid'] != "": - result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['blurredBcid'] - elif result['bcid'] != "": - result["imageUrl"] = "https://www.bing.com/images/blob?bcid=" + result['bcid'] - result['originalImageUrl'] = ( - "https://www.bing.com/images/blob?bcid=" - + result['blurredBcid'] - if image_config["enableFaceBlurDebug"] - else "https://www.bing.com/images/blob?bcid=" - + result['bcid'] - ) - return ImageResponse(result["imageUrl"], "", result) - -def build_image_upload_api_payload(image_bin: str, tone: str): - payload = { + return parse_image_response(await response.json()) + +def calculate_new_dimensions(image: Image.Image) -> Tuple[int, int]: + """ + Calculates the new dimensions for the image based on the maximum allowed pixels. + + Args: + image (Image): The PIL Image object. + + Returns: + Tuple[int, int]: The new width and height for the image. + """ + width, height = image.size + max_image_pixels = IMAGE_CONFIG['maxImagePixels'] + if max_image_pixels / (width * height) < 1: + scale_factor = math.sqrt(max_image_pixels / (width * height)) + return int(width * scale_factor), int(height * scale_factor) + return width, height + +def build_image_upload_payload(image_bin: str, tone: str) -> Tuple[str, str]: + """ + Builds the payload for image uploading. + + Args: + image_bin (str): Base64 encoded image binary data. + tone (str): The tone of the conversation. + + Returns: + Tuple[str, str]: The data and boundary for the payload. + """ + boundary = "----WebKitFormBoundary" + ''.join(random.choices(string.ascii_letters + string.digits, k=16)) + data = f"--{boundary}\r\n" \ + f"Content-Disposition: form-data; name=\"knowledgeRequest\"\r\n\r\n" \ + f"{json.dumps(build_knowledge_request(tone), ensure_ascii=False)}\r\n" \ + f"--{boundary}\r\n" \ + f"Content-Disposition: form-data; name=\"imageBase64\"\r\n\r\n" \ + f"{image_bin}\r\n" \ + f"--{boundary}--\r\n" + return data, boundary + +def build_knowledge_request(tone: str) -> dict: + """ + Builds the knowledge request payload. + + Args: + tone (str): The tone of the conversation. + + Returns: + dict: The knowledge request payload. + """ + return { 'invokedSkills': ["ImageById"], 'subscriptionId': "Bing.Chat.Multimodal", 'invokedSkillsRequestData': { @@ -69,21 +112,46 @@ def build_image_upload_api_payload(image_bin: str, tone: str): 'convotone': tone } } - knowledge_request = { - 'imageInfo': {}, - 'knowledgeRequest': payload - } - boundary="----WebKitFormBoundary" + ''.join(random.choices(string.ascii_letters + string.digits, k=16)) - data = ( - f'--{boundary}' - + '\r\nContent-Disposition: form-data; name="knowledgeRequest"\r\n\r\n' - + json.dumps(knowledge_request, ensure_ascii=False) - + "\r\n--" - + boundary - + '\r\nContent-Disposition: form-data; name="imageBase64"\r\n\r\n' - + image_bin - + "\r\n--" - + boundary - + "--\r\n" + +def prepare_headers(session: ClientSession, boundary: str) -> dict: + """ + Prepares the headers for the image upload request. + + Args: + session (ClientSession): The active session. + boundary (str): The boundary string for the multipart/form-data. + + Returns: + dict: The headers for the request. + """ + headers = session.headers.copy() + headers["Content-Type"] = f'multipart/form-data; boundary={boundary}' + headers["Referer"] = 'https://www.bing.com/search?q=Bing+AI&showconv=1&FORM=hpcodx' + headers["Origin"] = 'https://www.bing.com' + return headers + +def parse_image_response(response: dict) -> ImageResponse: + """ + Parses the response from the image upload. + + Args: + response (dict): The response dictionary. + + Raises: + RuntimeError: If parsing the image info fails. + + Returns: + ImageResponse: The parsed image response. + """ + if not response.get('blobId'): + raise RuntimeError("Failed to parse image info.") + + result = {'bcid': response.get('blobId', ""), 'blurredBcid': response.get('processedBlobId', "")} + result["imageUrl"] = f"https://www.bing.com/images/blob?bcid={result['blurredBcid'] or result['bcid']}" + + result['originalImageUrl'] = ( + f"https://www.bing.com/images/blob?bcid={result['blurredBcid']}" + if IMAGE_CONFIG["enableFaceBlurDebug"] else + f"https://www.bing.com/images/blob?bcid={result['bcid']}" ) - return data, boundary
\ No newline at end of file + return ImageResponse(result["imageUrl"], "", result)
\ No newline at end of file diff --git a/g4f/Provider/create_images.py b/g4f/Provider/create_images.py index f8a0442d..b8bcbde3 100644 --- a/g4f/Provider/create_images.py +++ b/g4f/Provider/create_images.py @@ -8,13 +8,31 @@ from ..base_provider import BaseProvider, ProviderType system_message = """ You can generate custom images with the DALL-E 3 image generator. -To generate a image with a prompt, do this: +To generate an image with a prompt, do this: <img data-prompt=\"keywords for the image\"> Don't use images with data uri. It is important to use a prompt instead. <img data-prompt=\"image caption\"> """ class CreateImagesProvider(BaseProvider): + """ + Provider class for creating images based on text prompts. + + This provider handles image creation requests embedded within message content, + using provided image creation functions. + + Attributes: + provider (ProviderType): The underlying provider to handle non-image related tasks. + create_images (callable): A function to create images synchronously. + create_images_async (callable): A function to create images asynchronously. + system_message (str): A message that explains the image creation capability. + include_placeholder (bool): Flag to determine whether to include the image placeholder in the output. + __name__ (str): Name of the provider. + url (str): URL of the provider. + working (bool): Indicates if the provider is operational. + supports_stream (bool): Indicates if the provider supports streaming. + """ + def __init__( self, provider: ProviderType, @@ -23,6 +41,16 @@ class CreateImagesProvider(BaseProvider): system_message: str = system_message, include_placeholder: bool = True ) -> None: + """ + Initializes the CreateImagesProvider. + + Args: + provider (ProviderType): The underlying provider. + create_images (callable): Function to create images synchronously. + create_async (callable): Function to create images asynchronously. + system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message. + include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True. + """ self.provider = provider self.create_images = create_images self.create_images_async = create_async @@ -40,6 +68,22 @@ class CreateImagesProvider(BaseProvider): stream: bool = False, **kwargs ) -> CreateResult: + """ + Creates a completion result, processing any image creation prompts found within the messages. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process, which may contain image prompts. + stream (bool, optional): Indicates whether to stream the results. Defaults to False. + **kwargs: Additional keywordarguments for the provider. + + Yields: + CreateResult: Yields chunks of the processed messages, including image data if applicable. + + Note: + This method processes messages to detect image creation prompts. When such a prompt is found, + it calls the synchronous image creation function and includes the resulting image in the output. + """ messages.insert(0, {"role": "system", "content": self.system_message}) buffer = "" for chunk in self.provider.create_completion(model, messages, stream, **kwargs): @@ -71,6 +115,21 @@ class CreateImagesProvider(BaseProvider): messages: Messages, **kwargs ) -> str: + """ + Asynchronously creates a response, processing any image creation prompts found within the messages. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process, which may contain image prompts. + **kwargs: Additional keyword arguments for the provider. + + Returns: + str: The processed response string, including asynchronously generated image data if applicable. + + Note: + This method processes messages to detect image creation prompts. When such a prompt is found, + it calls the asynchronous image creation function and includes the resulting image in the output. + """ messages.insert(0, {"role": "system", "content": self.system_message}) response = await self.provider.create_async(model, messages, **kwargs) matches = re.findall(r'(<img data-prompt="(.*?)">)', response) diff --git a/g4f/Provider/helper.py b/g4f/Provider/helper.py index 81f417dd..fce1ee6f 100644 --- a/g4f/Provider/helper.py +++ b/g4f/Provider/helper.py @@ -1,36 +1,31 @@ from __future__ import annotations import asyncio -import webbrowser +import os import random -import string import secrets -import os -from os import path +import string from asyncio import AbstractEventLoop, BaseEventLoop from platformdirs import user_config_dir from browser_cookie3 import ( - chrome, - chromium, - opera, - opera_gx, - brave, - edge, - vivaldi, - firefox, - _LinuxPasswordManager + chrome, chromium, opera, opera_gx, + brave, edge, vivaldi, firefox, + _LinuxPasswordManager, BrowserCookieError ) - from ..typing import Dict, Messages from .. import debug -# Local Cookie Storage +# Global variable to store cookies _cookies: Dict[str, Dict[str, str]] = {} -# If loop closed or not set, create new event loop. -# If event loop is already running, handle nested event loops. -# If "nest_asyncio" is installed, patch the event loop. def get_event_loop() -> AbstractEventLoop: + """ + Get the current asyncio event loop. If the loop is closed or not set, create a new event loop. + If a loop is running, handle nested event loops. Patch the loop if 'nest_asyncio' is installed. + + Returns: + AbstractEventLoop: The current or new event loop. + """ try: loop = asyncio.get_event_loop() if isinstance(loop, BaseEventLoop): @@ -39,61 +34,50 @@ def get_event_loop() -> AbstractEventLoop: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - # Is running event loop asyncio.get_running_loop() if not hasattr(loop.__class__, "_nest_patched"): import nest_asyncio nest_asyncio.apply(loop) except RuntimeError: - # No running event loop pass except ImportError: raise RuntimeError( - 'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.' + 'Use "create_async" instead of "create" function in a running event loop. Or install "nest_asyncio" package.' ) return loop -def init_cookies(): - urls = [ - 'https://chat-gpt.org', - 'https://www.aitianhu.com', - 'https://chatgptfree.ai', - 'https://gptchatly.com', - 'https://bard.google.com', - 'https://huggingface.co/chat', - 'https://open-assistant.io/chat' - ] - - browsers = ['google-chrome', 'chrome', 'firefox', 'safari'] - - def open_urls_in_browser(browser): - b = webbrowser.get(browser) - for url in urls: - b.open(url, new=0, autoraise=True) - - for browser in browsers: - try: - open_urls_in_browser(browser) - break - except webbrowser.Error: - continue - -# Check for broken dbus address in docker image if os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null": _LinuxPasswordManager.get_password = lambda a, b: b"secret" - -# Load cookies for a domain from all supported browsers. -# Cache the results in the "_cookies" variable. -def get_cookies(domain_name=''): + +def get_cookies(domain_name: str = '') -> Dict[str, str]: + """ + Load cookies for a given domain from all supported browsers and cache the results. + + Args: + domain_name (str): The domain for which to load cookies. + + Returns: + Dict[str, str]: A dictionary of cookie names and values. + """ if domain_name in _cookies: return _cookies[domain_name] - def g4f(domain_name): - user_data_dir = user_config_dir("g4f") - cookie_file = path.join(user_data_dir, "Default", "Cookies") - return [] if not path.exists(cookie_file) else chrome(cookie_file, domain_name) + + cookies = _load_cookies_from_browsers(domain_name) + _cookies[domain_name] = cookies + return cookies + +def _load_cookies_from_browsers(domain_name: str) -> Dict[str, str]: + """ + Helper function to load cookies from various browsers. + + Args: + domain_name (str): The domain for which to load cookies. + Returns: + Dict[str, str]: A dictionary of cookie names and values. + """ cookies = {} - for cookie_fn in [g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]: + for cookie_fn in [_g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]: try: cookie_jar = cookie_fn(domain_name=domain_name) if len(cookie_jar) and debug.logging: @@ -101,13 +85,38 @@ def get_cookies(domain_name=''): for cookie in cookie_jar: if cookie.name not in cookies: cookies[cookie.name] = cookie.value - except: + except BrowserCookieError: pass - _cookies[domain_name] = cookies - return _cookies[domain_name] + except Exception as e: + if debug.logging: + print(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}") + return cookies + +def _g4f(domain_name: str) -> list: + """ + Load cookies from the 'g4f' browser (if exists). + + Args: + domain_name (str): The domain for which to load cookies. + Returns: + list: List of cookies. + """ + user_data_dir = user_config_dir("g4f") + cookie_file = os.path.join(user_data_dir, "Default", "Cookies") + return [] if not os.path.exists(cookie_file) else chrome(cookie_file, domain_name) def format_prompt(messages: Messages, add_special_tokens=False) -> str: + """ + Format a series of messages into a single string, optionally adding special tokens. + + Args: + messages (Messages): A list of message dictionaries, each containing 'role' and 'content'. + add_special_tokens (bool): Whether to add special formatting tokens. + + Returns: + str: A formatted string containing all messages. + """ if not add_special_tokens and len(messages) <= 1: return messages[0]["content"] formatted = "\n".join([ @@ -116,12 +125,26 @@ def format_prompt(messages: Messages, add_special_tokens=False) -> str: ]) return f"{formatted}\nAssistant:" - def get_random_string(length: int = 10) -> str: + """ + Generate a random string of specified length, containing lowercase letters and digits. + + Args: + length (int, optional): Length of the random string to generate. Defaults to 10. + + Returns: + str: A random string of the specified length. + """ return ''.join( random.choice(string.ascii_lowercase + string.digits) for _ in range(length) ) def get_random_hex() -> str: + """ + Generate a random hexadecimal string of a fixed length. + + Returns: + str: A random hexadecimal string of 32 characters (16 bytes). + """ return secrets.token_hex(16).zfill(32)
\ No newline at end of file diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index a790f0de..7d352a46 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -1,6 +1,9 @@ from __future__ import annotations +import asyncio +import uuid +import json +import os -import uuid, json, asyncio, os from py_arkose_generator.arkose import get_values_for_request from async_property import async_cached_property from selenium.webdriver.common.by import By @@ -14,7 +17,8 @@ from ...typing import AsyncResult, Messages from ...requests import StreamSession from ...image import to_image, to_bytes, ImageType, ImageResponse -models = { +# Aliases for model names +MODELS = { "gpt-3.5": "text-davinci-002-render-sha", "gpt-3.5-turbo": "text-davinci-002-render-sha", "gpt-4": "gpt-4", @@ -22,13 +26,15 @@ models = { } class OpenaiChat(AsyncGeneratorProvider): - url = "https://chat.openai.com" - working = True - needs_auth = True + """A class for creating and managing conversations with OpenAI chat service""" + + url = "https://chat.openai.com" + working = True + needs_auth = True supports_gpt_35_turbo = True - supports_gpt_4 = True - _cookies: dict = {} - _default_model: str = None + supports_gpt_4 = True + _cookies: dict = {} + _default_model: str = None @classmethod async def create( @@ -43,6 +49,23 @@ class OpenaiChat(AsyncGeneratorProvider): image: ImageType = None, **kwargs ) -> Response: + """Create a new conversation or continue an existing one + + Args: + prompt: The user input to start or continue the conversation + model: The name of the model to use for generating responses + messages: The list of previous messages in the conversation + history_disabled: A flag indicating if the history and training should be disabled + action: The type of action to perform, either "next", "continue", or "variant" + conversation_id: The ID of the existing conversation, if any + parent_id: The ID of the parent message, if any + image: The image to include in the user input, if any + **kwargs: Additional keyword arguments to pass to the generator + + Returns: + A Response object that contains the generator, action, messages, and options + """ + # Add the user input to the messages list if prompt: messages.append({ "role": "user", @@ -67,20 +90,33 @@ class OpenaiChat(AsyncGeneratorProvider): ) @classmethod - async def upload_image( + async def _upload_image( cls, session: StreamSession, headers: dict, image: ImageType ) -> ImageResponse: + """Upload an image to the service and get the download URL + + Args: + session: The StreamSession object to use for requests + headers: The headers to include in the requests + image: The image to upload, either a PIL Image object or a bytes object + + Returns: + An ImageResponse object that contains the download URL, file name, and other data + """ + # Convert the image to a PIL Image object and get the extension image = to_image(image) extension = image.format.lower() + # Convert the image to a bytes object and get the size data_bytes = to_bytes(image) data = { "file_name": f"{image.width}x{image.height}.{extension}", "file_size": len(data_bytes), "use_case": "multimodal" } + # Post the image data to the service and get the image data async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response: response.raise_for_status() image_data = { @@ -91,6 +127,7 @@ class OpenaiChat(AsyncGeneratorProvider): "height": image.height, "width": image.width } + # Put the image bytes to the upload URL and check the status async with session.put( image_data["upload_url"], data=data_bytes, @@ -100,6 +137,7 @@ class OpenaiChat(AsyncGeneratorProvider): } ) as response: response.raise_for_status() + # Post the file ID to the service and get the download URL async with session.post( f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded", json={}, @@ -110,24 +148,45 @@ class OpenaiChat(AsyncGeneratorProvider): return ImageResponse(download_url, image_data["file_name"], image_data) @classmethod - async def get_default_model(cls, session: StreamSession, headers: dict): + async def _get_default_model(cls, session: StreamSession, headers: dict): + """Get the default model name from the service + + Args: + session: The StreamSession object to use for requests + headers: The headers to include in the requests + + Returns: + The default model name as a string + """ + # Check the cache for the default model if cls._default_model: - model = cls._default_model - else: - async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response: - data = await response.json() - if "categories" in data: - model = data["categories"][-1]["default_model"] - else: - RuntimeError(f"Response: {data}") - cls._default_model = model - return model + return cls._default_model + # Get the models data from the service + async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response: + data = await response.json() + if "categories" in data: + cls._default_model = data["categories"][-1]["default_model"] + else: + raise RuntimeError(f"Response: {data}") + return cls._default_model @classmethod - def create_messages(cls, prompt: str, image_response: ImageResponse = None): + def _create_messages(cls, prompt: str, image_response: ImageResponse = None): + """Create a list of messages for the user input + + Args: + prompt: The user input as a string + image_response: The image response object, if any + + Returns: + A list of messages with the user input and the image, if any + """ + # Check if there is an image response if not image_response: + # Create a content object with the text type and the prompt content = {"content_type": "text", "parts": [prompt]} else: + # Create a content object with the multimodal text type and the image and the prompt content = { "content_type": "multimodal_text", "parts": [{ @@ -137,12 +196,15 @@ class OpenaiChat(AsyncGeneratorProvider): "width": image_response.get("width"), }, prompt] } + # Create a message object with the user role and the content messages = [{ "id": str(uuid.uuid4()), "author": {"role": "user"}, "content": content, }] + # Check if there is an image response if image_response: + # Add the metadata object with the attachments messages[0]["metadata"] = { "attachments": [{ "height": image_response.get("height"), @@ -156,19 +218,38 @@ class OpenaiChat(AsyncGeneratorProvider): return messages @classmethod - async def get_image_response(cls, session: StreamSession, headers: dict, line: dict): - if "parts" in line["message"]["content"]: - part = line["message"]["content"]["parts"][0] - if "asset_pointer" in part and part["metadata"]: - file_id = part["asset_pointer"].split("file-service://", 1)[1] - prompt = part["metadata"]["dalle"]["prompt"] - async with session.get( - f"{cls.url}/backend-api/files/{file_id}/download", - headers=headers - ) as response: - response.raise_for_status() - download_url = (await response.json())["download_url"] - return ImageResponse(download_url, prompt) + async def _get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse: + """ + Retrieves the image response based on the message content. + + :param session: The StreamSession object. + :param headers: HTTP headers for the request. + :param line: The line of response containing image information. + :return: An ImageResponse object with the image details. + """ + if "parts" not in line["message"]["content"]: + return + first_part = line["message"]["content"]["parts"][0] + if "asset_pointer" not in first_part or "metadata" not in first_part: + return + file_id = first_part["asset_pointer"].split("file-service://", 1)[1] + prompt = first_part["metadata"]["dalle"]["prompt"] + try: + async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=headers) as response: + response.raise_for_status() + download_url = (await response.json())["download_url"] + return ImageResponse(download_url, prompt) + except Exception as e: + raise RuntimeError(f"Error in downloading image: {e}") + + @classmethod + async def _delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str): + async with session.patch( + f"{cls.url}/backend-api/conversation/{conversation_id}", + json={"is_visible": False}, + headers=headers + ) as response: + response.raise_for_status() @classmethod async def create_async_generator( @@ -188,26 +269,47 @@ class OpenaiChat(AsyncGeneratorProvider): response_fields: bool = False, **kwargs ) -> AsyncResult: - if model in models: - model = models[model] + """ + Create an asynchronous generator for the conversation. + + Args: + model (str): The model name. + messages (Messages): The list of previous messages. + proxy (str): Proxy to use for requests. + timeout (int): Timeout for requests. + access_token (str): Access token for authentication. + cookies (dict): Cookies to use for authentication. + auto_continue (bool): Flag to automatically continue the conversation. + history_disabled (bool): Flag to disable history and training. + action (str): Type of action ('next', 'continue', 'variant'). + conversation_id (str): ID of the conversation. + parent_id (str): ID of the parent message. + image (ImageType): Image to include in the conversation. + response_fields (bool): Flag to include response fields in the output. + **kwargs: Additional keyword arguments. + + Yields: + AsyncResult: Asynchronous results from the generator. + + Raises: + RuntimeError: If an error occurs during processing. + """ + model = MODELS.get(model, model) if not parent_id: parent_id = str(uuid.uuid4()) if not cookies: - cookies = cls._cookies - if not access_token: - if not cookies: - cls._cookies = cookies = get_cookies("chat.openai.com") - if "access_token" in cookies: - access_token = cookies["access_token"] + cookies = cls._cookies or get_cookies("chat.openai.com") + if not access_token and "access_token" in cookies: + access_token = cookies["access_token"] if not access_token: login_url = os.environ.get("G4F_LOGIN_URL") if login_url: yield f"Please login: [ChatGPT]({login_url})\n\n" - access_token, cookies = cls.browse_access_token(proxy) + access_token, cookies = cls._browse_access_token(proxy) cls._cookies = cookies - headers = { - "Authorization": f"Bearer {access_token}", - } + + headers = {"Authorization": f"Bearer {access_token}"} + async with StreamSession( proxies={"https": proxy}, impersonate="chrome110", @@ -215,11 +317,11 @@ class OpenaiChat(AsyncGeneratorProvider): cookies=dict([(name, value) for name, value in cookies.items() if name == "_puid"]) ) as session: if not model: - model = await cls.get_default_model(session, headers) + model = await cls._get_default_model(session, headers) try: image_response = None if image: - image_response = await cls.upload_image(session, headers, image) + image_response = await cls._upload_image(session, headers, image) yield image_response except Exception as e: yield e @@ -227,7 +329,7 @@ class OpenaiChat(AsyncGeneratorProvider): while not end_turn.is_end: data = { "action": action, - "arkose_token": await cls.get_arkose_token(session), + "arkose_token": await cls._get_arkose_token(session), "conversation_id": conversation_id, "parent_message_id": parent_id, "model": model, @@ -235,7 +337,7 @@ class OpenaiChat(AsyncGeneratorProvider): } if action != "continue": prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"] - data["messages"] = cls.create_messages(prompt, image_response) + data["messages"] = cls._create_messages(prompt, image_response) async with session.post( f"{cls.url}/backend-api/conversation", json=data, @@ -261,62 +363,80 @@ class OpenaiChat(AsyncGeneratorProvider): if "message_type" not in line["message"]["metadata"]: continue try: - image_response = await cls.get_image_response(session, headers, line) + image_response = await cls._get_generated_image(session, headers, line) if image_response: yield image_response except Exception as e: yield e if line["message"]["author"]["role"] != "assistant": continue - if line["message"]["metadata"]["message_type"] in ("next", "continue", "variant"): - conversation_id = line["conversation_id"] - parent_id = line["message"]["id"] - if response_fields: - response_fields = False - yield ResponseFields(conversation_id, parent_id, end_turn) - if "parts" in line["message"]["content"]: - new_message = line["message"]["content"]["parts"][0] - if len(new_message) > last_message: - yield new_message[last_message:] - last_message = len(new_message) + if line["message"]["content"]["content_type"] != "text": + continue + if line["message"]["metadata"]["message_type"] not in ("next", "continue", "variant"): + continue + conversation_id = line["conversation_id"] + parent_id = line["message"]["id"] + if response_fields: + response_fields = False + yield ResponseFields(conversation_id, parent_id, end_turn) + if "parts" in line["message"]["content"]: + new_message = line["message"]["content"]["parts"][0] + if len(new_message) > last_message: + yield new_message[last_message:] + last_message = len(new_message) if "finish_details" in line["message"]["metadata"]: if line["message"]["metadata"]["finish_details"]["type"] == "stop": end_turn.end() - break except Exception as e: - yield e + raise e if not auto_continue: break action = "continue" await asyncio.sleep(5) - if history_disabled: - async with session.patch( - f"{cls.url}/backend-api/conversation/{conversation_id}", - json={"is_visible": False}, - headers=headers - ) as response: - response.raise_for_status() + if history_disabled and auto_continue: + await cls._delete_conversation(session, headers, conversation_id) @classmethod - def browse_access_token(cls, proxy: str = None) -> tuple[str, dict]: + def _browse_access_token(cls, proxy: str = None) -> tuple[str, dict]: + """ + Browse to obtain an access token. + + Args: + proxy (str): Proxy to use for browsing. + + Returns: + tuple[str, dict]: A tuple containing the access token and cookies. + """ driver = get_browser(proxy=proxy) try: driver.get(f"{cls.url}/") - WebDriverWait(driver, 1200).until( - EC.presence_of_element_located((By.ID, "prompt-textarea")) + WebDriverWait(driver, 1200).until(EC.presence_of_element_located((By.ID, "prompt-textarea"))) + access_token = driver.execute_script( + "let session = await fetch('/api/auth/session');" + "let data = await session.json();" + "let accessToken = data['accessToken'];" + "let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7);" + "document.cookie = 'access_token=' + accessToken + ';expires=' + expires.toUTCString() + ';path=/';" + "return accessToken;" ) - javascript = """ -access_token = (await (await fetch('/api/auth/session')).json())['accessToken']; -expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 24 * 7); // One week -document.cookie = 'access_token=' + access_token + ';expires=' + expires.toUTCString() + ';path=/'; -return access_token; -""" - return driver.execute_script(javascript), get_driver_cookies(driver) + return access_token, get_driver_cookies(driver) finally: driver.quit() - @classmethod - async def get_arkose_token(cls, session: StreamSession) -> str: + @classmethod + async def _get_arkose_token(cls, session: StreamSession) -> str: + """ + Obtain an Arkose token for the session. + + Args: + session (StreamSession): The session object. + + Returns: + str: The Arkose token. + + Raises: + RuntimeError: If unable to retrieve the token. + """ config = { "pkey": "3D86FBBA-9D22-402A-B512-3420086BA6CC", "surl": "https://tcr9i.chat.openai.com", @@ -332,26 +452,30 @@ return access_token; if "token" in decoded_json: return decoded_json["token"] raise RuntimeError(f"Response: {decoded_json}") - -class EndTurn(): + +class EndTurn: + """ + Class to represent the end of a conversation turn. + """ def __init__(self): self.is_end = False def end(self): self.is_end = True -class ResponseFields(): - def __init__( - self, - conversation_id: str, - message_id: str, - end_turn: EndTurn - ): +class ResponseFields: + """ + Class to encapsulate response fields. + """ + def __init__(self, conversation_id: str, message_id: str, end_turn: EndTurn): self.conversation_id = conversation_id self.message_id = message_id self._end_turn = end_turn class Response(): + """ + Class to encapsulate a response from the chat service. + """ def __init__( self, generator: AsyncResult, @@ -360,13 +484,13 @@ class Response(): options: dict ): self._generator = generator - self.action: str = action - self.is_end: bool = False + self.action = action + self.is_end = False self._message = None self._messages = messages self._options = options self._fields = None - + async def generator(self): if self._generator: self._generator = None @@ -384,19 +508,16 @@ class Response(): def __aiter__(self): return self.generator() - + @async_cached_property async def message(self) -> str: - [_ async for _ in self.generator()] + await self.generator() return self._message - + async def get_fields(self): - [_ async for _ in self.generator()] - return { - "conversation_id": self._fields.conversation_id, - "parent_id": self._fields.message_id, - } - + await self.generator() + return {"conversation_id": self._fields.conversation_id, "parent_id": self._fields.message_id} + async def next(self, prompt: str, **kwargs) -> Response: return await OpenaiChat.create( **self._options, @@ -406,7 +527,7 @@ class Response(): **await self.get_fields(), **kwargs ) - + async def do_continue(self, **kwargs) -> Response: fields = await self.get_fields() if self.is_end: @@ -418,7 +539,7 @@ class Response(): **fields, **kwargs ) - + async def variant(self, **kwargs) -> Response: if self.action != "next": raise RuntimeError("Can't create variant from continue or variant request.") @@ -429,11 +550,9 @@ class Response(): **await self.get_fields(), **kwargs ) - + @async_cached_property async def messages(self): messages = self._messages - messages.append({ - "role": "assistant", "content": await self.message - }) + messages.append({"role": "assistant", "content": await self.message}) return messages
\ No newline at end of file diff --git a/g4f/Provider/retry_provider.py b/g4f/Provider/retry_provider.py index 4d3e77ac..9cc026fc 100644 --- a/g4f/Provider/retry_provider.py +++ b/g4f/Provider/retry_provider.py @@ -7,8 +7,17 @@ from ..base_provider import BaseRetryProvider from .. import debug from ..errors import RetryProviderError, RetryNoProviderError - class RetryProvider(BaseRetryProvider): + """ + A provider class to handle retries for creating completions with different providers. + + Attributes: + providers (list): A list of provider instances. + shuffle (bool): A flag indicating whether to shuffle providers before use. + exceptions (dict): A dictionary to store exceptions encountered during retries. + last_provider (BaseProvider): The last provider that was used. + """ + def create_completion( self, model: str, @@ -16,10 +25,21 @@ class RetryProvider(BaseRetryProvider): stream: bool = False, **kwargs ) -> CreateResult: - if stream: - providers = [provider for provider in self.providers if provider.supports_stream] - else: - providers = self.providers + """ + Create a completion using available providers, with an option to stream the response. + + Args: + model (str): The model to be used for completion. + messages (Messages): The messages to be used for generating completion. + stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False. + + Yields: + CreateResult: Tokens or results from the completion. + + Raises: + Exception: Any exception encountered during the completion process. + """ + providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers if self.shuffle: random.shuffle(providers) @@ -50,10 +70,23 @@ class RetryProvider(BaseRetryProvider): messages: Messages, **kwargs ) -> str: + """ + Asynchronously create a completion using available providers. + + Args: + model (str): The model to be used for completion. + messages (Messages): The messages to be used for generating completion. + + Returns: + str: The result of the asynchronous completion. + + Raises: + Exception: Any exception encountered during the asynchronous completion process. + """ providers = self.providers if self.shuffle: random.shuffle(providers) - + self.exceptions = {} for provider in providers: self.last_provider = provider @@ -66,13 +99,20 @@ class RetryProvider(BaseRetryProvider): self.exceptions[provider.__name__] = e if debug.logging: print(f"{provider.__name__}: {e.__class__.__name__}: {e}") - + self.raise_exceptions() - + def raise_exceptions(self) -> None: + """ + Raise a combined exception if any occurred during retries. + + Raises: + RetryProviderError: If any provider encountered an exception. + RetryNoProviderError: If no provider is found. + """ if self.exceptions: raise RetryProviderError("RetryProvider failed:\n" + "\n".join([ f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items() ])) - + raise RetryNoProviderError("No provider found")
\ No newline at end of file diff --git a/g4f/__init__.py b/g4f/__init__.py index 68f9ccf6..2b0e5b46 100644 --- a/g4f/__init__.py +++ b/g4f/__init__.py @@ -15,6 +15,26 @@ def get_model_and_provider(model : Union[Model, str], ignored : list[str] = None, ignore_working: bool = False, ignore_stream: bool = False) -> tuple[str, ProviderType]: + """ + Retrieves the model and provider based on input parameters. + + Args: + model (Union[Model, str]): The model to use, either as an object or a string identifier. + provider (Union[ProviderType, str, None]): The provider to use, either as an object, a string identifier, or None. + stream (bool): Indicates if the operation should be performed as a stream. + ignored (list[str], optional): List of provider names to be ignored. + ignore_working (bool, optional): If True, ignores the working status of the provider. + ignore_stream (bool, optional): If True, ignores the streaming capability of the provider. + + Returns: + tuple[str, ProviderType]: A tuple containing the model name and the provider type. + + Raises: + ProviderNotFoundError: If the provider is not found. + ModelNotFoundError: If the model is not found. + ProviderNotWorkingError: If the provider is not working. + StreamNotSupportedError: If streaming is not supported by the provider. + """ if debug.version_check: debug.version_check = False version.utils.check_version() @@ -70,7 +90,30 @@ class ChatCompletion: ignore_stream_and_auth: bool = False, patch_provider: callable = None, **kwargs) -> Union[CreateResult, str]: - + """ + Creates a chat completion using the specified model, provider, and messages. + + Args: + model (Union[Model, str]): The model to use, either as an object or a string identifier. + messages (Messages): The messages for which the completion is to be created. + provider (Union[ProviderType, str, None], optional): The provider to use, either as an object, a string identifier, or None. + stream (bool, optional): Indicates if the operation should be performed as a stream. + auth (Union[str, None], optional): Authentication token or credentials, if required. + ignored (list[str], optional): List of provider names to be ignored. + ignore_working (bool, optional): If True, ignores the working status of the provider. + ignore_stream_and_auth (bool, optional): If True, ignores the stream and authentication requirement checks. + patch_provider (callable, optional): Function to modify the provider. + **kwargs: Additional keyword arguments. + + Returns: + Union[CreateResult, str]: The result of the chat completion operation. + + Raises: + AuthenticationRequiredError: If authentication is required but not provided. + ProviderNotFoundError, ModelNotFoundError: If the specified provider or model is not found. + ProviderNotWorkingError: If the provider is not operational. + StreamNotSupportedError: If streaming is requested but not supported by the provider. + """ model, provider = get_model_and_provider(model, provider, stream, ignored, ignore_working, ignore_stream_and_auth) if not ignore_stream_and_auth and provider.needs_auth and not auth: @@ -98,7 +141,24 @@ class ChatCompletion: ignored : list[str] = None, patch_provider: callable = None, **kwargs) -> Union[AsyncResult, str]: - + """ + Asynchronously creates a completion using the specified model and provider. + + Args: + model (Union[Model, str]): The model to use, either as an object or a string identifier. + messages (Messages): Messages to be processed. + provider (Union[ProviderType, str, None]): The provider to use, either as an object, a string identifier, or None. + stream (bool): Indicates if the operation should be performed as a stream. + ignored (list[str], optional): List of provider names to be ignored. + patch_provider (callable, optional): Function to modify the provider. + **kwargs: Additional keyword arguments. + + Returns: + Union[AsyncResult, str]: The result of the asynchronous chat completion operation. + + Raises: + StreamNotSupportedError: If streaming is requested but not supported by the provider. + """ model, provider = get_model_and_provider(model, provider, False, ignored) if stream: @@ -118,7 +178,23 @@ class Completion: provider : Union[ProviderType, None] = None, stream : bool = False, ignored : list[str] = None, **kwargs) -> Union[CreateResult, str]: - + """ + Creates a completion based on the provided model, prompt, and provider. + + Args: + model (Union[Model, str]): The model to use, either as an object or a string identifier. + prompt (str): The prompt text for which the completion is to be created. + provider (Union[ProviderType, None], optional): The provider to use, either as an object or None. + stream (bool, optional): Indicates if the operation should be performed as a stream. + ignored (list[str], optional): List of provider names to be ignored. + **kwargs: Additional keyword arguments. + + Returns: + Union[CreateResult, str]: The result of the completion operation. + + Raises: + ModelNotAllowedError: If the specified model is not allowed for use with this method. + """ allowed_models = [ 'code-davinci-002', 'text-ada-001', @@ -137,6 +213,15 @@ class Completion: return result if stream else ''.join(result) def get_last_provider(as_dict: bool = False) -> Union[ProviderType, dict[str, str]]: + """ + Retrieves the last used provider. + + Args: + as_dict (bool, optional): If True, returns the provider information as a dictionary. + + Returns: + Union[ProviderType, dict[str, str]]: The last used provider, either as an object or a dictionary. + """ last = debug.last_provider if isinstance(last, BaseRetryProvider): last = last.last_provider diff --git a/g4f/base_provider.py b/g4f/base_provider.py index 1863f6bc..03ae64d6 100644 --- a/g4f/base_provider.py +++ b/g4f/base_provider.py @@ -1,7 +1,22 @@ from abc import ABC, abstractmethod -from .typing import Messages, CreateResult, Union - +from typing import Union, List, Dict, Type +from .typing import Messages, CreateResult + class BaseProvider(ABC): + """ + Abstract base class for a provider. + + Attributes: + url (str): URL of the provider. + working (bool): Indicates if the provider is currently working. + needs_auth (bool): Indicates if the provider needs authentication. + supports_stream (bool): Indicates if the provider supports streaming. + supports_gpt_35_turbo (bool): Indicates if the provider supports GPT-3.5 Turbo. + supports_gpt_4 (bool): Indicates if the provider supports GPT-4. + supports_message_history (bool): Indicates if the provider supports message history. + params (str): List parameters for the provider. + """ + url: str = None working: bool = False needs_auth: bool = False @@ -20,6 +35,18 @@ class BaseProvider(ABC): stream: bool, **kwargs ) -> CreateResult: + """ + Create a completion with the given parameters. + + Args: + model (str): The model to use. + messages (Messages): The messages to process. + stream (bool): Whether to use streaming. + **kwargs: Additional keyword arguments. + + Returns: + CreateResult: The result of the creation process. + """ raise NotImplementedError() @classmethod @@ -30,25 +57,59 @@ class BaseProvider(ABC): messages: Messages, **kwargs ) -> str: + """ + Asynchronously create a completion with the given parameters. + + Args: + model (str): The model to use. + messages (Messages): The messages to process. + **kwargs: Additional keyword arguments. + + Returns: + str: The result of the creation process. + """ raise NotImplementedError() @classmethod - def get_dict(cls): + def get_dict(cls) -> Dict[str, str]: + """ + Get a dictionary representation of the provider. + + Returns: + Dict[str, str]: A dictionary with provider's details. + """ return {'name': cls.__name__, 'url': cls.url} class BaseRetryProvider(BaseProvider): + """ + Base class for a provider that implements retry logic. + + Attributes: + providers (List[Type[BaseProvider]]): List of providers to use for retries. + shuffle (bool): Whether to shuffle the providers list. + exceptions (Dict[str, Exception]): Dictionary of exceptions encountered. + last_provider (Type[BaseProvider]): The last provider used. + """ + __name__: str = "RetryProvider" supports_stream: bool = True def __init__( self, - providers: list[type[BaseProvider]], + providers: List[Type[BaseProvider]], shuffle: bool = True ) -> None: - self.providers: list[type[BaseProvider]] = providers - self.shuffle: bool = shuffle - self.working: bool = True - self.exceptions: dict[str, Exception] = {} - self.last_provider: type[BaseProvider] = None + """ + Initialize the BaseRetryProvider. + + Args: + providers (List[Type[BaseProvider]]): List of providers to use. + shuffle (bool): Whether to shuffle the providers list. + """ + self.providers = providers + self.shuffle = shuffle + self.working = True + self.exceptions: Dict[str, Exception] = {} + self.last_provider: Type[BaseProvider] = None -ProviderType = Union[type[BaseProvider], BaseRetryProvider]
\ No newline at end of file +ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
\ No newline at end of file diff --git a/g4f/gui/client/css/style.css b/g4f/gui/client/css/style.css index 2d4c9857..e77410ab 100644 --- a/g4f/gui/client/css/style.css +++ b/g4f/gui/client/css/style.css @@ -404,7 +404,7 @@ body { display: none; } -#image { +#image, #file { display: none; } @@ -412,13 +412,22 @@ label[for="image"]:has(> input:valid){ color: var(--accent); } -label[for="image"] { +label[for="file"]:has(> input:valid){ + color: var(--accent); +} + +label[for="image"], label[for="file"] { cursor: pointer; position: absolute; top: 10px; left: 10px; } +label[for="file"] { + top: 32px; + left: 10px; +} + .buttons input[type="checkbox"] { height: 0; width: 0; diff --git a/g4f/gui/client/html/index.html b/g4f/gui/client/html/index.html index 3f2bb0c0..95489ba4 100644 --- a/g4f/gui/client/html/index.html +++ b/g4f/gui/client/html/index.html @@ -118,6 +118,10 @@ <input type="file" id="image" name="image" accept="image/png, image/gif, image/jpeg" required/> <i class="fa-regular fa-image"></i> </label> + <label for="file"> + <input type="file" id="file" name="file" accept="text/plain, text/html, text/xml, application/json, text/javascript, .sh, .py, .php, .css, .yaml, .sql, .svg, .log, .csv, .twig, .md" required/> + <i class="fa-solid fa-paperclip"></i> + </label> <div id="send-button"> <i class="fa-solid fa-paper-plane-top"></i> </div> @@ -125,7 +129,14 @@ </div> <div class="buttons"> <div class="field"> - <select name="model" id="model"></select> + <select name="model" id="model"> + <option value="">Model: Default</option> + <option value="gpt-4">gpt-4</option> + <option value="gpt-3.5-turbo">gpt-3.5-turbo</option> + <option value="llama2-70b">llama2-70b</option> + <option value="gemini-pro">gemini-pro</option> + <option value="">----</option> + </select> </div> <div class="field"> <select name="jailbreak" id="jailbreak" style="display: none;"> @@ -138,7 +149,16 @@ <option value="gpt-evil-1.0">evil 1.0</option> </select> <div class="field"> - <select name="provider" id="provider"></select> + <select name="provider" id="provider"> + <option value="">Provider: Auto</option> + <option value="Bing">Bing</option> + <option value="OpenaiChat">OpenaiChat</option> + <option value="HuggingChat">HuggingChat</option> + <option value="Bard">Bard</option> + <option value="Liaobots">Liaobots</option> + <option value="Phind">Phind</option> + <option value="">----</option> + </select> </div> </div> <div class="field"> diff --git a/g4f/gui/client/js/chat.v1.js b/g4f/gui/client/js/chat.v1.js index ccc9461b..8b9bc181 100644 --- a/g4f/gui/client/js/chat.v1.js +++ b/g4f/gui/client/js/chat.v1.js @@ -7,7 +7,9 @@ const spinner = box_conversations.querySelector(".spinner"); const stop_generating = document.querySelector(`.stop_generating`); const regenerate = document.querySelector(`.regenerate`); const send_button = document.querySelector(`#send-button`); -const imageInput = document.querySelector('#image') ; +const imageInput = document.querySelector('#image'); +const fileInput = document.querySelector('#file'); + let prompt_lock = false; hljs.addPlugin(new CopyButtonPlugin()); @@ -42,6 +44,11 @@ const handle_ask = async () => { if (message.length > 0) { message_input.value = ''; await add_conversation(window.conversation_id, message); + if ("text" in fileInput.dataset) { + message += '\n```' + fileInput.dataset.type + '\n'; + message += fileInput.dataset.text; + message += '\n```' + } await add_message(window.conversation_id, "user", message); window.token = message_id(); message_box.innerHTML += ` @@ -55,6 +62,9 @@ const handle_ask = async () => { </div> </div> `; + document.querySelectorAll('code:not(.hljs').forEach((el) => { + hljs.highlightElement(el); + }); await ask_gpt(); } }; @@ -171,17 +181,30 @@ const ask_gpt = async () => { content_inner.innerHTML += "<p>An error occured, please try again, if the problem persists, please use a other model or provider.</p>"; } else { html = markdown_render(text); - html = html.substring(0, html.lastIndexOf('</p>')) + '<span id="cursor"></span></p>'; + let lastElement, lastIndex = null; + for (element of ['</p>', '</code></pre>', '</li>\n</ol>']) { + const index = html.lastIndexOf(element) + if (index > lastIndex) { + lastElement = element; + lastIndex = index; + } + } + if (lastIndex) { + html = html.substring(0, lastIndex) + '<span id="cursor"></span>' + lastElement; + } content_inner.innerHTML = html; - document.querySelectorAll('code').forEach((el) => { + document.querySelectorAll('code:not(.hljs').forEach((el) => { hljs.highlightElement(el); }); } window.scrollTo(0, 0); - message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" }); + if (message_box.scrollTop >= message_box.scrollHeight - message_box.clientHeight - 100) { + message_box.scrollTo({ top: message_box.scrollHeight, behavior: "auto" }); + } } if (!error && imageInput) imageInput.value = ""; + if (!error && fileInput) fileInput.value = ""; } catch (e) { console.error(e); @@ -305,7 +328,7 @@ const load_conversation = async (conversation_id) => { `; } - document.querySelectorAll(`code`).forEach((el) => { + document.querySelectorAll('code:not(.hljs').forEach((el) => { hljs.highlightElement(el); }); @@ -400,7 +423,7 @@ const load_conversations = async (limit, offset, loader) => { `; } - document.querySelectorAll(`code`).forEach((el) => { + document.querySelectorAll('code:not(.hljs').forEach((el) => { hljs.highlightElement(el); }); }; @@ -602,14 +625,7 @@ observer.observe(message_input, { attributes: true }); (async () => { response = await fetch('/backend-api/v2/models') models = await response.json() - let select = document.getElementById('model'); - select.textContent = ''; - - let auto = document.createElement('option'); - auto.value = ''; - auto.text = 'Model: Default'; - select.appendChild(auto); for (model of models) { let option = document.createElement('option'); @@ -619,14 +635,7 @@ observer.observe(message_input, { attributes: true }); response = await fetch('/backend-api/v2/providers') providers = await response.json() - select = document.getElementById('provider'); - select.textContent = ''; - - auto = document.createElement('option'); - auto.value = ''; - auto.text = 'Provider: Auto'; - select.appendChild(auto); for (provider of providers) { let option = document.createElement('option'); @@ -643,11 +652,34 @@ observer.observe(message_input, { attributes: true }); document.title = 'g4f - gui - ' + versions["version"]; text = "version ~ " - if (versions["version"] != versions["lastet_version"]) { - release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["lastet_version"]; - text += '<a href="' + release_url +'" target="_blank" title="New version: ' + versions["lastet_version"] +'">' + versions["version"] + ' 🆕</a>'; + if (versions["version"] != versions["latest_version"]) { + release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["latest_version"]; + text += '<a href="' + release_url +'" target="_blank" title="New version: ' + versions["latest_version"] +'">' + versions["version"] + ' 🆕</a>'; } else { text += versions["version"]; } document.getElementById("version_text").innerHTML = text -})()
\ No newline at end of file +})() + +fileInput.addEventListener('change', async (event) => { + if (fileInput.files.length) { + type = fileInput.files[0].type; + if (type && type.indexOf('/')) { + type = type.split('/').pop().replace('x-', '') + type = type.replace('plain', 'plaintext') + .replace('shellscript', 'sh') + .replace('svg+xml', 'svg') + .replace('vnd.trolltech.linguist', 'ts') + } else { + type = fileInput.files[0].name.split('.').pop() + } + fileInput.dataset.type = type + const reader = new FileReader(); + reader.addEventListener('load', (event) => { + fileInput.dataset.text = event.target.result; + }); + reader.readAsText(fileInput.files[0]); + } else { + delete fileInput.dataset.text; + } +});
\ No newline at end of file diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py index 9d12bea5..4a5cafa8 100644 --- a/g4f/gui/server/backend.py +++ b/g4f/gui/server/backend.py @@ -1,6 +1,7 @@ import logging import json from flask import request, Flask +from typing import Generator from g4f import debug, version, models from g4f import _all_models, get_last_provider, ChatCompletion from g4f.image import is_allowed_extension, to_image @@ -11,60 +12,123 @@ from .internet import get_search_message debug.logging = True class Backend_Api: + """ + Handles various endpoints in a Flask application for backend operations. + + This class provides methods to interact with models, providers, and to handle + various functionalities like conversations, error handling, and version management. + + Attributes: + app (Flask): A Flask application instance. + routes (dict): A dictionary mapping API endpoints to their respective handlers. + """ def __init__(self, app: Flask) -> None: + """ + Initialize the backend API with the given Flask application. + + Args: + app (Flask): Flask application instance to attach routes to. + """ self.app: Flask = app self.routes = { '/backend-api/v2/models': { - 'function': self.models, - 'methods' : ['GET'] + 'function': self.get_models, + 'methods': ['GET'] }, '/backend-api/v2/providers': { - 'function': self.providers, - 'methods' : ['GET'] + 'function': self.get_providers, + 'methods': ['GET'] }, '/backend-api/v2/version': { - 'function': self.version, - 'methods' : ['GET'] + 'function': self.get_version, + 'methods': ['GET'] }, '/backend-api/v2/conversation': { - 'function': self._conversation, + 'function': self.handle_conversation, 'methods': ['POST'] }, '/backend-api/v2/gen.set.summarize:title': { - 'function': self._gen_title, + 'function': self.generate_title, 'methods': ['POST'] }, '/backend-api/v2/error': { - 'function': self.error, + 'function': self.handle_error, 'methods': ['POST'] } } - def error(self): + def handle_error(self): + """ + Initialize the backend API with the given Flask application. + + Args: + app (Flask): Flask application instance to attach routes to. + """ print(request.json) - return 'ok', 200 - def models(self): + def get_models(self): + """ + Return a list of all models. + + Fetches and returns a list of all available models in the system. + + Returns: + List[str]: A list of model names. + """ return _all_models - def providers(self): - return [ - provider.__name__ for provider in __providers__ if provider.working - ] + def get_providers(self): + """ + Return a list of all working providers. + """ + return [provider.__name__ for provider in __providers__ if provider.working] - def version(self): + def get_version(self): + """ + Returns the current and latest version of the application. + + Returns: + dict: A dictionary containing the current and latest version. + """ return { "version": version.utils.current_version, - "lastet_version": version.get_latest_version(), + "latest_version": version.get_latest_version(), } - def _gen_title(self): - return { - 'title': '' - } + def generate_title(self): + """ + Generates and returns a title based on the request data. + + Returns: + dict: A dictionary with the generated title. + """ + return {'title': ''} - def _conversation(self): + def handle_conversation(self): + """ + Handles conversation requests and streams responses back. + + Returns: + Response: A Flask response object for streaming. + """ + kwargs = self._prepare_conversation_kwargs() + + return self.app.response_class( + self._create_response_stream(kwargs), + mimetype='text/event-stream' + ) + + def _prepare_conversation_kwargs(self): + """ + Prepares arguments for chat completion based on the request data. + + Reads the request and prepares the necessary arguments for handling + a chat completion request. + + Returns: + dict: Arguments prepared for chat completion. + """ kwargs = {} if 'image' in request.files: file = request.files['image'] @@ -87,47 +151,70 @@ class Backend_Api: messages[-1]["content"] = get_search_message(messages[-1]["content"]) model = json_data.get('model') model = model if model else models.default - provider = json_data.get('provider', '').replace('g4f.Provider.', '') - provider = provider if provider and provider != "Auto" else None patch = patch_provider if json_data.get('patch_provider') else None - def try_response(): - try: - first = True - for chunk in ChatCompletion.create( - model=model, - provider=provider, - messages=messages, - stream=True, - ignore_stream_and_auth=True, - patch_provider=patch, - **kwargs - ): - if first: - first = False - yield json.dumps({ - 'type' : 'provider', - 'provider': get_last_provider(True) - }) + "\n" - if isinstance(chunk, Exception): - logging.exception(chunk) - yield json.dumps({ - 'type' : 'message', - 'message': get_error_message(chunk), - }) + "\n" - else: - yield json.dumps({ - 'type' : 'content', - 'content': str(chunk), - }) + "\n" - except Exception as e: - logging.exception(e) - yield json.dumps({ - 'type' : 'error', - 'error': get_error_message(e) - }) - - return self.app.response_class(try_response(), mimetype='text/event-stream') + return { + "model": model, + "provider": provider, + "messages": messages, + "stream": True, + "ignore_stream_and_auth": True, + "patch_provider": patch, + **kwargs + } + + def _create_response_stream(self, kwargs) -> Generator[str, None, None]: + """ + Creates and returns a streaming response for the conversation. + + Args: + kwargs (dict): Arguments for creating the chat completion. + + Yields: + str: JSON formatted response chunks for the stream. + + Raises: + Exception: If an error occurs during the streaming process. + """ + try: + first = True + for chunk in ChatCompletion.create(**kwargs): + if first: + first = False + yield self._format_json('provider', get_last_provider(True)) + if isinstance(chunk, Exception): + logging.exception(chunk) + yield self._format_json('message', get_error_message(chunk)) + else: + yield self._format_json('content', str(chunk)) + except Exception as e: + logging.exception(e) + yield self._format_json('error', get_error_message(e)) + + def _format_json(self, response_type: str, content) -> str: + """ + Formats and returns a JSON response. + + Args: + response_type (str): The type of the response. + content: The content to be included in the response. + + Returns: + str: A JSON formatted string. + """ + return json.dumps({ + 'type': response_type, + response_type: content + }) + "\n" def get_error_message(exception: Exception) -> str: + """ + Generates a formatted error message from an exception. + + Args: + exception (Exception): The exception to format. + + Returns: + str: A formatted error message string. + """ return f"{get_last_provider().__name__}: {type(exception).__name__}: {exception}"
\ No newline at end of file diff --git a/g4f/image.py b/g4f/image.py index 01664f4e..cfa22ab1 100644 --- a/g4f/image.py +++ b/g4f/image.py @@ -4,9 +4,18 @@ import base64 from .typing import ImageType, Union from PIL import Image -ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'} +ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp'} def to_image(image: ImageType) -> Image.Image: + """ + Converts the input image to a PIL Image object. + + Args: + image (Union[str, bytes, Image.Image]): The input image. + + Returns: + Image.Image: The converted PIL Image object. + """ if isinstance(image, str): is_data_uri_an_image(image) image = extract_data_uri(image) @@ -20,21 +29,48 @@ def to_image(image: ImageType) -> Image.Image: image = copy return image -def is_allowed_extension(filename) -> bool: +def is_allowed_extension(filename: str) -> bool: + """ + Checks if the given filename has an allowed extension. + + Args: + filename (str): The filename to check. + + Returns: + bool: True if the extension is allowed, False otherwise. + """ return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def is_data_uri_an_image(data_uri: str) -> bool: + """ + Checks if the given data URI represents an image. + + Args: + data_uri (str): The data URI to check. + + Raises: + ValueError: If the data URI is invalid or the image format is not allowed. + """ # Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif) if not re.match(r'data:image/(\w+);base64,', data_uri): raise ValueError("Invalid data URI image.") - # Extract the image format from the data URI + # Extract the image format from the data URI image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1) # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif) if image_format.lower() not in ALLOWED_EXTENSIONS: raise ValueError("Invalid image format (from mime file type).") def is_accepted_format(binary_data: bytes) -> bool: + """ + Checks if the given binary data represents an image with an accepted format. + + Args: + binary_data (bytes): The binary data to check. + + Raises: + ValueError: If the image format is not allowed. + """ if binary_data.startswith(b'\xFF\xD8\xFF'): pass # It's a JPEG image elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'): @@ -49,13 +85,31 @@ def is_accepted_format(binary_data: bytes) -> bool: pass # It's a WebP image else: raise ValueError("Invalid image format (from magic code).") - + def extract_data_uri(data_uri: str) -> bytes: + """ + Extracts the binary data from the given data URI. + + Args: + data_uri (str): The data URI. + + Returns: + bytes: The extracted binary data. + """ data = data_uri.split(",")[1] data = base64.b64decode(data) return data def get_orientation(image: Image.Image) -> int: + """ + Gets the orientation of the given image. + + Args: + image (Image.Image): The image. + + Returns: + int: The orientation value. + """ exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif() if exif_data is not None: orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF @@ -63,6 +117,17 @@ def get_orientation(image: Image.Image) -> int: return orientation def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Image: + """ + Processes the given image by adjusting its orientation and resizing it. + + Args: + img (Image.Image): The image to process. + new_width (int): The new width of the image. + new_height (int): The new height of the image. + + Returns: + Image.Image: The processed image. + """ orientation = get_orientation(img) if orientation: if orientation > 4: @@ -75,13 +140,34 @@ def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Im img = img.transpose(Image.ROTATE_90) img.thumbnail((new_width, new_height)) return img - + def to_base64(image: Image.Image, compression_rate: float) -> str: + """ + Converts the given image to a base64-encoded string. + + Args: + image (Image.Image): The image to convert. + compression_rate (float): The compression rate (0.0 to 1.0). + + Returns: + str: The base64-encoded image. + """ output_buffer = BytesIO() image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100)) return base64.b64encode(output_buffer.getvalue()).decode() def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=200") -> str: + """ + Formats the given images as a markdown string. + + Args: + images: The images to format. + prompt (str): The prompt for the images. + preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200". + + Returns: + str: The formatted markdown string. + """ if isinstance(images, list): images = [f"[![#{idx+1} {prompt}]({preview.replace('{image}', image)})]({image})" for idx, image in enumerate(images)] images = "\n".join(images) @@ -92,6 +178,15 @@ def format_images_markdown(images, prompt: str, preview: str="{image}?w=200&h=20 return f"\n{start_flag}{images}\n{end_flag}\n" def to_bytes(image: Image.Image) -> bytes: + """ + Converts the given image to bytes. + + Args: + image (Image.Image): The image to convert. + + Returns: + bytes: The image as bytes. + """ bytes_io = BytesIO() image.save(bytes_io, image.format) image.seek(0) diff --git a/g4f/models.py b/g4f/models.py index 03deebf8..dd6e0a2c 100644 --- a/g4f/models.py +++ b/g4f/models.py @@ -31,12 +31,21 @@ from .Provider import ( @dataclass(unsafe_hash=True) class Model: + """ + Represents a machine learning model configuration. + + Attributes: + name (str): Name of the model. + base_provider (str): Default provider for the model. + best_provider (ProviderType): The preferred provider for the model, typically with retry logic. + """ name: str base_provider: str best_provider: ProviderType = None @staticmethod def __all__() -> list[str]: + """Returns a list of all model names.""" return _all_models default = Model( @@ -298,6 +307,12 @@ pi = Model( ) class ModelUtils: + """ + Utility class for mapping string identifiers to Model instances. + + Attributes: + convert (dict[str, Model]): Dictionary mapping model string identifiers to Model instances. + """ convert: dict[str, Model] = { # gpt-3.5 'gpt-3.5-turbo' : gpt_35_turbo, diff --git a/g4f/requests.py b/g4f/requests.py index 1a13dec9..466d5a2a 100644 --- a/g4f/requests.py +++ b/g4f/requests.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -from contextlib import asynccontextmanager from functools import partialmethod from typing import AsyncGenerator from urllib.parse import urlparse @@ -9,27 +8,41 @@ from curl_cffi.requests import AsyncSession, Session, Response from .webdriver import WebDriver, WebDriverSession, bypass_cloudflare, get_driver_cookies class StreamResponse: + """ + A wrapper class for handling asynchronous streaming responses. + + Attributes: + inner (Response): The original Response object. + """ + def __init__(self, inner: Response) -> None: + """Initialize the StreamResponse with the provided Response object.""" self.inner: Response = inner async def text(self) -> str: + """Asynchronously get the response text.""" return await self.inner.atext() def raise_for_status(self) -> None: + """Raise an HTTPError if one occurred.""" self.inner.raise_for_status() async def json(self, **kwargs) -> dict: + """Asynchronously parse the JSON response content.""" return json.loads(await self.inner.acontent(), **kwargs) async def iter_lines(self) -> AsyncGenerator[bytes, None]: + """Asynchronously iterate over the lines of the response.""" async for line in self.inner.aiter_lines(): yield line async def iter_content(self) -> AsyncGenerator[bytes, None]: + """Asynchronously iterate over the response content.""" async for chunk in self.inner.aiter_content(): yield chunk - + async def __aenter__(self): + """Asynchronously enter the runtime context for the response object.""" inner: Response = await self.inner self.inner = inner self.request = inner.request @@ -39,24 +52,47 @@ class StreamResponse: self.headers = inner.headers self.cookies = inner.cookies return self - + async def __aexit__(self, *args): + """Asynchronously exit the runtime context for the response object.""" await self.inner.aclose() + class StreamSession(AsyncSession): + """ + An asynchronous session class for handling HTTP requests with streaming. + + Inherits from AsyncSession. + """ + def request( self, method: str, url: str, **kwargs ) -> StreamResponse: + """Create and return a StreamResponse object for the given HTTP request.""" return StreamResponse(super().request(method, url, stream=True, **kwargs)) + # Defining HTTP methods as partial methods of the request method. head = partialmethod(request, "HEAD") get = partialmethod(request, "GET") post = partialmethod(request, "POST") put = partialmethod(request, "PUT") patch = partialmethod(request, "PATCH") delete = partialmethod(request, "DELETE") - -def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120): + + +def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = None, timeout: int = 120) -> Session: + """ + Create a Session object using a WebDriver to handle cookies and headers. + + Args: + url (str): The URL to navigate to using the WebDriver. + webdriver (WebDriver, optional): The WebDriver instance to use. + proxy (str, optional): Proxy server to use for the Session. + timeout (int, optional): Timeout in seconds for the WebDriver. + + Returns: + Session: A Session object configured with cookies and headers from the WebDriver. + """ with WebDriverSession(webdriver, "", proxy=proxy, virtual_display=True) as driver: bypass_cloudflare(driver, url, timeout) cookies = get_driver_cookies(driver) @@ -78,4 +114,4 @@ def get_session_from_browser(url: str, webdriver: WebDriver = None, proxy: str = proxies={"https": proxy, "http": proxy}, timeout=timeout, impersonate="chrome110" - ) + )
\ No newline at end of file diff --git a/g4f/version.py b/g4f/version.py index bb4b7f17..9201c75c 100644 --- a/g4f/version.py +++ b/g4f/version.py @@ -5,45 +5,120 @@ from importlib.metadata import version as get_package_version, PackageNotFoundEr from subprocess import check_output, CalledProcessError, PIPE from .errors import VersionNotFoundError +def get_pypi_version(package_name: str) -> str: + """ + Retrieves the latest version of a package from PyPI. + + Args: + package_name (str): The name of the package for which to retrieve the version. + + Returns: + str: The latest version of the specified package from PyPI. + + Raises: + VersionNotFoundError: If there is an error in fetching the version from PyPI. + """ + try: + response = requests.get(f"https://pypi.org/pypi/{package_name}/json").json() + return response["info"]["version"] + except requests.RequestException as e: + raise VersionNotFoundError(f"Failed to get PyPI version: {e}") + +def get_github_version(repo: str) -> str: + """ + Retrieves the latest release version from a GitHub repository. + + Args: + repo (str): The name of the GitHub repository. + + Returns: + str: The latest release version from the specified GitHub repository. + + Raises: + VersionNotFoundError: If there is an error in fetching the version from GitHub. + """ + try: + response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json() + return response["tag_name"] + except requests.RequestException as e: + raise VersionNotFoundError(f"Failed to get GitHub release version: {e}") + def get_latest_version() -> str: + """ + Retrieves the latest release version of the 'g4f' package from PyPI or GitHub. + + Returns: + str: The latest release version of 'g4f'. + + Note: + The function first tries to fetch the version from PyPI. If the package is not found, + it retrieves the version from the GitHub repository. + """ try: + # Is installed via package manager? get_package_version("g4f") - response = requests.get("https://pypi.org/pypi/g4f/json").json() - return response["info"]["version"] + return get_pypi_version("g4f") except PackageNotFoundError: - url = "https://api.github.com/repos/xtekky/gpt4free/releases/latest" - response = requests.get(url).json() - return response["tag_name"] + # Else use Github version: + return get_github_version("xtekky/gpt4free") -class VersionUtils(): +class VersionUtils: + """ + Utility class for managing and comparing package versions of 'g4f'. + """ @cached_property def current_version(self) -> str: + """ + Retrieves the current version of the 'g4f' package. + + Returns: + str: The current version of 'g4f'. + + Raises: + VersionNotFoundError: If the version cannot be determined from the package manager, + Docker environment, or git repository. + """ # Read from package manager try: return get_package_version("g4f") except PackageNotFoundError: pass + # Read from docker environment version = environ.get("G4F_VERSION") if version: return version + # Read from git repository try: command = ["git", "describe", "--tags", "--abbrev=0"] return check_output(command, text=True, stderr=PIPE).strip() except CalledProcessError: pass + raise VersionNotFoundError("Version not found") - + @cached_property def latest_version(self) -> str: + """ + Retrieves the latest version of the 'g4f' package. + + Returns: + str: The latest version of 'g4f'. + """ return get_latest_version() - + def check_version(self) -> None: + """ + Checks if the current version of 'g4f' is up to date with the latest version. + + Note: + If a newer version is available, it prints a message with the new version and update instructions. + """ try: if self.current_version != self.latest_version: print(f'New g4f version: {self.latest_version} (current: {self.current_version}) | pip install -U g4f') except Exception as e: print(f'Failed to check g4f version: {e}') - + utils = VersionUtils()
\ No newline at end of file diff --git a/g4f/webdriver.py b/g4f/webdriver.py index da283409..9a83215f 100644 --- a/g4f/webdriver.py +++ b/g4f/webdriver.py @@ -1,5 +1,4 @@ from __future__ import annotations - from platformdirs import user_config_dir from selenium.webdriver.remote.webdriver import WebDriver from undetected_chromedriver import Chrome, ChromeOptions @@ -21,7 +20,19 @@ def get_browser( proxy: str = None, options: ChromeOptions = None ) -> WebDriver: - if user_data_dir == None: + """ + Creates and returns a Chrome WebDriver with specified options. + + Args: + user_data_dir (str, optional): Directory for user data. If None, uses default directory. + headless (bool, optional): Whether to run the browser in headless mode. Defaults to False. + proxy (str, optional): Proxy settings for the browser. Defaults to None. + options (ChromeOptions, optional): ChromeOptions object with specific browser options. Defaults to None. + + Returns: + WebDriver: An instance of WebDriver configured with the specified options. + """ + if user_data_dir is None: user_data_dir = user_config_dir("g4f") if user_data_dir and debug.logging: print("Open browser with config dir:", user_data_dir) @@ -39,36 +50,53 @@ def get_browser( headless=headless ) -def get_driver_cookies(driver: WebDriver): - return dict([(cookie["name"], cookie["value"]) for cookie in driver.get_cookies()]) +def get_driver_cookies(driver: WebDriver) -> dict: + """ + Retrieves cookies from the specified WebDriver. + + Args: + driver (WebDriver): The WebDriver instance from which to retrieve cookies. + + Returns: + dict: A dictionary containing cookies with their names as keys and values as cookie values. + """ + return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()} def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None: - # Open website + """ + Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver. + + Args: + driver (WebDriver): The WebDriver to use for accessing the URL. + url (str): The URL to access. + timeout (int): Time in seconds to wait for the page to load. + + Raises: + Exception: If there is an error while bypassing Cloudflare or loading the page. + """ driver.get(url) - # Is cloudflare protection if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js": if debug.logging: print("Cloudflare protection detected:", url) try: - # Click button in iframe - WebDriverWait(driver, 5).until( - EC.presence_of_element_located((By.CSS_SELECTOR, "#turnstile-wrapper iframe")) - ) driver.switch_to.frame(driver.find_element(By.CSS_SELECTOR, "#turnstile-wrapper iframe")) WebDriverWait(driver, 5).until( EC.presence_of_element_located((By.CSS_SELECTOR, "#challenge-stage input")) - ) - driver.find_element(By.CSS_SELECTOR, "#challenge-stage input").click() - except: - pass + ).click() + except Exception as e: + if debug.logging: + print(f"Error bypassing Cloudflare: {e}") finally: driver.switch_to.default_content() - # No cloudflare protection WebDriverWait(driver, timeout).until( EC.presence_of_element_located((By.CSS_SELECTOR, "body:not(.no-js)")) ) -class WebDriverSession(): +class WebDriverSession: + """ + Manages a Selenium WebDriver session, including handling of virtual displays and proxies. + """ + def __init__( self, webdriver: WebDriver = None, @@ -78,12 +106,21 @@ class WebDriverSession(): proxy: str = None, options: ChromeOptions = None ): + """ + Initializes a new instance of the WebDriverSession. + + Args: + webdriver (WebDriver, optional): A WebDriver instance for the session. Defaults to None. + user_data_dir (str, optional): Directory for user data. Defaults to None. + headless (bool, optional): Whether to run the browser in headless mode. Defaults to False. + virtual_display (bool, optional): Whether to use a virtual display. Defaults to False. + proxy (str, optional): Proxy settings for the browser. Defaults to None. + options (ChromeOptions, optional): ChromeOptions for the browser. Defaults to None. + """ self.webdriver = webdriver self.user_data_dir = user_data_dir self.headless = headless - self.virtual_display = None - if has_pyvirtualdisplay and virtual_display: - self.virtual_display = Display(size=(1920, 1080)) + self.virtual_display = Display(size=(1920, 1080)) if has_pyvirtualdisplay and virtual_display else None self.proxy = proxy self.options = options self.default_driver = None @@ -94,8 +131,18 @@ class WebDriverSession(): headless: bool = False, virtual_display: bool = False ) -> WebDriver: - if user_data_dir == None: - user_data_dir = self.user_data_dir + """ + Reopens the WebDriver session with new settings. + + Args: + user_data_dir (str, optional): Directory for user data. Defaults to current value. + headless (bool, optional): Whether to run the browser in headless mode. Defaults to current value. + virtual_display (bool, optional): Whether to use a virtual display. Defaults to current value. + + Returns: + WebDriver: The reopened WebDriver instance. + """ + user_data_dir = user_data_data_dir or self.user_data_dir if self.default_driver: self.default_driver.quit() if not virtual_display and self.virtual_display: @@ -105,6 +152,12 @@ class WebDriverSession(): return self.default_driver def __enter__(self) -> WebDriver: + """ + Context management method for entering a session. Initializes and returns a WebDriver instance. + + Returns: + WebDriver: An instance of WebDriver for this session. + """ if self.webdriver: return self.webdriver if self.virtual_display: @@ -113,11 +166,23 @@ class WebDriverSession(): return self.default_driver def __exit__(self, exc_type, exc_val, exc_tb): + """ + Context management method for exiting a session. Closes and quits the WebDriver. + + Args: + exc_type: Exception type. + exc_val: Exception value. + exc_tb: Exception traceback. + + Note: + Closes the WebDriver and stops the virtual display if used. + """ if self.default_driver: try: self.default_driver.close() - except: - pass + except Exception as e: + if debug.logging: + print(f"Error closing WebDriver: {e}") self.default_driver.quit() if self.virtual_display: self.virtual_display.stop()
\ No newline at end of file |