diff options
author | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-03-11 02:41:59 +0100 |
---|---|---|
committer | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-03-11 02:41:59 +0100 |
commit | ec51e9c76433b9e9498ade1dbe5de2268ad84218 (patch) | |
tree | 2876e373144f025949477601e89b8d478a49fafc /g4f/Provider/HuggingFace.py | |
parent | Add word count from iG8R (diff) | |
download | gpt4free-ec51e9c76433b9e9498ade1dbe5de2268ad84218.tar gpt4free-ec51e9c76433b9e9498ade1dbe5de2268ad84218.tar.gz gpt4free-ec51e9c76433b9e9498ade1dbe5de2268ad84218.tar.bz2 gpt4free-ec51e9c76433b9e9498ade1dbe5de2268ad84218.tar.lz gpt4free-ec51e9c76433b9e9498ade1dbe5de2268ad84218.tar.xz gpt4free-ec51e9c76433b9e9498ade1dbe5de2268ad84218.tar.zst gpt4free-ec51e9c76433b9e9498ade1dbe5de2268ad84218.zip |
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/HuggingFace.py | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/g4f/Provider/HuggingFace.py b/g4f/Provider/HuggingFace.py new file mode 100644 index 00000000..a73411ce --- /dev/null +++ b/g4f/Provider/HuggingFace.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import json +from aiohttp import ClientSession, BaseConnector + +from ..typing import AsyncResult, Messages +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin +from .helper import get_connector +from ..errors import RateLimitError, ModelNotFoundError + +class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): + url = "https://huggingface.co/chat" + working = True + supports_message_history = True + default_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" + + @classmethod + async def create_async_generator( + cls, + model: str, + messages: Messages, + stream: bool = True, + proxy: str = None, + connector: BaseConnector = None, + api_base: str = "https://api-inference.huggingface.co", + api_key: str = None, + max_new_tokens: int = 1024, + temperature: float = 0.7, + **kwargs + ) -> AsyncResult: + model = cls.get_model(model) + headers = {} + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + params = { + "return_full_text": False, + "max_new_tokens": max_new_tokens, + "temperature": temperature, + **kwargs + } + payload = {"inputs": format_prompt(messages), "parameters": params, "stream": stream} + async with ClientSession( + headers=headers, + connector=get_connector(connector, proxy) + ) as session: + async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response: + if response.status == 429: + raise RateLimitError("Rate limit reached. Set a api_key") + elif response.status == 404: + raise ModelNotFoundError(f"Model is not supported: {model}") + elif response.status != 200: + raise RuntimeError(f"Response {response.status}: {await response.text()}") + if stream: + first = True + async for line in response.content: + if line.startswith(b"data:"): + data = json.loads(line[5:]) + if not data["token"]["special"]: + chunk = data["token"]["text"] + if first: + first = False + chunk = chunk.lstrip() + yield chunk + else: + yield (await response.json())[0]["generated_text"].strip() + +def format_prompt(messages: Messages) -> str: + system_messages = [message["content"] for message in messages if message["role"] == "system"] + question = " ".join([messages[-1]["content"], *system_messages]) + history = "".join([ + f"<s>[INST]{messages[idx-1]['content']} [/INST] {message}</s>" + for idx, message in enumerate(messages) + if message["role"] == "assistant" + ]) + return f"{history}<s>[INST] {question} [/INST]"
\ No newline at end of file |