summaryrefslogtreecommitdiffstats
path: root/g4f/gui/server/api.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/gui/server/api.py')
-rw-r--r--g4f/gui/server/api.py21
1 files changed, 12 insertions, 9 deletions
diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py
index f03d2048..ed8454c3 100644
--- a/g4f/gui/server/api.py
+++ b/g4f/gui/server/api.py
@@ -42,7 +42,12 @@ class Api:
provider: ProviderType = __map__[provider]
if issubclass(provider, ProviderModelMixin):
return [
- {"model": model, "default": model == provider.default_model}
+ {
+ "model": model,
+ "default": model == provider.default_model,
+ "vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
+ "image": model in getattr(provider, "image_models", []),
+ }
for model in provider.get_models()
]
return []
@@ -65,7 +70,7 @@ class Api:
"url": parent.url,
"label": parent.label if hasattr(parent, "label") else None,
"image_model": model,
- "vision_model": parent.default_vision_model if hasattr(parent, "default_vision_model") else None
+ "vision_model": getattr(parent, "default_vision_model", None)
})
index.append(parent.__name__)
elif hasattr(provider, "default_vision_model") and provider.__name__ not in index:
@@ -82,13 +87,11 @@ class Api:
@staticmethod
def get_providers() -> list[str]:
return {
- provider.__name__: (
- provider.label if hasattr(provider, "label") else provider.__name__
- ) + (
- " (WebDriver)" if "webdriver" in provider.get_parameters() else ""
- ) + (
- " (Auth)" if provider.needs_auth else ""
- )
+ provider.__name__: (provider.label if hasattr(provider, "label") else provider.__name__)
+ + (" (Image Generation)" if hasattr(provider, "image_models") else "")
+ + (" (Image Upload)" if getattr(provider, "default_vision_model", None) else "")
+ + (" (WebDriver)" if "webdriver" in provider.get_parameters() else "")
+ + (" (Auth)" if provider.needs_auth else "")
for provider in __providers__
if provider.working
}