diff options
Diffstat (limited to 'g4f/Provider/HuggingChat.py')
-rw-r--r-- | g4f/Provider/HuggingChat.py | 36 |
1 files changed, 26 insertions, 10 deletions
diff --git a/g4f/Provider/HuggingChat.py b/g4f/Provider/HuggingChat.py index 668ce4b1..527f0a56 100644 --- a/g4f/Provider/HuggingChat.py +++ b/g4f/Provider/HuggingChat.py @@ -6,12 +6,14 @@ from aiohttp import ClientSession, BaseConnector from ..typing import AsyncResult, Messages from ..requests.raise_for_status import raise_for_status +from ..providers.conversation import BaseConversation from .base_provider import AsyncGeneratorProvider, ProviderModelMixin -from .helper import format_prompt, get_connector +from .helper import format_prompt, get_connector, get_cookies class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): url = "https://huggingface.co/chat" working = True + needs_auth = True default_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" models = [ "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", @@ -22,9 +24,6 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): 'mistralai/Mistral-7B-Instruct-v0.2', 'meta-llama/Meta-Llama-3-70B-Instruct' ] - model_aliases = { - "openchat/openchat_3.5": "openchat/openchat-3.5-0106", - } @classmethod def get_models(cls): @@ -45,9 +44,16 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): connector: BaseConnector = None, web_search: bool = False, cookies: dict = None, + conversation: Conversation = None, + return_conversation: bool = False, + delete_conversation: bool = True, **kwargs ) -> AsyncResult: options = {"model": cls.get_model(model)} + if cookies is None: + cookies = get_cookies("huggingface.co", False) + if return_conversation: + delete_conversation = False system_prompt = "\n".join([message["content"] for message in messages if message["role"] == "system"]) if system_prompt: @@ -61,9 +67,14 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): headers=headers, connector=get_connector(connector, proxy) ) as session: - async with session.post(f"{cls.url}/conversation", json=options) as response: - await raise_for_status(response) - conversation_id = (await response.json())["conversationId"] + if conversation is None: + async with session.post(f"{cls.url}/conversation", json=options) as response: + await raise_for_status(response) + conversation_id = (await response.json())["conversationId"] + if return_conversation: + yield Conversation(conversation_id) + else: + conversation_id = conversation.conversation_id async with session.get(f"{cls.url}/conversation/{conversation_id}/__data.json") as response: await raise_for_status(response) data: list = (await response.json())["nodes"][1]["data"] @@ -72,7 +83,7 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): message_id: str = data[message_keys["id"]] options = { "id": message_id, - "inputs": format_prompt(messages), + "inputs": format_prompt(messages) if conversation is None else messages[-1]["content"], "is_continue": False, "is_retry": False, "web_search": web_search @@ -92,5 +103,10 @@ class HuggingChat(AsyncGeneratorProvider, ProviderModelMixin): yield token elif line["type"] == "finalAnswer": break - async with session.delete(f"{cls.url}/conversation/{conversation_id}") as response: - await raise_for_status(response) + if delete_conversation: + async with session.delete(f"{cls.url}/conversation/{conversation_id}") as response: + await raise_for_status(response) + +class Conversation(BaseConversation): + def __init__(self, conversation_id: str) -> None: + self.conversation_id = conversation_id
\ No newline at end of file |