diff options
Diffstat (limited to '')
-rw-r--r-- | etc/tool/create_provider.py | 51 | ||||
-rw-r--r-- | etc/tool/improve_code.py | 4 |
2 files changed, 39 insertions, 16 deletions
diff --git a/etc/tool/create_provider.py b/etc/tool/create_provider.py index ff04f961..7a9827a8 100644 --- a/etc/tool/create_provider.py +++ b/etc/tool/create_provider.py @@ -33,14 +33,35 @@ from __future__ import annotations from aiohttp import ClientSession from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .helper import format_prompt -class ChatGpt(AsyncGeneratorProvider): - url = "https://chat-gpt.com" +class {name}(AsyncGeneratorProvider, ProviderModelMixin): + label = "" + url = "https://example.com" + api_endpoint = "https://example.com/api/completion" working = True - supports_gpt_35_turbo = True + needs_auth = False + supports_stream = True + supports_system_message = True + supports_message_history = True + + default_model = '' + models = ['', ''] + + model_aliases = { + "alias1": "model1", + } + + @classmethod + def get_model(cls, model: str) -> str: + if model in cls.models: + return model + elif model in cls.model_aliases: + return cls.model_aliases[model] + else: + return cls.default_model @classmethod async def create_async_generator( @@ -50,19 +71,21 @@ class ChatGpt(AsyncGeneratorProvider): proxy: str = None, **kwargs ) -> AsyncResult: - headers = { - "authority": "chat-gpt.com", + model = cls.get_model(model) + + headers = {{ + "authority": "example.com", "accept": "application/json", "origin": cls.url, - "referer": f"{cls.url}/chat", - } + "referer": f"{{cls.url}}/chat", + }} async with ClientSession(headers=headers) as session: prompt = format_prompt(messages) - data = { + data = {{ "prompt": prompt, - "purpose": "", - } - async with session.post(f"{cls.url}/api/chat", json=data, proxy=proxy) as response: + "model": model, + }} + async with session.post(f"{{cls.url}}/api/chat", json=data, proxy=proxy) as response: response.raise_for_status() async for chunk in response.content: if chunk: @@ -78,7 +101,7 @@ Create a provider from a cURL command. The command is: {command} ``` A example for a provider: -```py +```python {example} ``` The name for the provider class: @@ -90,7 +113,7 @@ And replace "gpt-3.5-turbo" with `model`. print("Create code...") response = [] for chunk in g4f.ChatCompletion.create( - model=g4f.models.gpt_35_long, + model=g4f.models.default, messages=[{"role": "user", "content": prompt}], timeout=300, stream=True, diff --git a/etc/tool/improve_code.py b/etc/tool/improve_code.py index b2e36f86..8578b478 100644 --- a/etc/tool/improve_code.py +++ b/etc/tool/improve_code.py @@ -30,7 +30,7 @@ Don't remove license comments. print("Create code...") response = [] for chunk in g4f.ChatCompletion.create( - model=g4f.models.gpt_35_long, + model=g4f.models.default, messages=[{"role": "user", "content": prompt}], timeout=300, stream=True @@ -42,4 +42,4 @@ response = "".join(response) if code := read_code(response): with open(path, "w") as file: - file.write(code)
\ No newline at end of file + file.write(code) |