summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/helper.py
blob: fce1ee6fff6cc6587fd54371a346ddfe008eb7f4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from __future__ import annotations

import asyncio
import os
import random
import secrets
import string
from asyncio import AbstractEventLoop, BaseEventLoop
from platformdirs import user_config_dir
from browser_cookie3 import (
    chrome, chromium, opera, opera_gx,
    brave, edge, vivaldi, firefox,
    _LinuxPasswordManager, BrowserCookieError
)
from ..typing import Dict, Messages
from .. import debug

# Global variable to store cookies
_cookies: Dict[str, Dict[str, str]] = {}

def get_event_loop() -> AbstractEventLoop:
    """
    Get the current asyncio event loop. If the loop is closed or not set, create a new event loop.
    If a loop is running, handle nested event loops. Patch the loop if 'nest_asyncio' is installed.

    Returns:
        AbstractEventLoop: The current or new event loop.
    """
    try:
        loop = asyncio.get_event_loop()
        if isinstance(loop, BaseEventLoop):
            loop._check_closed()
    except RuntimeError:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
    try:
        asyncio.get_running_loop()
        if not hasattr(loop.__class__, "_nest_patched"):
            import nest_asyncio
            nest_asyncio.apply(loop)
    except RuntimeError:
        pass
    except ImportError:
        raise RuntimeError(
            'Use "create_async" instead of "create" function in a running event loop. Or install "nest_asyncio" package.'
        )
    return loop

if os.environ.get('DBUS_SESSION_BUS_ADDRESS') == "/dev/null":
    _LinuxPasswordManager.get_password = lambda a, b: b"secret"

def get_cookies(domain_name: str = '') -> Dict[str, str]:
    """
    Load cookies for a given domain from all supported browsers and cache the results.

    Args:
        domain_name (str): The domain for which to load cookies.

    Returns:
        Dict[str, str]: A dictionary of cookie names and values.
    """
    if domain_name in _cookies:
        return _cookies[domain_name]
    
    cookies = _load_cookies_from_browsers(domain_name)
    _cookies[domain_name] = cookies
    return cookies

def _load_cookies_from_browsers(domain_name: str) -> Dict[str, str]:
    """
    Helper function to load cookies from various browsers.

    Args:
        domain_name (str): The domain for which to load cookies.

    Returns:
        Dict[str, str]: A dictionary of cookie names and values.
    """
    cookies = {}
    for cookie_fn in [_g4f, chrome, chromium, opera, opera_gx, brave, edge, vivaldi, firefox]:
        try:
            cookie_jar = cookie_fn(domain_name=domain_name)
            if len(cookie_jar) and debug.logging:
                print(f"Read cookies from {cookie_fn.__name__} for {domain_name}")
            for cookie in cookie_jar:
                if cookie.name not in cookies:
                    cookies[cookie.name] = cookie.value
        except BrowserCookieError:
            pass
        except Exception as e:
            if debug.logging:
                print(f"Error reading cookies from {cookie_fn.__name__} for {domain_name}: {e}")
    return cookies

def _g4f(domain_name: str) -> list:
    """
    Load cookies from the 'g4f' browser (if exists).

    Args:
        domain_name (str): The domain for which to load cookies.

    Returns:
        list: List of cookies.
    """
    user_data_dir = user_config_dir("g4f")
    cookie_file = os.path.join(user_data_dir, "Default", "Cookies")
    return [] if not os.path.exists(cookie_file) else chrome(cookie_file, domain_name)

def format_prompt(messages: Messages, add_special_tokens=False) -> str:
    """
    Format a series of messages into a single string, optionally adding special tokens.

    Args:
        messages (Messages): A list of message dictionaries, each containing 'role' and 'content'.
        add_special_tokens (bool): Whether to add special formatting tokens.

    Returns:
        str: A formatted string containing all messages.
    """
    if not add_special_tokens and len(messages) <= 1:
        return messages[0]["content"]
    formatted = "\n".join([
        f'{message["role"].capitalize()}: {message["content"]}'
        for message in messages
    ])
    return f"{formatted}\nAssistant:"

def get_random_string(length: int = 10) -> str:
    """
    Generate a random string of specified length, containing lowercase letters and digits.

    Args:
        length (int, optional): Length of the random string to generate. Defaults to 10.

    Returns:
        str: A random string of the specified length.
    """
    return ''.join(
        random.choice(string.ascii_lowercase + string.digits)
        for _ in range(length)
    )

def get_random_hex() -> str:
    """
    Generate a random hexadecimal string of a fixed length.

    Returns:
        str: A random hexadecimal string of 32 characters (16 bytes).
    """
    return secrets.token_hex(16).zfill(32)