summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/DeepInfra.py
diff options
context:
space:
mode:
authorH Lohaus <hlohaus@users.noreply.github.com>2024-01-23 20:08:41 +0100
committerGitHub <noreply@github.com>2024-01-23 20:08:41 +0100
commit2b140a32554c1e94d095c55599a2f93e86f957cf (patch)
treee2770d97f0242a0b99a3af68ea4fcf25227dfcc8 /g4f/Provider/DeepInfra.py
parent~ (diff)
parentAdd ProviderModelMixin for model selection (diff)
downloadgpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar
gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.gz
gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.bz2
gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.lz
gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.xz
gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.tar.zst
gpt4free-2b140a32554c1e94d095c55599a2f93e86f957cf.zip
Diffstat (limited to 'g4f/Provider/DeepInfra.py')
-rw-r--r--g4f/Provider/DeepInfra.py26
1 files changed, 17 insertions, 9 deletions
diff --git a/g4f/Provider/DeepInfra.py b/g4f/Provider/DeepInfra.py
index acde1200..2f34b679 100644
--- a/g4f/Provider/DeepInfra.py
+++ b/g4f/Provider/DeepInfra.py
@@ -1,18 +1,27 @@
from __future__ import annotations
import json
-from ..typing import AsyncResult, Messages
-from .base_provider import AsyncGeneratorProvider
-from ..requests import StreamSession
+import requests
+from ..typing import AsyncResult, Messages
+from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
+from ..requests import StreamSession
-class DeepInfra(AsyncGeneratorProvider):
+class DeepInfra(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://deepinfra.com"
working = True
supports_stream = True
supports_message_history = True
-
+ default_model = 'meta-llama/Llama-2-70b-chat-hf'
+
@staticmethod
+ def get_models():
+ url = 'https://api.deepinfra.com/models/featured'
+ models = requests.get(url).json()
+ return [model['model_name'] for model in models]
+
+ @classmethod
async def create_async_generator(
+ cls,
model: str,
messages: Messages,
stream: bool,
@@ -21,8 +30,6 @@ class DeepInfra(AsyncGeneratorProvider):
auth: str = None,
**kwargs
) -> AsyncResult:
- if not model:
- model = 'meta-llama/Llama-2-70b-chat-hf'
headers = {
'Accept-Encoding': 'gzip, deflate, br',
'Accept-Language': 'en-US',
@@ -49,7 +56,7 @@ class DeepInfra(AsyncGeneratorProvider):
impersonate="chrome110"
) as session:
json_data = {
- 'model' : model,
+ 'model' : cls.get_model(model),
'messages': messages,
'stream' : True
}
@@ -70,7 +77,8 @@ class DeepInfra(AsyncGeneratorProvider):
if token:
if first:
token = token.lstrip()
+ if token:
first = False
- yield token
+ yield token
except Exception:
raise RuntimeError(f"Response: {line}")