summaryrefslogtreecommitdiffstats
path: root/g4f/locals/provider.py
blob: d9d7345597a036ca6b7b3d1b6b83d16d5ff55ee0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from __future__ import annotations

import os

from gpt4all import GPT4All
from .models import get_models
from ..typing import Messages

MODEL_LIST: dict[str, dict] = None

def find_model_dir(model_file: str) -> str:
    local_dir = os.path.dirname(os.path.abspath(__file__))
    project_dir = os.path.dirname(os.path.dirname(local_dir))

    new_model_dir = os.path.join(project_dir, "models")
    new_model_file = os.path.join(new_model_dir, model_file)
    if os.path.isfile(new_model_file):
        return new_model_dir

    old_model_dir = os.path.join(local_dir, "models")
    old_model_file = os.path.join(old_model_dir, model_file)
    if os.path.isfile(old_model_file):
        return old_model_dir

    working_dir = "./"
    for root, dirs, files in os.walk(working_dir):
        if model_file in files:
            return root

    return new_model_dir

class LocalProvider:
    @staticmethod
    def create_completion(model: str, messages: Messages, stream: bool = False, **kwargs):
        global MODEL_LIST
        if MODEL_LIST is None:
            MODEL_LIST = get_models()
        if model not in MODEL_LIST:
            raise ValueError(f'Model "{model}" not found / not yet implemented')

        model = MODEL_LIST[model]
        model_file = model["path"]
        model_dir = find_model_dir(model_file)
        if not os.path.isfile(os.path.join(model_dir, model_file)):
            print(f'Model file "models/{model_file}" not found.')
            download = input(f"Do you want to download {model_file}? [y/n]: ")
            if download in ["y", "Y"]:
                GPT4All.download_model(model_file, model_dir)
            else:
                raise ValueError(f'Model "{model_file}" not found.')

        model = GPT4All(model_name=model_file,
                        #n_threads=8,
                        verbose=False,
                        allow_download=False,
                        model_path=model_dir)

        system_message = "\n".join(message["content"] for message in messages if message["role"] == "system")
        if system_message:
            system_message = "A chat between a curious user and an artificial intelligence assistant."

        prompt_template = "USER: {0}\nASSISTANT: "
        conversation    = "\n" . join(
            f"{message['role'].upper()}: {message['content']}"
            for message in messages
            if message["role"] != "system"
        ) + "\nASSISTANT: "

        def should_not_stop(token_id: int, token: str):
            return "USER" not in token

        with model.chat_session(system_message, prompt_template):
            if stream:
                for token in model.generate(conversation, streaming=True, callback=should_not_stop):
                    yield token
            else:
                yield model.generate(conversation, callback=should_not_stop)