summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/Local.py
blob: c08064b9195cb9d5f761a3f98f5a49f2832bcffe (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from __future__ import annotations

from ..locals.models import get_models
try:
    from ..locals.provider import LocalProvider
    has_requirements = True
except ModuleNotFoundError:
    has_requirements = False

from ..typing import Messages, CreateResult
from ..providers.base_provider import AbstractProvider, ProviderModelMixin
from ..errors import MissingRequirementsError

class Local(AbstractProvider, ProviderModelMixin):
    label = "gpt4all"
    working = True
    supports_message_history = True
    supports_system_message = True
    supports_stream = True

    @classmethod
    def get_models(cls):
        if not cls.models:
            cls.models = list(get_models())
            cls.default_model = cls.models[0]
        return cls.models

    @classmethod
    def create_completion(
        cls,
        model: str,
        messages: Messages,
        stream: bool,
        **kwargs
    ) -> CreateResult:
        if not has_requirements:
            raise MissingRequirementsError('Install "gpt4all" package | pip install -U g4f[local]')
        return LocalProvider.create_completion(
            cls.get_model(model),
            messages,
            stream,
            **kwargs
        )