summaryrefslogblamecommitdiffstats
path: root/g4f/Provider/needs_auth/Cerebras.py
blob: 0f94c476a0c9b460eaf6bce24cd35ae6d9f1444b (plain) (tree)
































































                                                                                                 
from __future__ import annotations

import requests
from aiohttp import ClientSession

from .OpenaiAPI import OpenaiAPI
from ...typing import AsyncResult, Messages, Cookies
from ...requests.raise_for_status import raise_for_status
from ...cookies import get_cookies

class Cerebras(OpenaiAPI):
    label = "Cerebras Inference"
    url = "https://inference.cerebras.ai/"
    working = True
    default_model = "llama3.1-70b"
    fallback_models = [
        "llama3.1-70b",
        "llama3.1-8b",
    ]
    model_aliases = {"llama-3.1-70b": "llama3.1-70b", "llama-3.1-8b": "llama3.1-8b"}

    @classmethod
    def get_models(cls, api_key: str = None):
        if not cls.models:
            try:
                headers = {}
                if api_key:
                    headers["authorization"] = f"Bearer ${api_key}"
                response = requests.get(f"https://api.cerebras.ai/v1/models", headers=headers)
                raise_for_status(response)
                data = response.json()
                cls.models = [model.get("model") for model in data.get("models")]
            except Exception:
                cls.models = cls.fallback_models
        return cls.models

    @classmethod
    async def create_async_generator(
        cls,
        model: str,
        messages: Messages,
        api_base: str = "https://api.cerebras.ai/v1",
        api_key: str = None,
        cookies: Cookies = None,
        **kwargs
    ) -> AsyncResult:
        if api_key is None and cookies is None:
            cookies = get_cookies(".cerebras.ai")
        async with ClientSession(cookies=cookies) as session:
            async with session.get("https://inference.cerebras.ai/api/auth/session") as response:
                raise_for_status(response)
                data = await response.json()
                if data:
                    api_key = data.get("user", {}).get("demoApiKey")
        async for chunk in super().create_async_generator(
            model, messages,
            api_base=api_base,
            impersonate="chrome",
            api_key=api_key,
            headers={
                "User-Agent": "ex/JS 1.5.0",
            },
            **kwargs
        ):
            yield chunk