From 03fd5ac99a828bd2637cf5be43a98157113527fb Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Thu, 28 Mar 2024 11:36:25 +0100 Subject: Fix history support for OpenaiChat --- g4f/gui/server/api.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) (limited to 'g4f/gui/server/api.py') diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index da934d57..b4e2b3d4 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -41,7 +41,7 @@ from g4f.providers.base_provider import ProviderModelMixin from g4f.Provider.bing.create_images import patch_provider from g4f.providers.conversation import BaseConversation -conversations: dict[str, BaseConversation] = {} +conversations: dict[dict[str, BaseConversation]] = {} class Api(): @@ -106,7 +106,8 @@ class Api(): kwargs["image"] = open(self.image, "rb") for message in self._create_response_stream( self._prepare_conversation_kwargs(options, kwargs), - options.get("conversation_id") + options.get("conversation_id"), + options.get('provider') ): if not window.evaluate_js(f"if (!this.abort) this.add_message_chunk({json.dumps(message)}); !this.abort && !this.error;"): break @@ -193,8 +194,8 @@ class Api(): messages[-1]["content"] = get_search_message(messages[-1]["content"]) conversation_id = json_data.get("conversation_id") - if conversation_id and conversation_id in conversations: - kwargs["conversation"] = conversations[conversation_id] + if conversation_id and provider in conversations and conversation_id in conversations[provider]: + kwargs["conversation"] = conversations[provider][conversation_id] model = json_data.get('model') model = model if model else models.default @@ -211,7 +212,7 @@ class Api(): **kwargs } - def _create_response_stream(self, kwargs, conversation_id: str) -> Iterator: + def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str) -> Iterator: """ Creates and returns a streaming response for the conversation. @@ -231,7 +232,9 @@ class Api(): first = False yield self._format_json("provider", get_last_provider(True)) if isinstance(chunk, BaseConversation): - conversations[conversation_id] = chunk + if provider not in conversations: + conversations[provider] = {} + conversations[provider][conversation_id] = chunk yield self._format_json("conversation", conversation_id) elif isinstance(chunk, Exception): logging.exception(chunk) -- cgit v1.2.3