From 91feb34054f529c37e10d98d2471c8c0c6780147 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Tue, 23 Jan 2024 19:44:48 +0100 Subject: Add ProviderModelMixin for model selection --- g4f/Provider/base_provider.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) (limited to 'g4f/Provider/base_provider.py') diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index bc47a1fa..e1dcd24d 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -8,7 +8,7 @@ from inspect import signature, Parameter from .helper import get_cookies, format_prompt from ..typing import CreateResult, AsyncResult, Messages, Union from ..base_provider import BaseProvider -from ..errors import NestAsyncioError +from ..errors import NestAsyncioError, ModelNotSupportedError if sys.version_info < (3, 10): NoneType = type(None) @@ -251,4 +251,23 @@ class AsyncGeneratorProvider(AsyncProvider): Returns: AsyncResult: An asynchronous generator yielding results. """ - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() + +class ProviderModelMixin: + default_model: str + models: list[str] = [] + model_aliases: dict[str, str] = {} + + @classmethod + def get_models(cls) -> list[str]: + return cls.models + + @classmethod + def get_model(cls, model: str) -> str: + if not model: + return cls.default_model + elif model in cls.model_aliases: + return cls.model_aliases[model] + elif model not in cls.get_models(): + raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") + return model \ No newline at end of file -- cgit v1.2.3