summaryrefslogblamecommitdiffstats
path: root/g4f/Provider/Local.py
blob: c08064b9195cb9d5f761a3f98f5a49f2832bcffe (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14













                                                                          
                     



























                                                                                                   
from __future__ import annotations

from ..locals.models import get_models
try:
    from ..locals.provider import LocalProvider
    has_requirements = True
except ModuleNotFoundError:
    has_requirements = False

from ..typing import Messages, CreateResult
from ..providers.base_provider import AbstractProvider, ProviderModelMixin
from ..errors import MissingRequirementsError

class Local(AbstractProvider, ProviderModelMixin):
    label = "gpt4all"
    working = True
    supports_message_history = True
    supports_system_message = True
    supports_stream = True

    @classmethod
    def get_models(cls):
        if not cls.models:
            cls.models = list(get_models())
            cls.default_model = cls.models[0]
        return cls.models

    @classmethod
    def create_completion(
        cls,
        model: str,
        messages: Messages,
        stream: bool,
        **kwargs
    ) -> CreateResult:
        if not has_requirements:
            raise MissingRequirementsError('Install "gpt4all" package | pip install -U g4f[local]')
        return LocalProvider.create_completion(
            cls.get_model(model),
            messages,
            stream,
            **kwargs
        )