summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth/Theb.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/needs_auth/Theb.py29
1 files changed, 11 insertions, 18 deletions
diff --git a/g4f/Provider/needs_auth/Theb.py b/g4f/Provider/needs_auth/Theb.py
index cf33f0c6..49ee174b 100644
--- a/g4f/Provider/needs_auth/Theb.py
+++ b/g4f/Provider/needs_auth/Theb.py
@@ -4,7 +4,8 @@ import time
from ...typing import CreateResult, Messages
from ..base_provider import BaseProvider
-from ..helper import WebDriver, WebDriverSession, format_prompt
+from ..helper import format_prompt
+from ..webdriver import WebDriver, WebDriverSession
models = {
"theb-ai": "TheB.AI",
@@ -44,14 +45,14 @@ class Theb(BaseProvider):
messages: Messages,
stream: bool,
proxy: str = None,
- web_driver: WebDriver = None,
+ webdriver: WebDriver = None,
virtual_display: bool = True,
**kwargs
) -> CreateResult:
if model in models:
model = models[model]
prompt = format_prompt(messages)
- web_session = WebDriverSession(web_driver, virtual_display=virtual_display, proxy=proxy)
+ web_session = WebDriverSession(webdriver, virtual_display=virtual_display, proxy=proxy)
with web_session as driver:
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
@@ -61,22 +62,16 @@ class Theb(BaseProvider):
# Register fetch hook
script = """
window._fetch = window.fetch;
-window.fetch = (url, options) => {
+window.fetch = async (url, options) => {
// Call parent fetch method
- const result = window._fetch(url, options);
+ const response = await window._fetch(url, options);
if (!url.startsWith("/api/conversation")) {
return result;
}
- // Load response reader
- result.then((response) => {
- if (!response.body.locked) {
- window._reader = response.body.getReader();
- }
- });
- // Return dummy response
- return new Promise((resolve, reject) => {
- resolve(new Response(new ReadableStream()))
- });
+ // Copy response
+ copy = response.clone();
+ window._reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
+ return copy;
}
window._last_message = "";
"""
@@ -97,7 +92,6 @@ window._last_message = "";
wait = WebDriverWait(driver, 240)
wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
- time.sleep(200)
try:
driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
@@ -134,9 +128,8 @@ if(window._reader) {
if (chunk['done']) {
return null;
}
- text = (new TextDecoder()).decode(chunk['value']);
message = '';
- text.split('\\r\\n').forEach((line, index) => {
+ chunk['value'].split('\\r\\n').forEach((line, index) => {
if (line.startsWith('data: ')) {
try {
line = JSON.parse(line.substring('data: '.length));