summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/Cloudflare.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/Cloudflare.py136
1 files changed, 37 insertions, 99 deletions
diff --git a/g4f/Provider/Cloudflare.py b/g4f/Provider/Cloudflare.py
index 8fb37bef..825c5027 100644
--- a/g4f/Provider/Cloudflare.py
+++ b/g4f/Provider/Cloudflare.py
@@ -1,72 +1,52 @@
from __future__ import annotations
-from aiohttp import ClientSession
import asyncio
import json
import uuid
-import cloudscraper
-from typing import AsyncGenerator
-from ..typing import AsyncResult, Messages
-from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
-from .helper import format_prompt
+from ..typing import AsyncResult, Messages, Cookies
+from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
+from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
label = "Cloudflare AI"
url = "https://playground.ai.cloudflare.com"
api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
+ models_url = "https://playground.ai.cloudflare.com/api/models"
working = True
supports_stream = True
supports_system_message = True
supports_message_history = True
-
- default_model = '@cf/meta/llama-3.1-8b-instruct-awq'
- models = [
- '@cf/meta/llama-2-7b-chat-fp16',
- '@cf/meta/llama-2-7b-chat-int8',
-
- '@cf/meta/llama-3-8b-instruct',
- '@cf/meta/llama-3-8b-instruct-awq',
- '@hf/meta-llama/meta-llama-3-8b-instruct',
-
- default_model,
- '@cf/meta/llama-3.1-8b-instruct-fp8',
-
- '@cf/meta/llama-3.2-1b-instruct',
-
- '@hf/mistral/mistral-7b-instruct-v0.2',
-
- '@cf/qwen/qwen1.5-7b-chat-awq',
-
- '@cf/defog/sqlcoder-7b-2',
- ]
-
+ default_model = "@cf/meta/llama-3.1-8b-instruct"
model_aliases = {
"llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
"llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",
-
"llama-3-8b": "@cf/meta/llama-3-8b-instruct",
"llama-3-8b": "@cf/meta/llama-3-8b-instruct-awq",
"llama-3-8b": "@hf/meta-llama/meta-llama-3-8b-instruct",
-
"llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-awq",
"llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-fp8",
-
"llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct",
-
"qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq",
-
- #"sqlcoder-7b": "@cf/defog/sqlcoder-7b-2",
}
+ _args: dict = None
@classmethod
- def get_model(cls, model: str) -> str:
- if model in cls.models:
- return model
- elif model in cls.model_aliases:
- return cls.model_aliases[model]
- else:
- return cls.default_model
+ def get_models(cls) -> str:
+ if not cls.models:
+ if cls._args is None:
+ get_running_loop(check_nested=True)
+ args = get_args_from_nodriver(cls.url, cookies={
+ '__cf_bm': uuid.uuid4().hex,
+ })
+ cls._args = asyncio.run(args)
+ with Session(**cls._args) as session:
+ response = session.get(cls.models_url)
+ raise_for_status(response)
+ json_data = response.json()
+ cls.models = [model.get("name") for model in json_data.get("models")]
+ cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
+ return cls.models
@classmethod
async def create_async_generator(
@@ -75,76 +55,34 @@ class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages,
proxy: str = None,
max_tokens: int = 2048,
+ cookies: Cookies = None,
+ timeout: int = 300,
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
-
- headers = {
- 'Accept': 'text/event-stream',
- 'Accept-Language': 'en-US,en;q=0.9',
- 'Cache-Control': 'no-cache',
- 'Content-Type': 'application/json',
- 'Origin': cls.url,
- 'Pragma': 'no-cache',
- 'Referer': f'{cls.url}/',
- 'Sec-Ch-Ua': '"Chromium";v="129", "Not=A?Brand";v="8"',
- 'Sec-Ch-Ua-Mobile': '?0',
- 'Sec-Ch-Ua-Platform': '"Linux"',
- 'Sec-Fetch-Dest': 'empty',
- 'Sec-Fetch-Mode': 'cors',
- 'Sec-Fetch-Site': 'same-origin',
- 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36',
- }
-
- cookies = {
- '__cf_bm': uuid.uuid4().hex,
- }
-
- scraper = cloudscraper.create_scraper()
-
+ if cls._args is None:
+ cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
data = {
- "messages": [
- {"role": "user", "content": format_prompt(messages)}
- ],
+ "messages": messages,
"lora": None,
"model": model,
"max_tokens": max_tokens,
"stream": True
}
-
- max_retries = 3
- full_response = ""
-
- for attempt in range(max_retries):
- try:
- response = scraper.post(
- cls.api_endpoint,
- headers=headers,
- cookies=cookies,
- json=data,
- stream=True,
- proxies={'http': proxy, 'https': proxy} if proxy else None
- )
-
- if response.status_code == 403:
- await asyncio.sleep(2 ** attempt)
- continue
-
- response.raise_for_status()
-
- for line in response.iter_lines():
+ async with StreamSession(**cls._args) as session:
+ async with session.post(
+ cls.api_endpoint,
+ json=data,
+ ) as response:
+ await raise_for_status(response)
+ cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
+ async for line in response.iter_lines():
if line.startswith(b'data: '):
if line == b'data: [DONE]':
- if full_response:
- yield full_response
break
try:
- content = json.loads(line[6:].decode('utf-8'))
- if 'response' in content and content['response'] != '</s>':
+ content = json.loads(line[6:].decode())
+ if content.get("response") and content.get("response") != '</s>':
yield content['response']
except Exception:
- continue
- break
- except Exception as e:
- if attempt == max_retries - 1:
- raise
+ continue \ No newline at end of file