From 32252def150da94f12d1f3c07f977af6d8931402 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Sun, 14 Jan 2024 15:04:37 +0100 Subject: Change doctypes style to Google Fix typo in latest_version Fix Phind Provider Add unittest worklow and main tests --- .github/workflows/unittest.yml | 19 ++++ etc/unittest/main.py | 73 +++++++++++++ g4f/Provider/Phind.py | 8 +- g4f/Provider/base_provider.py | 71 +++++++++++++ g4f/Provider/bing/create_images.py | 2 +- g4f/Provider/create_images.py | 61 ++++++++++- g4f/gui/client/js/chat.v1.js | 6 +- g4f/gui/server/backend.py | 211 ++++++++++++++++++++++++++----------- g4f/version.py | 56 +++++++--- g4f/webdriver.py | 75 +++++++++---- 10 files changed, 478 insertions(+), 104 deletions(-) create mode 100644 .github/workflows/unittest.yml create mode 100644 etc/unittest/main.py diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml new file mode 100644 index 00000000..e895e969 --- /dev/null +++ b/.github/workflows/unittest.yml @@ -0,0 +1,19 @@ +name: Unittest + +on: [push] + +jobs: + build: + name: Build unittest + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.x" + cache: 'pip' + - name: Install requirements + - run: pip install -r requirements.txt + - name: Run tests + run: python -m etc.unittest.main \ No newline at end of file diff --git a/etc/unittest/main.py b/etc/unittest/main.py new file mode 100644 index 00000000..61f4ffda --- /dev/null +++ b/etc/unittest/main.py @@ -0,0 +1,73 @@ +import sys +import pathlib +import unittest +from unittest.mock import MagicMock + +sys.path.append(str(pathlib.Path(__file__).parent.parent.parent)) + +import g4f +from g4f import ChatCompletion, get_last_provider +from g4f.gui.server.backend import Backend_Api, get_error_message +from g4f.base_provider import BaseProvider + +g4f.debug.logging = False + +class MockProvider(BaseProvider): + working = True + + def create_completion( + model, messages, stream, **kwargs + ): + yield "Mock" + + async def create_async( + model, messages, **kwargs + ): + return "Mock" + +class TestBackendApi(unittest.TestCase): + + def setUp(self): + self.app = MagicMock() + self.api = Backend_Api(self.app) + + def test_version(self): + response = self.api.get_version() + self.assertIn("version", response) + self.assertIn("latest_version", response) + +class TestChatCompletion(unittest.TestCase): + + def test_create(self): + messages = [{'role': 'user', 'content': 'Hello'}] + result = ChatCompletion.create(g4f.models.default, messages) + self.assertTrue("Hello" in result or "Good" in result) + + def test_get_last_provider(self): + messages = [{'role': 'user', 'content': 'Hello'}] + ChatCompletion.create(g4f.models.default, messages, MockProvider) + self.assertEqual(get_last_provider(), MockProvider) + + def test_bing_provider(self): + messages = [{'role': 'user', 'content': 'Hello'}] + provider = g4f.Provider.Bing + result = ChatCompletion.create(g4f.models.default, messages, provider) + self.assertTrue("Bing" in result) + +class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase): + + async def test_async(self): + messages = [{'role': 'user', 'content': 'Hello'}] + result = await ChatCompletion.create_async(g4f.models.default, messages, MockProvider) + self.assertTrue("Mock" in result) + +class TestUtilityFunctions(unittest.TestCase): + + def test_get_error_message(self): + g4f.debug.last_provider = g4f.Provider.Bing + exception = Exception("Message") + result = get_error_message(exception) + self.assertEqual("Bing: Exception: Message", result) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/g4f/Provider/Phind.py b/g4f/Provider/Phind.py index bb216989..9e80baa9 100644 --- a/g4f/Provider/Phind.py +++ b/g4f/Provider/Phind.py @@ -59,12 +59,16 @@ class Phind(AsyncGeneratorProvider): "rewrittenQuestion": prompt, "challenge": 0.21132115912208504 } - async with session.post(f"{cls.url}/api/infer/followup/answer", headers=headers, json=data) as response: + async with session.post(f"https://https.api.phind.com/infer/", headers=headers, json=data) as response: new_line = False async for line in response.iter_lines(): if line.startswith(b"data: "): chunk = line[6:] - if chunk.startswith(b"") or chunk.startswith(b""): + if chunk.startswith(b''): + break + if chunk.startswith(b'') or chunk.startswith(b''): + pass + elif chunk.startswith(b"") or chunk.startswith(b""): pass elif chunk: yield chunk.decode() diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 3c083bda..fd92d17a 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -36,6 +36,17 @@ class AbstractProvider(BaseProvider): ) -> str: """ Asynchronously creates a result based on the given model and messages. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. + executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + str: The created result as a string. """ loop = loop or get_event_loop() @@ -52,6 +63,12 @@ class AbstractProvider(BaseProvider): def params(cls) -> str: """ Returns the parameters supported by the provider. + + Args: + cls (type): The class on which this property is called. + + Returns: + str: A string listing the supported parameters. """ sig = signature( cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else @@ -90,6 +107,17 @@ class AsyncProvider(AbstractProvider): ) -> CreateResult: """ Creates a completion result synchronously. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + stream (bool): Indicates whether to stream the results. Defaults to False. + loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + CreateResult: The result of the completion creation. """ loop = loop or get_event_loop() coro = cls.create_async(model, messages, **kwargs) @@ -104,6 +132,17 @@ class AsyncProvider(AbstractProvider): ) -> str: """ Abstract method for creating asynchronous results. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process. + **kwargs: Additional keyword arguments. + + Raises: + NotImplementedError: If this method is not overridden in derived classes. + + Returns: + str: The created result as a string. """ raise NotImplementedError() @@ -126,6 +165,17 @@ class AsyncGeneratorProvider(AsyncProvider): ) -> CreateResult: """ Creates a streaming completion result synchronously. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + stream (bool): Indicates whether to stream the results. Defaults to True. + loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + CreateResult: The result of the streaming completion creation. """ loop = loop or get_event_loop() generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) @@ -146,6 +196,15 @@ class AsyncGeneratorProvider(AsyncProvider): ) -> str: """ Asynchronously creates a result from a generator. + + Args: + cls (type): The class on which this method is called. + model (str): The model to use for creation. + messages (Messages): The messages to process. + **kwargs: Additional keyword arguments. + + Returns: + str: The created result as a string. """ return "".join([ chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs) @@ -162,5 +221,17 @@ class AsyncGeneratorProvider(AsyncProvider): ) -> AsyncResult: """ Abstract method for creating an asynchronous generator. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process. + stream (bool): Indicates whether to stream the results. Defaults to True. + **kwargs: Additional keyword arguments. + + Raises: + NotImplementedError: If this method is not overridden in derived classes. + + Returns: + AsyncResult: An asynchronous generator yielding results. """ raise NotImplementedError() \ No newline at end of file diff --git a/g4f/Provider/bing/create_images.py b/g4f/Provider/bing/create_images.py index 29daccbd..060cd184 100644 --- a/g4f/Provider/bing/create_images.py +++ b/g4f/Provider/bing/create_images.py @@ -198,7 +198,7 @@ class CreateImagesBing: _cookies: Dict[str, str] = {} @classmethod - def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str]: + def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str, None, None]: """ Generator for creating imagecompletion based on a prompt. diff --git a/g4f/Provider/create_images.py b/g4f/Provider/create_images.py index f8a0442d..b8bcbde3 100644 --- a/g4f/Provider/create_images.py +++ b/g4f/Provider/create_images.py @@ -8,13 +8,31 @@ from ..base_provider import BaseProvider, ProviderType system_message = """ You can generate custom images with the DALL-E 3 image generator. -To generate a image with a prompt, do this: +To generate an image with a prompt, do this: Don't use images with data uri. It is important to use a prompt instead. """ class CreateImagesProvider(BaseProvider): + """ + Provider class for creating images based on text prompts. + + This provider handles image creation requests embedded within message content, + using provided image creation functions. + + Attributes: + provider (ProviderType): The underlying provider to handle non-image related tasks. + create_images (callable): A function to create images synchronously. + create_images_async (callable): A function to create images asynchronously. + system_message (str): A message that explains the image creation capability. + include_placeholder (bool): Flag to determine whether to include the image placeholder in the output. + __name__ (str): Name of the provider. + url (str): URL of the provider. + working (bool): Indicates if the provider is operational. + supports_stream (bool): Indicates if the provider supports streaming. + """ + def __init__( self, provider: ProviderType, @@ -23,6 +41,16 @@ class CreateImagesProvider(BaseProvider): system_message: str = system_message, include_placeholder: bool = True ) -> None: + """ + Initializes the CreateImagesProvider. + + Args: + provider (ProviderType): The underlying provider. + create_images (callable): Function to create images synchronously. + create_async (callable): Function to create images asynchronously. + system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message. + include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True. + """ self.provider = provider self.create_images = create_images self.create_images_async = create_async @@ -40,6 +68,22 @@ class CreateImagesProvider(BaseProvider): stream: bool = False, **kwargs ) -> CreateResult: + """ + Creates a completion result, processing any image creation prompts found within the messages. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process, which may contain image prompts. + stream (bool, optional): Indicates whether to stream the results. Defaults to False. + **kwargs: Additional keywordarguments for the provider. + + Yields: + CreateResult: Yields chunks of the processed messages, including image data if applicable. + + Note: + This method processes messages to detect image creation prompts. When such a prompt is found, + it calls the synchronous image creation function and includes the resulting image in the output. + """ messages.insert(0, {"role": "system", "content": self.system_message}) buffer = "" for chunk in self.provider.create_completion(model, messages, stream, **kwargs): @@ -71,6 +115,21 @@ class CreateImagesProvider(BaseProvider): messages: Messages, **kwargs ) -> str: + """ + Asynchronously creates a response, processing any image creation prompts found within the messages. + + Args: + model (str): The model to use for creation. + messages (Messages): The messages to process, which may contain image prompts. + **kwargs: Additional keyword arguments for the provider. + + Returns: + str: The processed response string, including asynchronously generated image data if applicable. + + Note: + This method processes messages to detect image creation prompts. When such a prompt is found, + it calls the asynchronous image creation function and includes the resulting image in the output. + """ messages.insert(0, {"role": "system", "content": self.system_message}) response = await self.provider.create_async(model, messages, **kwargs) matches = re.findall(r'()', response) diff --git a/g4f/gui/client/js/chat.v1.js b/g4f/gui/client/js/chat.v1.js index 7ed9f183..8b9bc181 100644 --- a/g4f/gui/client/js/chat.v1.js +++ b/g4f/gui/client/js/chat.v1.js @@ -652,9 +652,9 @@ observer.observe(message_input, { attributes: true }); document.title = 'g4f - gui - ' + versions["version"]; text = "version ~ " - if (versions["version"] != versions["lastet_version"]) { - release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["lastet_version"]; - text += '' + versions["version"] + ' 🆕'; + if (versions["version"] != versions["latest_version"]) { + release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["latest_version"]; + text += '' + versions["version"] + ' 🆕'; } else { text += versions["version"]; } diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py index 9d12bea5..4a5cafa8 100644 --- a/g4f/gui/server/backend.py +++ b/g4f/gui/server/backend.py @@ -1,6 +1,7 @@ import logging import json from flask import request, Flask +from typing import Generator from g4f import debug, version, models from g4f import _all_models, get_last_provider, ChatCompletion from g4f.image import is_allowed_extension, to_image @@ -11,60 +12,123 @@ from .internet import get_search_message debug.logging = True class Backend_Api: + """ + Handles various endpoints in a Flask application for backend operations. + + This class provides methods to interact with models, providers, and to handle + various functionalities like conversations, error handling, and version management. + + Attributes: + app (Flask): A Flask application instance. + routes (dict): A dictionary mapping API endpoints to their respective handlers. + """ def __init__(self, app: Flask) -> None: + """ + Initialize the backend API with the given Flask application. + + Args: + app (Flask): Flask application instance to attach routes to. + """ self.app: Flask = app self.routes = { '/backend-api/v2/models': { - 'function': self.models, - 'methods' : ['GET'] + 'function': self.get_models, + 'methods': ['GET'] }, '/backend-api/v2/providers': { - 'function': self.providers, - 'methods' : ['GET'] + 'function': self.get_providers, + 'methods': ['GET'] }, '/backend-api/v2/version': { - 'function': self.version, - 'methods' : ['GET'] + 'function': self.get_version, + 'methods': ['GET'] }, '/backend-api/v2/conversation': { - 'function': self._conversation, + 'function': self.handle_conversation, 'methods': ['POST'] }, '/backend-api/v2/gen.set.summarize:title': { - 'function': self._gen_title, + 'function': self.generate_title, 'methods': ['POST'] }, '/backend-api/v2/error': { - 'function': self.error, + 'function': self.handle_error, 'methods': ['POST'] } } - def error(self): + def handle_error(self): + """ + Initialize the backend API with the given Flask application. + + Args: + app (Flask): Flask application instance to attach routes to. + """ print(request.json) - return 'ok', 200 - def models(self): + def get_models(self): + """ + Return a list of all models. + + Fetches and returns a list of all available models in the system. + + Returns: + List[str]: A list of model names. + """ return _all_models - def providers(self): - return [ - provider.__name__ for provider in __providers__ if provider.working - ] + def get_providers(self): + """ + Return a list of all working providers. + """ + return [provider.__name__ for provider in __providers__ if provider.working] - def version(self): + def get_version(self): + """ + Returns the current and latest version of the application. + + Returns: + dict: A dictionary containing the current and latest version. + """ return { "version": version.utils.current_version, - "lastet_version": version.get_latest_version(), + "latest_version": version.get_latest_version(), } - def _gen_title(self): - return { - 'title': '' - } + def generate_title(self): + """ + Generates and returns a title based on the request data. + + Returns: + dict: A dictionary with the generated title. + """ + return {'title': ''} - def _conversation(self): + def handle_conversation(self): + """ + Handles conversation requests and streams responses back. + + Returns: + Response: A Flask response object for streaming. + """ + kwargs = self._prepare_conversation_kwargs() + + return self.app.response_class( + self._create_response_stream(kwargs), + mimetype='text/event-stream' + ) + + def _prepare_conversation_kwargs(self): + """ + Prepares arguments for chat completion based on the request data. + + Reads the request and prepares the necessary arguments for handling + a chat completion request. + + Returns: + dict: Arguments prepared for chat completion. + """ kwargs = {} if 'image' in request.files: file = request.files['image'] @@ -87,47 +151,70 @@ class Backend_Api: messages[-1]["content"] = get_search_message(messages[-1]["content"]) model = json_data.get('model') model = model if model else models.default - provider = json_data.get('provider', '').replace('g4f.Provider.', '') - provider = provider if provider and provider != "Auto" else None patch = patch_provider if json_data.get('patch_provider') else None - def try_response(): - try: - first = True - for chunk in ChatCompletion.create( - model=model, - provider=provider, - messages=messages, - stream=True, - ignore_stream_and_auth=True, - patch_provider=patch, - **kwargs - ): - if first: - first = False - yield json.dumps({ - 'type' : 'provider', - 'provider': get_last_provider(True) - }) + "\n" - if isinstance(chunk, Exception): - logging.exception(chunk) - yield json.dumps({ - 'type' : 'message', - 'message': get_error_message(chunk), - }) + "\n" - else: - yield json.dumps({ - 'type' : 'content', - 'content': str(chunk), - }) + "\n" - except Exception as e: - logging.exception(e) - yield json.dumps({ - 'type' : 'error', - 'error': get_error_message(e) - }) - - return self.app.response_class(try_response(), mimetype='text/event-stream') + return { + "model": model, + "provider": provider, + "messages": messages, + "stream": True, + "ignore_stream_and_auth": True, + "patch_provider": patch, + **kwargs + } + + def _create_response_stream(self, kwargs) -> Generator[str, None, None]: + """ + Creates and returns a streaming response for the conversation. + + Args: + kwargs (dict): Arguments for creating the chat completion. + + Yields: + str: JSON formatted response chunks for the stream. + + Raises: + Exception: If an error occurs during the streaming process. + """ + try: + first = True + for chunk in ChatCompletion.create(**kwargs): + if first: + first = False + yield self._format_json('provider', get_last_provider(True)) + if isinstance(chunk, Exception): + logging.exception(chunk) + yield self._format_json('message', get_error_message(chunk)) + else: + yield self._format_json('content', str(chunk)) + except Exception as e: + logging.exception(e) + yield self._format_json('error', get_error_message(e)) + + def _format_json(self, response_type: str, content) -> str: + """ + Formats and returns a JSON response. + + Args: + response_type (str): The type of the response. + content: The content to be included in the response. + + Returns: + str: A JSON formatted string. + """ + return json.dumps({ + 'type': response_type, + response_type: content + }) + "\n" def get_error_message(exception: Exception) -> str: + """ + Generates a formatted error message from an exception. + + Args: + exception (Exception): The exception to format. + + Returns: + str: A formatted error message string. + """ return f"{get_last_provider().__name__}: {type(exception).__name__}: {exception}" \ No newline at end of file diff --git a/g4f/version.py b/g4f/version.py index c976c8fd..9201c75c 100644 --- a/g4f/version.py +++ b/g4f/version.py @@ -7,10 +7,16 @@ from .errors import VersionNotFoundError def get_pypi_version(package_name: str) -> str: """ - Get the latest version of a package from PyPI. + Retrieves the latest version of a package from PyPI. - :param package_name: The name of the package. - :return: The latest version of the package as a string. + Args: + package_name (str): The name of the package for which to retrieve the version. + + Returns: + str: The latest version of the specified package from PyPI. + + Raises: + VersionNotFoundError: If there is an error in fetching the version from PyPI. """ try: response = requests.get(f"https://pypi.org/pypi/{package_name}/json").json() @@ -20,10 +26,16 @@ def get_pypi_version(package_name: str) -> str: def get_github_version(repo: str) -> str: """ - Get the latest release version from a GitHub repository. + Retrieves the latest release version from a GitHub repository. + + Args: + repo (str): The name of the GitHub repository. + + Returns: + str: The latest release version from the specified GitHub repository. - :param repo: The name of the GitHub repository. - :return: The latest release version as a string. + Raises: + VersionNotFoundError: If there is an error in fetching the version from GitHub. """ try: response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json() @@ -31,11 +43,16 @@ def get_github_version(repo: str) -> str: except requests.RequestException as e: raise VersionNotFoundError(f"Failed to get GitHub release version: {e}") -def get_latest_version(): +def get_latest_version() -> str: """ - Get the latest release version from PyPI or the GitHub repository. + Retrieves the latest release version of the 'g4f' package from PyPI or GitHub. - :return: The latest release version as a string. + Returns: + str: The latest release version of 'g4f'. + + Note: + The function first tries to fetch the version from PyPI. If the package is not found, + it retrieves the version from the GitHub repository. """ try: # Is installed via package manager? @@ -47,14 +64,19 @@ def get_latest_version(): class VersionUtils: """ - Utility class for managing and comparing package versions. + Utility class for managing and comparing package versions of 'g4f'. """ @cached_property def current_version(self) -> str: """ - Get the current version of the g4f package. + Retrieves the current version of the 'g4f' package. + + Returns: + str: The current version of 'g4f'. - :return: The current version as a string. + Raises: + VersionNotFoundError: If the version cannot be determined from the package manager, + Docker environment, or git repository. """ # Read from package manager try: @@ -79,15 +101,19 @@ class VersionUtils: @cached_property def latest_version(self) -> str: """ - Get the latest version of the g4f package. + Retrieves the latest version of the 'g4f' package. - :return: The latest version as a string. + Returns: + str: The latest version of 'g4f'. """ return get_latest_version() def check_version(self) -> None: """ - Check if the current version is up to date with the latest version. + Checks if the current version of 'g4f' is up to date with the latest version. + + Note: + If a newer version is available, it prints a message with the new version and update instructions. """ try: if self.current_version != self.latest_version: diff --git a/g4f/webdriver.py b/g4f/webdriver.py index e5ecd8bf..9a83215f 100644 --- a/g4f/webdriver.py +++ b/g4f/webdriver.py @@ -21,13 +21,16 @@ def get_browser( options: ChromeOptions = None ) -> WebDriver: """ - Creates and returns a Chrome WebDriver with the specified options. + Creates and returns a Chrome WebDriver with specified options. - :param user_data_dir: Directory for user data. If None, uses default directory. - :param headless: Boolean indicating whether to run the browser in headless mode. - :param proxy: Proxy settings for the browser. - :param options: ChromeOptions object with specific browser options. - :return: An instance of WebDriver. + Args: + user_data_dir (str, optional): Directory for user data. If None, uses default directory. + headless (bool, optional): Whether to run the browser in headless mode. Defaults to False. + proxy (str, optional): Proxy settings for the browser. Defaults to None. + options (ChromeOptions, optional): ChromeOptions object with specific browser options. Defaults to None. + + Returns: + WebDriver: An instance of WebDriver configured with the specified options. """ if user_data_dir is None: user_data_dir = user_config_dir("g4f") @@ -49,10 +52,13 @@ def get_browser( def get_driver_cookies(driver: WebDriver) -> dict: """ - Retrieves cookies from the given WebDriver. + Retrieves cookies from the specified WebDriver. + + Args: + driver (WebDriver): The WebDriver instance from which to retrieve cookies. - :param driver: WebDriver from which to retrieve cookies. - :return: A dictionary of cookies. + Returns: + dict: A dictionary containing cookies with their names as keys and values as cookie values. """ return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()} @@ -60,9 +66,13 @@ def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None: """ Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver. - :param driver: The WebDriver to use. - :param url: URL to access. - :param timeout: Time in seconds to wait for the page to load. + Args: + driver (WebDriver): The WebDriver to use for accessing the URL. + url (str): The URL to access. + timeout (int): Time in seconds to wait for the page to load. + + Raises: + Exception: If there is an error while bypassing Cloudflare or loading the page. """ driver.get(url) if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js": @@ -86,6 +96,7 @@ class WebDriverSession: """ Manages a Selenium WebDriver session, including handling of virtual displays and proxies. """ + def __init__( self, webdriver: WebDriver = None, @@ -95,6 +106,17 @@ class WebDriverSession: proxy: str = None, options: ChromeOptions = None ): + """ + Initializes a new instance of the WebDriverSession. + + Args: + webdriver (WebDriver, optional): A WebDriver instance for the session. Defaults to None. + user_data_dir (str, optional): Directory for user data. Defaults to None. + headless (bool, optional): Whether to run the browser in headless mode. Defaults to False. + virtual_display (bool, optional): Whether to use a virtual display. Defaults to False. + proxy (str, optional): Proxy settings for the browser. Defaults to None. + options (ChromeOptions, optional): ChromeOptions for the browser. Defaults to None. + """ self.webdriver = webdriver self.user_data_dir = user_data_dir self.headless = headless @@ -110,14 +132,17 @@ class WebDriverSession: virtual_display: bool = False ) -> WebDriver: """ - Reopens the WebDriver session with the specified parameters. + Reopens the WebDriver session with new settings. + + Args: + user_data_dir (str, optional): Directory for user data. Defaults to current value. + headless (bool, optional): Whether to run the browser in headless mode. Defaults to current value. + virtual_display (bool, optional): Whether to use a virtual display. Defaults to current value. - :param user_data_dir: Directory for user data. - :param headless: Boolean indicating whether to run the browser in headless mode. - :param virtual_display: Boolean indicating whether to use a virtual display. - :return: An instance of WebDriver. + Returns: + WebDriver: The reopened WebDriver instance. """ - user_data_dir = user_data_dir or self.user_data_dir + user_data_dir = user_data_data_dir or self.user_data_dir if self.default_driver: self.default_driver.quit() if not virtual_display and self.virtual_display: @@ -128,8 +153,10 @@ class WebDriverSession: def __enter__(self) -> WebDriver: """ - Context management method for entering a session. - :return: An instance of WebDriver. + Context management method for entering a session. Initializes and returns a WebDriver instance. + + Returns: + WebDriver: An instance of WebDriver for this session. """ if self.webdriver: return self.webdriver @@ -141,6 +168,14 @@ class WebDriverSession: def __exit__(self, exc_type, exc_val, exc_tb): """ Context management method for exiting a session. Closes and quits the WebDriver. + + Args: + exc_type: Exception type. + exc_val: Exception value. + exc_tb: Exception traceback. + + Note: + Closes the WebDriver and stops the virtual display if used. """ if self.default_driver: try: -- cgit v1.2.3