From 58fa409eefcc8ae0233967dc807b046ad77bf6fa Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Wed, 20 Nov 2024 02:34:47 +0100 Subject: Add Cerebras and HuggingFace2 provider, Fix RubiksAI provider Add support for image generation in Copilot provider --- g4f/gui/client/index.html | 8 +++- g4f/gui/client/static/css/style.css | 4 +- g4f/gui/client/static/js/chat.v1.js | 87 ++++++++++++++++++++----------------- g4f/gui/server/api.py | 5 ++- g4f/gui/server/backend.py | 3 +- 5 files changed, 60 insertions(+), 47 deletions(-) (limited to 'g4f/gui') diff --git a/g4f/gui/client/index.html b/g4f/gui/client/index.html index 3a2197de..48214093 100644 --- a/g4f/gui/client/index.html +++ b/g4f/gui/client/index.html @@ -128,6 +128,10 @@ +
+ + +
@@ -142,7 +146,7 @@
- +
@@ -192,7 +196,7 @@
diff --git a/g4f/gui/client/static/css/style.css b/g4f/gui/client/static/css/style.css index e435094f..76399703 100644 --- a/g4f/gui/client/static/css/style.css +++ b/g4f/gui/client/static/css/style.css @@ -512,9 +512,7 @@ body { @media only screen and (min-width: 40em) { .stop_generating { - left: 50%; - transform: translateX(-50%); - right: auto; + right: 4px; } .toolbar .regenerate span { display: block; diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js index 51bf8b81..a3e94ee2 100644 --- a/g4f/gui/client/static/js/chat.v1.js +++ b/g4f/gui/client/static/js/chat.v1.js @@ -215,7 +215,6 @@ const register_message_buttons = async () => { const message_el = el.parentElement.parentElement.parentElement; el.classList.add("clicked"); setTimeout(() => el.classList.remove("clicked"), 1000); - await hide_message(window.conversation_id, message_el.dataset.index); await ask_gpt(message_el.dataset.index, get_message_id()); }) } @@ -317,6 +316,7 @@ async function remove_cancel_button() { regenerate.addEventListener("click", async () => { regenerate.classList.add("regenerate-hidden"); + setTimeout(()=>regenerate.classList.remove("regenerate-hidden"), 3000); stop_generating.classList.remove("stop_generating-hidden"); await hide_message(window.conversation_id); await ask_gpt(-1, get_message_id()); @@ -383,12 +383,12 @@ const prepare_messages = (messages, message_index = -1) => { return new_messages; } -async function add_message_chunk(message, message_index) { - content_map = content_storage[message_index]; +async function add_message_chunk(message, message_id) { + content_map = content_storage[message_id]; if (message.type == "conversation") { console.info("Conversation used:", message.conversation) } else if (message.type == "provider") { - provider_storage[message_index] = message.provider; + provider_storage[message_id] = message.provider; content_map.content.querySelector('.provider').innerHTML = ` ${message.provider.label ? message.provider.label : message.provider.name} @@ -398,7 +398,7 @@ async function add_message_chunk(message, message_index) { } else if (message.type == "message") { console.error(message.message) } else if (message.type == "error") { - error_storage[message_index] = message.error + error_storage[message_id] = message.error console.error(message.error); content_map.inner.innerHTML += `

An error occured: ${message.error}

`; let p = document.createElement("p"); @@ -407,8 +407,8 @@ async function add_message_chunk(message, message_index) { } else if (message.type == "preview") { content_map.inner.innerHTML = markdown_render(message.preview); } else if (message.type == "content") { - message_storage[message_index] += message.content; - html = markdown_render(message_storage[message_index]); + message_storage[message_id] += message.content; + html = markdown_render(message_storage[message_id]); let lastElement, lastIndex = null; for (element of ['

', '', '

\n\n', '\n', '\n']) { const index = html.lastIndexOf(element) @@ -421,7 +421,7 @@ async function add_message_chunk(message, message_index) { html = html.substring(0, lastIndex) + '' + lastElement; } content_map.inner.innerHTML = html; - content_map.count.innerText = count_words_and_tokens(message_storage[message_index], provider_storage[message_index]?.model); + content_map.count.innerText = count_words_and_tokens(message_storage[message_id], provider_storage[message_id]?.model); highlight(content_map.inner); } else if (message.type == "log") { let p = document.createElement("p"); @@ -453,7 +453,7 @@ const ask_gpt = async (message_index = -1, message_id) => { let total_messages = messages.length; messages = prepare_messages(messages, message_index); message_index = total_messages - message_storage[message_index] = ""; + message_storage[message_id] = ""; stop_generating.classList.remove(".stop_generating-hidden"); message_box.scrollTop = message_box.scrollHeight; @@ -477,10 +477,10 @@ const ask_gpt = async (message_index = -1, message_id) => {
`; - controller_storage[message_index] = new AbortController(); + controller_storage[message_id] = new AbortController(); let content_el = document.getElementById(`gpt_${message_id}`) - let content_map = content_storage[message_index] = { + let content_map = content_storage[message_id] = { content: content_el, inner: content_el.querySelector('.content_inner'), count: content_el.querySelector('.count'), @@ -492,12 +492,7 @@ const ask_gpt = async (message_index = -1, message_id) => { const file = input && input.files.length > 0 ? input.files[0] : null; const provider = providerSelect.options[providerSelect.selectedIndex].value; const auto_continue = document.getElementById("auto_continue")?.checked; - let api_key = null; - if (provider) { - api_key = document.getElementById(`${provider}-api_key`)?.value || null; - if (api_key == null) - api_key = document.querySelector(`.${provider}-api_key`)?.value || null; - } + let api_key = get_api_key_by_provider(provider); await api("conversation", { id: message_id, conversation_id: window.conversation_id, @@ -506,10 +501,10 @@ const ask_gpt = async (message_index = -1, message_id) => { provider: provider, messages: messages, auto_continue: auto_continue, - api_key: api_key - }, file, message_index); - if (!error_storage[message_index]) { - html = markdown_render(message_storage[message_index]); + api_key: api_key, + }, file, message_id); + if (!error_storage[message_id]) { + html = markdown_render(message_storage[message_id]); content_map.inner.innerHTML = html; highlight(content_map.inner); @@ -520,14 +515,14 @@ const ask_gpt = async (message_index = -1, message_id) => { } catch (e) { console.error(e); if (e.name != "AbortError") { - error_storage[message_index] = true; + error_storage[message_id] = true; content_map.inner.innerHTML += `

An error occured: ${e}

`; } } - delete controller_storage[message_index]; - if (!error_storage[message_index] && message_storage[message_index]) { - const message_provider = message_index in provider_storage ? provider_storage[message_index] : null; - await add_message(window.conversation_id, "assistant", message_storage[message_index], message_provider); + delete controller_storage[message_id]; + if (!error_storage[message_id] && message_storage[message_id]) { + const message_provider = message_id in provider_storage ? provider_storage[message_id] : null; + await add_message(window.conversation_id, "assistant", message_storage[message_id], message_provider); await safe_load_conversation(window.conversation_id); } else { let cursorDiv = message_box.querySelector(".cursor"); @@ -1156,7 +1151,7 @@ async function on_api() { evt.preventDefault(); console.log("pressed enter"); prompt_lock = true; - setTimeout(()=>prompt_lock=false, 3); + setTimeout(()=>prompt_lock=false, 3000); await handle_ask(); } else { messageInput.style.removeProperty("height"); @@ -1167,7 +1162,7 @@ async function on_api() { console.log("clicked send"); if (prompt_lock) return; prompt_lock = true; - setTimeout(()=>prompt_lock=false, 3); + setTimeout(()=>prompt_lock=false, 3000); await handle_ask(); }); messageInput.focus(); @@ -1189,8 +1184,8 @@ async function on_api() { providerSelect.appendChild(option); }) - await load_provider_models(appStorage.getItem("provider")); await load_settings_storage() + await load_provider_models(appStorage.getItem("provider")); const hide_systemPrompt = document.getElementById("hide-systemPrompt") const slide_systemPrompt_icon = document.querySelector(".slide-systemPrompt i"); @@ -1316,7 +1311,7 @@ function get_selected_model() { } } -async function api(ressource, args=null, file=null, message_index=null) { +async function api(ressource, args=null, file=null, message_id=null) { if (window?.pywebview) { if (args !== null) { if (ressource == "models") { @@ -1326,15 +1321,19 @@ async function api(ressource, args=null, file=null, message_index=null) { } return pywebview.api[`get_${ressource}`](); } + let api_key; if (ressource == "models" && args) { + api_key = get_api_key_by_provider(args); ressource = `${ressource}/${args}`; } const url = `/backend-api/v2/${ressource}`; + const headers = {}; + if (api_key) { + headers.authorization = `Bearer ${api_key}`; + } if (ressource == "conversation") { let body = JSON.stringify(args); - const headers = { - accept: 'text/event-stream' - } + headers.accept = 'text/event-stream'; if (file !== null) { const formData = new FormData(); formData.append('file', file); @@ -1345,17 +1344,17 @@ async function api(ressource, args=null, file=null, message_index=null) { } response = await fetch(url, { method: 'POST', - signal: controller_storage[message_index].signal, + signal: controller_storage[message_id].signal, headers: headers, - body: body + body: body, }); - return read_response(response, message_index); + return read_response(response, message_id); } - response = await fetch(url); + response = await fetch(url, {headers: headers}); return await response.json(); } -async function read_response(response, message_index) { +async function read_response(response, message_id) { const reader = response.body.pipeThrough(new TextDecoderStream()).getReader(); let buffer = "" while (true) { @@ -1368,7 +1367,7 @@ async function read_response(response, message_index) { continue; } try { - add_message_chunk(JSON.parse(buffer + line), message_index); + add_message_chunk(JSON.parse(buffer + line), message_id); buffer = ""; } catch { buffer += line @@ -1377,6 +1376,16 @@ async function read_response(response, message_index) { } } +function get_api_key_by_provider(provider) { + let api_key = null; + if (provider) { + api_key = document.getElementById(`${provider}-api_key`)?.value || null; + if (api_key == null) + api_key = document.querySelector(`.${provider}-api_key`)?.value || null; + } + return api_key; +} + async function load_provider_models(providerIndex=null) { if (!providerIndex) { providerIndex = providerSelect.selectedIndex; diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index 6be77d09..2d871ff3 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -38,10 +38,11 @@ class Api: return models._all_models @staticmethod - def get_provider_models(provider: str) -> list[dict]: + def get_provider_models(provider: str, api_key: str = None) -> list[dict]: if provider in __map__: provider: ProviderType = __map__[provider] if issubclass(provider, ProviderModelMixin): + models = provider.get_models() if api_key is None else provider.get_models(api_key=api_key) return [ { "model": model, @@ -49,7 +50,7 @@ class Api: "vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []), "image": model in getattr(provider, "image_models", []), } - for model in provider.get_models() + for model in models ] return [] diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py index dc1b1080..020e49ef 100644 --- a/g4f/gui/server/backend.py +++ b/g4f/gui/server/backend.py @@ -94,7 +94,8 @@ class Backend_Api(Api): ) def get_provider_models(self, provider: str): - models = super().get_provider_models(provider) + api_key = None if request.authorization is None else request.authorization.token + models = super().get_provider_models(provider, api_key) if models is None: return 404, "Provider not found" return models -- cgit v1.2.3