summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/base_provider.py
diff options
context:
space:
mode:
authorH Lohaus <hlohaus@users.noreply.github.com>2024-01-23 20:08:41 +0100
committerGitHub <noreply@github.com>2024-01-23 20:08:41 +0100
commit2b140a32554c1e94d095c55599a2f93e86f957cf (patch)
treee2770d97f0242a0b99a3af68ea4fcf25227dfcc8 /g4f/Provider/base_provider.py
parent~ (diff)
parentAdd ProviderModelMixin for model selection (diff)
downloadgpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar
gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.gz
gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.bz2
gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.lz
gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.xz
gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.zst
gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.zip
Diffstat (limited to '')
-rw-r--r--g4f/Provider/base_provider.py23
1 files changed, 21 insertions, 2 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index bc47a1fa..e1dcd24d 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -8,7 +8,7 @@ from inspect import signature, Parameter
from .helper import get_cookies, format_prompt
from ..typing import CreateResult, AsyncResult, Messages, Union
from ..base_provider import BaseProvider
-from ..errors import NestAsyncioError
+from ..errors import NestAsyncioError, ModelNotSupportedError
if sys.version_info < (3, 10):
NoneType = type(None)
@@ -251,4 +251,23 @@ class AsyncGeneratorProvider(AsyncProvider):
Returns:
AsyncResult: An asynchronous generator yielding results.
"""
- raise NotImplementedError() \ No newline at end of file
+ raise NotImplementedError()
+
+class ProviderModelMixin:
+ default_model: str
+ models: list[str] = []
+ model_aliases: dict[str, str] = {}
+
+ @classmethod
+ def get_models(cls) -> list[str]:
+ return cls.models
+
+ @classmethod
+ def get_model(cls, model: str) -> str:
+ if not model:
+ return cls.default_model
+ elif model in cls.model_aliases:
+ return cls.model_aliases[model]
+ elif model not in cls.get_models():
+ raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
+ return model \ No newline at end of file