diff options
Diffstat (limited to 'g4f/api/__init__.py')
-rw-r--r-- | g4f/api/__init__.py | 103 |
1 files changed, 85 insertions, 18 deletions
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 292164fa..628d7512 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -5,8 +5,10 @@ import json import uvicorn import secrets import os +import shutil -from fastapi import FastAPI, Response, Request +import os.path +from fastapi import FastAPI, Response, Request, UploadFile from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse from fastapi.exceptions import RequestValidationError from fastapi.security import APIKeyHeader @@ -16,16 +18,17 @@ from fastapi.encoders import jsonable_encoder from fastapi.middleware.cors import CORSMiddleware from starlette.responses import FileResponse from pydantic import BaseModel -from typing import Union, Optional +from typing import Union, Optional, List import g4f import g4f.debug -from g4f.client import AsyncClient, ChatCompletion +from g4f.client import AsyncClient, ChatCompletion, convert_to_provider from g4f.providers.response import BaseConversation from g4f.client.helper import filter_none from g4f.image import is_accepted_format, images_dir from g4f.typing import Messages -from g4f.cookies import read_cookie_files +from g4f.errors import ProviderNotFoundError +from g4f.cookies import read_cookie_files, get_cookies_dir from g4f.Provider import ProviderType, ProviderUtils, __providers__ logger = logging.getLogger(__name__) @@ -78,6 +81,18 @@ class ImageGenerationConfig(BaseModel): api_key: Optional[str] = None proxy: Optional[str] = None +class ProviderResponseModel(BaseModel): + id: str + object: str = "provider" + created: int + owned_by: Optional[str] + +class ModelResponseModel(BaseModel): + id: str + object: str = "model" + created: int + owned_by: Optional[str] + class AppConfig: ignored_providers: Optional[list[str]] = None g4f_api_key: Optional[str] = None @@ -109,7 +124,7 @@ class Api: def register_authorization(self): @self.app.middleware("http") async def authorization(request: Request, call_next): - if self.g4f_api_key and request.url.path in ["/v1/chat/completions", "/v1/completions", "/v1/images/generate"]: + if self.g4f_api_key and request.url.path not in ("/", "/v1"): try: user_g4f_api_key = await self.get_g4f_api_key(request) except HTTPException as e: @@ -123,9 +138,7 @@ class Api: status_code=HTTP_403_FORBIDDEN, content=jsonable_encoder({"detail": "Invalid G4F API key"}), ) - - response = await call_next(request) - return response + return await call_next(request) def register_validation_exception_handler(self): @self.app.exception_handler(RequestValidationError) @@ -158,22 +171,21 @@ class Api: '<a href="/docs">/docs</a>') @self.app.get("/v1/models") - async def models(): + async def models() -> list[ModelResponseModel]: model_list = dict( (model, g4f.models.ModelUtils.convert[model]) for model in g4f.Model.__all__() ) - model_list = [{ + return [{ 'id': model_id, 'object': 'model', 'created': 0, 'owned_by': model.base_provider } for model_id, model in model_list.items()] - return JSONResponse(model_list) @self.app.get("/v1/models/{model_name}") async def model_info(model_name: str): - try: + if model_name in g4f.models.ModelUtils.convert: model_info = g4f.models.ModelUtils.convert[model_name] return JSONResponse({ 'id': model_name, @@ -181,8 +193,7 @@ class Api: 'created': 0, 'owned_by': model_info.base_provider }) - except: - return JSONResponse({"error": "The model does not exist."}) + return JSONResponse({"error": "The model does not exist."}, 404) @self.app.post("/v1/chat/completions") async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None): @@ -277,12 +288,68 @@ class Api: logger.exception(e) return Response(content=format_exception(e, config, True), status_code=500, media_type="application/json") - @self.app.post("/v1/completions") - async def completions(): - return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json") + @self.app.get("/v1/providers") + async def providers() -> list[ProviderResponseModel]: + return [{ + 'id': provider.__name__, + 'object': 'provider', + 'created': 0, + 'url': provider.url, + 'label': getattr(provider, "label", None), + } for provider in __providers__ if provider.working] + + @self.app.get("/v1/providers/{provider}") + async def providers_info(provider: str) -> ProviderResponseModel: + if provider not in ProviderUtils.convert: + return JSONResponse({"error": "The provider does not exist."}, 404) + provider: ProviderType = ProviderUtils.convert[provider] + def safe_get_models(provider: ProviderType) -> list[str]: + try: + return provider.get_models() if hasattr(provider, "get_models") else [] + except: + return [] + return { + 'id': provider.__name__, + 'object': 'provider', + 'created': 0, + 'url': provider.url, + 'label': getattr(provider, "label", None), + 'models': safe_get_models(provider), + 'image_models': getattr(provider, "image_models", []) or [], + 'vision_models': [model for model in [getattr(provider, "default_vision_model", None)] if model], + 'params': [*provider.get_parameters()] if hasattr(provider, "get_parameters") else [] + } + + @self.app.post("/v1/upload_cookies") + def upload_cookies(files: List[UploadFile]): + response_data = [] + for file in files: + try: + if file and file.filename.endswith(".json") or file.filename.endswith(".har"): + filename = os.path.basename(file.filename) + with open(os.path.join(get_cookies_dir(), filename), 'wb') as f: + shutil.copyfileobj(file.file, f) + response_data.append({"filename": filename}) + finally: + file.file.close() + return response_data + + @self.app.get("/v1/synthesize/{provider}") + async def synthesize(request: Request, provider: str): + try: + provider_handler = convert_to_provider(provider) + except ProviderNotFoundError: + return Response("Provider not found", 404) + if not hasattr(provider_handler, "synthesize"): + return Response("Provider doesn't support synthesize", 500) + if len(request.query_params) == 0: + return Response("Missing query params", 500) + response_data = provider_handler.synthesize({**request.query_params}) + content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream") + return StreamingResponse(response_data, media_type=content_type) @self.app.get("/images/{filename}") - async def get_image(filename): + async def get_image(filename) -> FileResponse: target = os.path.join(images_dir, filename) if not os.path.isfile(target): |