diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-03-28 17:17:59 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-28 17:17:59 +0100 |
commit | 64e07b7fbf810176d66506786a946a3122ea7fc4 (patch) | |
tree | 1cf10ab4f117583fdb4a98712c18052e5a42cdf2 /g4f | |
parent | Merge pull request #1758 from Zero6992/main (diff) | |
parent | Fix history support for OpenaiChat (diff) | |
download | gpt4free-0.2.7.2.tar gpt4free-0.2.7.2.tar.gz gpt4free-0.2.7.2.tar.bz2 gpt4free-0.2.7.2.tar.lz gpt4free-0.2.7.2.tar.xz gpt4free-0.2.7.2.tar.zst gpt4free-0.2.7.2.zip |
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 11 | ||||
-rw-r--r-- | g4f/gui/server/api.py | 15 | ||||
-rw-r--r-- | g4f/gui/server/backend.py | 2 |
3 files changed, 15 insertions, 13 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 72f9f224..396d73dd 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -389,19 +389,17 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): print(f"{e.__class__.__name__}: {e}") model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha") - fields = Conversation() if conversation is None else copy(conversation) + fields = Conversation(conversation_id, parent_id) if conversation is None else copy(conversation) fields.finish_reason = None while fields.finish_reason is None: - conversation_id = conversation_id if fields.conversation_id is None else fields.conversation_id - parent_id = parent_id if fields.message_id is None else fields.message_id websocket_request_id = str(uuid.uuid4()) data = { "action": action, "conversation_mode": {"kind": "primary_assistant"}, "force_paragen": False, "force_rate_limit": False, - "conversation_id": conversation_id, - "parent_message_id": parent_id, + "conversation_id": fields.conversation_id, + "parent_message_id": fields.message_id, "model": model, "history_and_training_disabled": history_disabled and not auto_continue and not return_conversation, "websocket_request_id": websocket_request_id @@ -425,6 +423,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): await raise_for_status(response) async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, fields): if return_conversation: + history_disabled = False return_conversation = False yield fields yield chunk @@ -432,7 +431,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): break action = "continue" await asyncio.sleep(5) - if history_disabled and auto_continue and not return_conversation: + if history_disabled and auto_continue: await cls.delete_conversation(session, cls._headers, fields.conversation_id) @staticmethod diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index da934d57..b4e2b3d4 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -41,7 +41,7 @@ from g4f.providers.base_provider import ProviderModelMixin from g4f.Provider.bing.create_images import patch_provider from g4f.providers.conversation import BaseConversation -conversations: dict[str, BaseConversation] = {} +conversations: dict[dict[str, BaseConversation]] = {} class Api(): @@ -106,7 +106,8 @@ class Api(): kwargs["image"] = open(self.image, "rb") for message in self._create_response_stream( self._prepare_conversation_kwargs(options, kwargs), - options.get("conversation_id") + options.get("conversation_id"), + options.get('provider') ): if not window.evaluate_js(f"if (!this.abort) this.add_message_chunk({json.dumps(message)}); !this.abort && !this.error;"): break @@ -193,8 +194,8 @@ class Api(): messages[-1]["content"] = get_search_message(messages[-1]["content"]) conversation_id = json_data.get("conversation_id") - if conversation_id and conversation_id in conversations: - kwargs["conversation"] = conversations[conversation_id] + if conversation_id and provider in conversations and conversation_id in conversations[provider]: + kwargs["conversation"] = conversations[provider][conversation_id] model = json_data.get('model') model = model if model else models.default @@ -211,7 +212,7 @@ class Api(): **kwargs } - def _create_response_stream(self, kwargs, conversation_id: str) -> Iterator: + def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str) -> Iterator: """ Creates and returns a streaming response for the conversation. @@ -231,7 +232,9 @@ class Api(): first = False yield self._format_json("provider", get_last_provider(True)) if isinstance(chunk, BaseConversation): - conversations[conversation_id] = chunk + if provider not in conversations: + conversations[provider] = {} + conversations[provider][conversation_id] = chunk yield self._format_json("conversation", conversation_id) elif isinstance(chunk, Exception): logging.exception(chunk) diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py index fb8404d4..d30b97d9 100644 --- a/g4f/gui/server/backend.py +++ b/g4f/gui/server/backend.py @@ -85,7 +85,7 @@ class Backend_Api(Api): kwargs = self._prepare_conversation_kwargs(json_data, kwargs) return self.app.response_class( - self._create_response_stream(kwargs, json_data.get("conversation_id")), + self._create_response_stream(kwargs, json_data.get("conversation_id"), json_data.get("provider")), mimetype='text/event-stream' ) |