summaryrefslogtreecommitdiffstats
path: root/g4f
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/Phind.py8
-rw-r--r--g4f/Provider/base_provider.py71
-rw-r--r--g4f/Provider/bing/create_images.py2
-rw-r--r--g4f/Provider/create_images.py61
-rw-r--r--g4f/gui/client/js/chat.v1.js6
-rw-r--r--g4f/gui/server/backend.py211
-rw-r--r--g4f/version.py56
-rw-r--r--g4f/webdriver.py75
8 files changed, 386 insertions, 104 deletions
diff --git a/g4f/Provider/Phind.py b/g4f/Provider/Phind.py
index bb216989..9e80baa9 100644
--- a/g4f/Provider/Phind.py
+++ b/g4f/Provider/Phind.py
@@ -59,12 +59,16 @@ class Phind(AsyncGeneratorProvider):
"rewrittenQuestion": prompt,
"challenge": 0.21132115912208504
}
- async with session.post(f"{cls.url}/api/infer/followup/answer", headers=headers, json=data) as response:
+ async with session.post(f"https://https.api.phind.com/infer/", headers=headers, json=data) as response:
new_line = False
async for line in response.iter_lines():
if line.startswith(b"data: "):
chunk = line[6:]
- if chunk.startswith(b"<PHIND_METADATA>") or chunk.startswith(b"<PHIND_INDICATOR>"):
+ if chunk.startswith(b'<PHIND_DONE/>'):
+ break
+ if chunk.startswith(b'<PHIND_WEBRESULTS>') or chunk.startswith(b'<PHIND_FOLLOWUP>'):
+ pass
+ elif chunk.startswith(b"<PHIND_METADATA>") or chunk.startswith(b"<PHIND_INDICATOR>"):
pass
elif chunk:
yield chunk.decode()
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 3c083bda..fd92d17a 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -36,6 +36,17 @@ class AbstractProvider(BaseProvider):
) -> str:
"""
Asynchronously creates a result based on the given model and messages.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The created result as a string.
"""
loop = loop or get_event_loop()
@@ -52,6 +63,12 @@ class AbstractProvider(BaseProvider):
def params(cls) -> str:
"""
Returns the parameters supported by the provider.
+
+ Args:
+ cls (type): The class on which this property is called.
+
+ Returns:
+ str: A string listing the supported parameters.
"""
sig = signature(
cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
@@ -90,6 +107,17 @@ class AsyncProvider(AbstractProvider):
) -> CreateResult:
"""
Creates a completion result synchronously.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to False.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the completion creation.
"""
loop = loop or get_event_loop()
coro = cls.create_async(model, messages, **kwargs)
@@ -104,6 +132,17 @@ class AsyncProvider(AbstractProvider):
) -> str:
"""
Abstract method for creating asynchronous results.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ NotImplementedError: If this method is not overridden in derived classes.
+
+ Returns:
+ str: The created result as a string.
"""
raise NotImplementedError()
@@ -126,6 +165,17 @@ class AsyncGeneratorProvider(AsyncProvider):
) -> CreateResult:
"""
Creates a streaming completion result synchronously.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to True.
+ loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ CreateResult: The result of the streaming completion creation.
"""
loop = loop or get_event_loop()
generator = cls.create_async_generator(model, messages, stream=stream, **kwargs)
@@ -146,6 +196,15 @@ class AsyncGeneratorProvider(AsyncProvider):
) -> str:
"""
Asynchronously creates a result from a generator.
+
+ Args:
+ cls (type): The class on which this method is called.
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ str: The created result as a string.
"""
return "".join([
chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
@@ -162,5 +221,17 @@ class AsyncGeneratorProvider(AsyncProvider):
) -> AsyncResult:
"""
Abstract method for creating an asynchronous generator.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process.
+ stream (bool): Indicates whether to stream the results. Defaults to True.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ NotImplementedError: If this method is not overridden in derived classes.
+
+ Returns:
+ AsyncResult: An asynchronous generator yielding results.
"""
raise NotImplementedError() \ No newline at end of file
diff --git a/g4f/Provider/bing/create_images.py b/g4f/Provider/bing/create_images.py
index 29daccbd..060cd184 100644
--- a/g4f/Provider/bing/create_images.py
+++ b/g4f/Provider/bing/create_images.py
@@ -198,7 +198,7 @@ class CreateImagesBing:
_cookies: Dict[str, str] = {}
@classmethod
- def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str]:
+ def create_completion(cls, prompt: str, cookies: Dict[str, str] = None, proxy: str = None) -> Generator[str, None, None]:
"""
Generator for creating imagecompletion based on a prompt.
diff --git a/g4f/Provider/create_images.py b/g4f/Provider/create_images.py
index f8a0442d..b8bcbde3 100644
--- a/g4f/Provider/create_images.py
+++ b/g4f/Provider/create_images.py
@@ -8,13 +8,31 @@ from ..base_provider import BaseProvider, ProviderType
system_message = """
You can generate custom images with the DALL-E 3 image generator.
-To generate a image with a prompt, do this:
+To generate an image with a prompt, do this:
<img data-prompt=\"keywords for the image\">
Don't use images with data uri. It is important to use a prompt instead.
<img data-prompt=\"image caption\">
"""
class CreateImagesProvider(BaseProvider):
+ """
+ Provider class for creating images based on text prompts.
+
+ This provider handles image creation requests embedded within message content,
+ using provided image creation functions.
+
+ Attributes:
+ provider (ProviderType): The underlying provider to handle non-image related tasks.
+ create_images (callable): A function to create images synchronously.
+ create_images_async (callable): A function to create images asynchronously.
+ system_message (str): A message that explains the image creation capability.
+ include_placeholder (bool): Flag to determine whether to include the image placeholder in the output.
+ __name__ (str): Name of the provider.
+ url (str): URL of the provider.
+ working (bool): Indicates if the provider is operational.
+ supports_stream (bool): Indicates if the provider supports streaming.
+ """
+
def __init__(
self,
provider: ProviderType,
@@ -23,6 +41,16 @@ class CreateImagesProvider(BaseProvider):
system_message: str = system_message,
include_placeholder: bool = True
) -> None:
+ """
+ Initializes the CreateImagesProvider.
+
+ Args:
+ provider (ProviderType): The underlying provider.
+ create_images (callable): Function to create images synchronously.
+ create_async (callable): Function to create images asynchronously.
+ system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message.
+ include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True.
+ """
self.provider = provider
self.create_images = create_images
self.create_images_async = create_async
@@ -40,6 +68,22 @@ class CreateImagesProvider(BaseProvider):
stream: bool = False,
**kwargs
) -> CreateResult:
+ """
+ Creates a completion result, processing any image creation prompts found within the messages.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process, which may contain image prompts.
+ stream (bool, optional): Indicates whether to stream the results. Defaults to False.
+ **kwargs: Additional keywordarguments for the provider.
+
+ Yields:
+ CreateResult: Yields chunks of the processed messages, including image data if applicable.
+
+ Note:
+ This method processes messages to detect image creation prompts. When such a prompt is found,
+ it calls the synchronous image creation function and includes the resulting image in the output.
+ """
messages.insert(0, {"role": "system", "content": self.system_message})
buffer = ""
for chunk in self.provider.create_completion(model, messages, stream, **kwargs):
@@ -71,6 +115,21 @@ class CreateImagesProvider(BaseProvider):
messages: Messages,
**kwargs
) -> str:
+ """
+ Asynchronously creates a response, processing any image creation prompts found within the messages.
+
+ Args:
+ model (str): The model to use for creation.
+ messages (Messages): The messages to process, which may contain image prompts.
+ **kwargs: Additional keyword arguments for the provider.
+
+ Returns:
+ str: The processed response string, including asynchronously generated image data if applicable.
+
+ Note:
+ This method processes messages to detect image creation prompts. When such a prompt is found,
+ it calls the asynchronous image creation function and includes the resulting image in the output.
+ """
messages.insert(0, {"role": "system", "content": self.system_message})
response = await self.provider.create_async(model, messages, **kwargs)
matches = re.findall(r'(<img data-prompt="(.*?)">)', response)
diff --git a/g4f/gui/client/js/chat.v1.js b/g4f/gui/client/js/chat.v1.js
index 7ed9f183..8b9bc181 100644
--- a/g4f/gui/client/js/chat.v1.js
+++ b/g4f/gui/client/js/chat.v1.js
@@ -652,9 +652,9 @@ observer.observe(message_input, { attributes: true });
document.title = 'g4f - gui - ' + versions["version"];
text = "version ~ "
- if (versions["version"] != versions["lastet_version"]) {
- release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["lastet_version"];
- text += '<a href="' + release_url +'" target="_blank" title="New version: ' + versions["lastet_version"] +'">' + versions["version"] + ' 🆕</a>';
+ if (versions["version"] != versions["latest_version"]) {
+ release_url = 'https://github.com/xtekky/gpt4free/releases/tag/' + versions["latest_version"];
+ text += '<a href="' + release_url +'" target="_blank" title="New version: ' + versions["latest_version"] +'">' + versions["version"] + ' 🆕</a>';
} else {
text += versions["version"];
}
diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py
index 9d12bea5..4a5cafa8 100644
--- a/g4f/gui/server/backend.py
+++ b/g4f/gui/server/backend.py
@@ -1,6 +1,7 @@
import logging
import json
from flask import request, Flask
+from typing import Generator
from g4f import debug, version, models
from g4f import _all_models, get_last_provider, ChatCompletion
from g4f.image import is_allowed_extension, to_image
@@ -11,60 +12,123 @@ from .internet import get_search_message
debug.logging = True
class Backend_Api:
+ """
+ Handles various endpoints in a Flask application for backend operations.
+
+ This class provides methods to interact with models, providers, and to handle
+ various functionalities like conversations, error handling, and version management.
+
+ Attributes:
+ app (Flask): A Flask application instance.
+ routes (dict): A dictionary mapping API endpoints to their respective handlers.
+ """
def __init__(self, app: Flask) -> None:
+ """
+ Initialize the backend API with the given Flask application.
+
+ Args:
+ app (Flask): Flask application instance to attach routes to.
+ """
self.app: Flask = app
self.routes = {
'/backend-api/v2/models': {
- 'function': self.models,
- 'methods' : ['GET']
+ 'function': self.get_models,
+ 'methods': ['GET']
},
'/backend-api/v2/providers': {
- 'function': self.providers,
- 'methods' : ['GET']
+ 'function': self.get_providers,
+ 'methods': ['GET']
},
'/backend-api/v2/version': {
- 'function': self.version,
- 'methods' : ['GET']
+ 'function': self.get_version,
+ 'methods': ['GET']
},
'/backend-api/v2/conversation': {
- 'function': self._conversation,
+ 'function': self.handle_conversation,
'methods': ['POST']
},
'/backend-api/v2/gen.set.summarize:title': {
- 'function': self._gen_title,
+ 'function': self.generate_title,
'methods': ['POST']
},
'/backend-api/v2/error': {
- 'function': self.error,
+ 'function': self.handle_error,
'methods': ['POST']
}
}
- def error(self):
+ def handle_error(self):
+ """
+ Initialize the backend API with the given Flask application.
+
+ Args:
+ app (Flask): Flask application instance to attach routes to.
+ """
print(request.json)
-
return 'ok', 200
- def models(self):
+ def get_models(self):
+ """
+ Return a list of all models.
+
+ Fetches and returns a list of all available models in the system.
+
+ Returns:
+ List[str]: A list of model names.
+ """
return _all_models
- def providers(self):
- return [
- provider.__name__ for provider in __providers__ if provider.working
- ]
+ def get_providers(self):
+ """
+ Return a list of all working providers.
+ """
+ return [provider.__name__ for provider in __providers__ if provider.working]
- def version(self):
+ def get_version(self):
+ """
+ Returns the current and latest version of the application.
+
+ Returns:
+ dict: A dictionary containing the current and latest version.
+ """
return {
"version": version.utils.current_version,
- "lastet_version": version.get_latest_version(),
+ "latest_version": version.get_latest_version(),
}
- def _gen_title(self):
- return {
- 'title': ''
- }
+ def generate_title(self):
+ """
+ Generates and returns a title based on the request data.
+
+ Returns:
+ dict: A dictionary with the generated title.
+ """
+ return {'title': ''}
- def _conversation(self):
+ def handle_conversation(self):
+ """
+ Handles conversation requests and streams responses back.
+
+ Returns:
+ Response: A Flask response object for streaming.
+ """
+ kwargs = self._prepare_conversation_kwargs()
+
+ return self.app.response_class(
+ self._create_response_stream(kwargs),
+ mimetype='text/event-stream'
+ )
+
+ def _prepare_conversation_kwargs(self):
+ """
+ Prepares arguments for chat completion based on the request data.
+
+ Reads the request and prepares the necessary arguments for handling
+ a chat completion request.
+
+ Returns:
+ dict: Arguments prepared for chat completion.
+ """
kwargs = {}
if 'image' in request.files:
file = request.files['image']
@@ -87,47 +151,70 @@ class Backend_Api:
messages[-1]["content"] = get_search_message(messages[-1]["content"])
model = json_data.get('model')
model = model if model else models.default
- provider = json_data.get('provider', '').replace('g4f.Provider.', '')
- provider = provider if provider and provider != "Auto" else None
patch = patch_provider if json_data.get('patch_provider') else None
- def try_response():
- try:
- first = True
- for chunk in ChatCompletion.create(
- model=model,
- provider=provider,
- messages=messages,
- stream=True,
- ignore_stream_and_auth=True,
- patch_provider=patch,
- **kwargs
- ):
- if first:
- first = False
- yield json.dumps({
- 'type' : 'provider',
- 'provider': get_last_provider(True)
- }) + "\n"
- if isinstance(chunk, Exception):
- logging.exception(chunk)
- yield json.dumps({
- 'type' : 'message',
- 'message': get_error_message(chunk),
- }) + "\n"
- else:
- yield json.dumps({
- 'type' : 'content',
- 'content': str(chunk),
- }) + "\n"
- except Exception as e:
- logging.exception(e)
- yield json.dumps({
- 'type' : 'error',
- 'error': get_error_message(e)
- })
-
- return self.app.response_class(try_response(), mimetype='text/event-stream')
+ return {
+ "model": model,
+ "provider": provider,
+ "messages": messages,
+ "stream": True,
+ "ignore_stream_and_auth": True,
+ "patch_provider": patch,
+ **kwargs
+ }
+
+ def _create_response_stream(self, kwargs) -> Generator[str, None, None]:
+ """
+ Creates and returns a streaming response for the conversation.
+
+ Args:
+ kwargs (dict): Arguments for creating the chat completion.
+
+ Yields:
+ str: JSON formatted response chunks for the stream.
+
+ Raises:
+ Exception: If an error occurs during the streaming process.
+ """
+ try:
+ first = True
+ for chunk in ChatCompletion.create(**kwargs):
+ if first:
+ first = False
+ yield self._format_json('provider', get_last_provider(True))
+ if isinstance(chunk, Exception):
+ logging.exception(chunk)
+ yield self._format_json('message', get_error_message(chunk))
+ else:
+ yield self._format_json('content', str(chunk))
+ except Exception as e:
+ logging.exception(e)
+ yield self._format_json('error', get_error_message(e))
+
+ def _format_json(self, response_type: str, content) -> str:
+ """
+ Formats and returns a JSON response.
+
+ Args:
+ response_type (str): The type of the response.
+ content: The content to be included in the response.
+
+ Returns:
+ str: A JSON formatted string.
+ """
+ return json.dumps({
+ 'type': response_type,
+ response_type: content
+ }) + "\n"
def get_error_message(exception: Exception) -> str:
+ """
+ Generates a formatted error message from an exception.
+
+ Args:
+ exception (Exception): The exception to format.
+
+ Returns:
+ str: A formatted error message string.
+ """
return f"{get_last_provider().__name__}: {type(exception).__name__}: {exception}" \ No newline at end of file
diff --git a/g4f/version.py b/g4f/version.py
index c976c8fd..9201c75c 100644
--- a/g4f/version.py
+++ b/g4f/version.py
@@ -7,10 +7,16 @@ from .errors import VersionNotFoundError
def get_pypi_version(package_name: str) -> str:
"""
- Get the latest version of a package from PyPI.
+ Retrieves the latest version of a package from PyPI.
- :param package_name: The name of the package.
- :return: The latest version of the package as a string.
+ Args:
+ package_name (str): The name of the package for which to retrieve the version.
+
+ Returns:
+ str: The latest version of the specified package from PyPI.
+
+ Raises:
+ VersionNotFoundError: If there is an error in fetching the version from PyPI.
"""
try:
response = requests.get(f"https://pypi.org/pypi/{package_name}/json").json()
@@ -20,10 +26,16 @@ def get_pypi_version(package_name: str) -> str:
def get_github_version(repo: str) -> str:
"""
- Get the latest release version from a GitHub repository.
+ Retrieves the latest release version from a GitHub repository.
+
+ Args:
+ repo (str): The name of the GitHub repository.
+
+ Returns:
+ str: The latest release version from the specified GitHub repository.
- :param repo: The name of the GitHub repository.
- :return: The latest release version as a string.
+ Raises:
+ VersionNotFoundError: If there is an error in fetching the version from GitHub.
"""
try:
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json()
@@ -31,11 +43,16 @@ def get_github_version(repo: str) -> str:
except requests.RequestException as e:
raise VersionNotFoundError(f"Failed to get GitHub release version: {e}")
-def get_latest_version():
+def get_latest_version() -> str:
"""
- Get the latest release version from PyPI or the GitHub repository.
+ Retrieves the latest release version of the 'g4f' package from PyPI or GitHub.
- :return: The latest release version as a string.
+ Returns:
+ str: The latest release version of 'g4f'.
+
+ Note:
+ The function first tries to fetch the version from PyPI. If the package is not found,
+ it retrieves the version from the GitHub repository.
"""
try:
# Is installed via package manager?
@@ -47,14 +64,19 @@ def get_latest_version():
class VersionUtils:
"""
- Utility class for managing and comparing package versions.
+ Utility class for managing and comparing package versions of 'g4f'.
"""
@cached_property
def current_version(self) -> str:
"""
- Get the current version of the g4f package.
+ Retrieves the current version of the 'g4f' package.
+
+ Returns:
+ str: The current version of 'g4f'.
- :return: The current version as a string.
+ Raises:
+ VersionNotFoundError: If the version cannot be determined from the package manager,
+ Docker environment, or git repository.
"""
# Read from package manager
try:
@@ -79,15 +101,19 @@ class VersionUtils:
@cached_property
def latest_version(self) -> str:
"""
- Get the latest version of the g4f package.
+ Retrieves the latest version of the 'g4f' package.
- :return: The latest version as a string.
+ Returns:
+ str: The latest version of 'g4f'.
"""
return get_latest_version()
def check_version(self) -> None:
"""
- Check if the current version is up to date with the latest version.
+ Checks if the current version of 'g4f' is up to date with the latest version.
+
+ Note:
+ If a newer version is available, it prints a message with the new version and update instructions.
"""
try:
if self.current_version != self.latest_version:
diff --git a/g4f/webdriver.py b/g4f/webdriver.py
index e5ecd8bf..9a83215f 100644
--- a/g4f/webdriver.py
+++ b/g4f/webdriver.py
@@ -21,13 +21,16 @@ def get_browser(
options: ChromeOptions = None
) -> WebDriver:
"""
- Creates and returns a Chrome WebDriver with the specified options.
+ Creates and returns a Chrome WebDriver with specified options.
- :param user_data_dir: Directory for user data. If None, uses default directory.
- :param headless: Boolean indicating whether to run the browser in headless mode.
- :param proxy: Proxy settings for the browser.
- :param options: ChromeOptions object with specific browser options.
- :return: An instance of WebDriver.
+ Args:
+ user_data_dir (str, optional): Directory for user data. If None, uses default directory.
+ headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
+ proxy (str, optional): Proxy settings for the browser. Defaults to None.
+ options (ChromeOptions, optional): ChromeOptions object with specific browser options. Defaults to None.
+
+ Returns:
+ WebDriver: An instance of WebDriver configured with the specified options.
"""
if user_data_dir is None:
user_data_dir = user_config_dir("g4f")
@@ -49,10 +52,13 @@ def get_browser(
def get_driver_cookies(driver: WebDriver) -> dict:
"""
- Retrieves cookies from the given WebDriver.
+ Retrieves cookies from the specified WebDriver.
+
+ Args:
+ driver (WebDriver): The WebDriver instance from which to retrieve cookies.
- :param driver: WebDriver from which to retrieve cookies.
- :return: A dictionary of cookies.
+ Returns:
+ dict: A dictionary containing cookies with their names as keys and values as cookie values.
"""
return {cookie["name"]: cookie["value"] for cookie in driver.get_cookies()}
@@ -60,9 +66,13 @@ def bypass_cloudflare(driver: WebDriver, url: str, timeout: int) -> None:
"""
Attempts to bypass Cloudflare protection when accessing a URL using the provided WebDriver.
- :param driver: The WebDriver to use.
- :param url: URL to access.
- :param timeout: Time in seconds to wait for the page to load.
+ Args:
+ driver (WebDriver): The WebDriver to use for accessing the URL.
+ url (str): The URL to access.
+ timeout (int): Time in seconds to wait for the page to load.
+
+ Raises:
+ Exception: If there is an error while bypassing Cloudflare or loading the page.
"""
driver.get(url)
if driver.find_element(By.TAG_NAME, "body").get_attribute("class") == "no-js":
@@ -86,6 +96,7 @@ class WebDriverSession:
"""
Manages a Selenium WebDriver session, including handling of virtual displays and proxies.
"""
+
def __init__(
self,
webdriver: WebDriver = None,
@@ -95,6 +106,17 @@ class WebDriverSession:
proxy: str = None,
options: ChromeOptions = None
):
+ """
+ Initializes a new instance of the WebDriverSession.
+
+ Args:
+ webdriver (WebDriver, optional): A WebDriver instance for the session. Defaults to None.
+ user_data_dir (str, optional): Directory for user data. Defaults to None.
+ headless (bool, optional): Whether to run the browser in headless mode. Defaults to False.
+ virtual_display (bool, optional): Whether to use a virtual display. Defaults to False.
+ proxy (str, optional): Proxy settings for the browser. Defaults to None.
+ options (ChromeOptions, optional): ChromeOptions for the browser. Defaults to None.
+ """
self.webdriver = webdriver
self.user_data_dir = user_data_dir
self.headless = headless
@@ -110,14 +132,17 @@ class WebDriverSession:
virtual_display: bool = False
) -> WebDriver:
"""
- Reopens the WebDriver session with the specified parameters.
+ Reopens the WebDriver session with new settings.
+
+ Args:
+ user_data_dir (str, optional): Directory for user data. Defaults to current value.
+ headless (bool, optional): Whether to run the browser in headless mode. Defaults to current value.
+ virtual_display (bool, optional): Whether to use a virtual display. Defaults to current value.
- :param user_data_dir: Directory for user data.
- :param headless: Boolean indicating whether to run the browser in headless mode.
- :param virtual_display: Boolean indicating whether to use a virtual display.
- :return: An instance of WebDriver.
+ Returns:
+ WebDriver: The reopened WebDriver instance.
"""
- user_data_dir = user_data_dir or self.user_data_dir
+ user_data_dir = user_data_data_dir or self.user_data_dir
if self.default_driver:
self.default_driver.quit()
if not virtual_display and self.virtual_display:
@@ -128,8 +153,10 @@ class WebDriverSession:
def __enter__(self) -> WebDriver:
"""
- Context management method for entering a session.
- :return: An instance of WebDriver.
+ Context management method for entering a session. Initializes and returns a WebDriver instance.
+
+ Returns:
+ WebDriver: An instance of WebDriver for this session.
"""
if self.webdriver:
return self.webdriver
@@ -141,6 +168,14 @@ class WebDriverSession:
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Context management method for exiting a session. Closes and quits the WebDriver.
+
+ Args:
+ exc_type: Exception type.
+ exc_val: Exception value.
+ exc_tb: Exception traceback.
+
+ Note:
+ Closes the WebDriver and stops the virtual display if used.
"""
if self.default_driver:
try: