summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorH Lohaus <hlohaus@users.noreply.github.com>2024-04-17 10:35:08 +0200
committerGitHub <noreply@github.com>2024-04-17 10:35:08 +0200
commit0f04dacdbdba067152ebd4c1f7c23df8c9422295 (patch)
tree4e80161355a866d6e533c74dcca84494b6c2a5bf
parentadd cohere provider. (diff)
parentUpdate event loop on windows only for old curl_cffi (diff)
downloadgpt4free-0f04dacdbdba067152ebd4c1f7c23df8c9422295.tar
gpt4free-0f04dacdbdba067152ebd4c1f7c23df8c9422295.tar.gz
gpt4free-0f04dacdbdba067152ebd4c1f7c23df8c9422295.tar.bz2
gpt4free-0f04dacdbdba067152ebd4c1f7c23df8c9422295.tar.lz
gpt4free-0f04dacdbdba067152ebd4c1f7c23df8c9422295.tar.xz
gpt4free-0f04dacdbdba067152ebd4c1f7c23df8c9422295.tar.zst
gpt4free-0f04dacdbdba067152ebd4c1f7c23df8c9422295.zip
-rw-r--r--g4f/locals/provider.py7
-rw-r--r--g4f/providers/base_provider.py9
2 files changed, 12 insertions, 4 deletions
diff --git a/g4f/locals/provider.py b/g4f/locals/provider.py
index 45041539..d9d73455 100644
--- a/g4f/locals/provider.py
+++ b/g4f/locals/provider.py
@@ -66,9 +66,12 @@ class LocalProvider:
if message["role"] != "system"
) + "\nASSISTANT: "
+ def should_not_stop(token_id: int, token: str):
+ return "USER" not in token
+
with model.chat_session(system_message, prompt_template):
if stream:
- for token in model.generate(conversation, streaming=True):
+ for token in model.generate(conversation, streaming=True, callback=should_not_stop):
yield token
else:
- yield model.generate(conversation) \ No newline at end of file
+ yield model.generate(conversation, callback=should_not_stop) \ No newline at end of file
diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py
index 86789ec2..cb60d78f 100644
--- a/g4f/providers/base_provider.py
+++ b/g4f/providers/base_provider.py
@@ -19,8 +19,13 @@ else:
# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
if sys.platform == 'win32':
- if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
- asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+ try:
+ from curl_cffi import aio
+ if not hasattr(aio, "_get_selector"):
+ if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
+ except ImportError:
+ pass
def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
try: