diff options
author | Heiner Lohaus <heiner@lohaus.eu> | 2023-09-18 07:15:43 +0200 |
---|---|---|
committer | Heiner Lohaus <heiner@lohaus.eu> | 2023-09-18 07:15:43 +0200 |
commit | 3b8dfff974618499b177c5a724638919b93b702e (patch) | |
tree | e48ce2533f3cff7b1c45a3e789f63da01e541ac7 | |
parent | Add GptGo Provider, Fix AItianhu Provider (diff) | |
download | gpt4free-3b8dfff974618499b177c5a724638919b93b702e.tar gpt4free-3b8dfff974618499b177c5a724638919b93b702e.tar.gz gpt4free-3b8dfff974618499b177c5a724638919b93b702e.tar.bz2 gpt4free-3b8dfff974618499b177c5a724638919b93b702e.tar.lz gpt4free-3b8dfff974618499b177c5a724638919b93b702e.tar.xz gpt4free-3b8dfff974618499b177c5a724638919b93b702e.tar.zst gpt4free-3b8dfff974618499b177c5a724638919b93b702e.zip |
-rw-r--r-- | g4f/Provider/Ylokh.py | 5 | ||||
-rw-r--r-- | g4f/Provider/base_provider.py | 81 | ||||
-rw-r--r-- | testing/test_needs_auth.py | 10 |
3 files changed, 49 insertions, 47 deletions
diff --git a/g4f/Provider/Ylokh.py b/g4f/Provider/Ylokh.py index 1986b6d3..c7b92089 100644 --- a/g4f/Provider/Ylokh.py +++ b/g4f/Provider/Ylokh.py @@ -51,7 +51,9 @@ class Ylokh(AsyncGeneratorProvider): if stream: async for line in response.content: line = line.decode() - if line.startswith("data: ") and not line.startswith("data: [DONE]"): + if line.startswith("data: "): + if line.startswith("data: [DONE]"): + break line = json.loads(line[6:-1]) content = line["choices"][0]["delta"].get("content") if content: @@ -71,6 +73,7 @@ class Ylokh(AsyncGeneratorProvider): ("stream", "bool"), ("proxy", "str"), ("temperature", "float"), + ("top_p", "float"), ] param = ", ".join([": ".join(p) for p in params]) return f"g4f.provider.{cls.__name__} supports: ({param})"
\ No newline at end of file diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 0f499c8c..79f8f617 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -35,30 +35,6 @@ class BaseProvider(ABC): ] param = ", ".join([": ".join(p) for p in params]) return f"g4f.provider.{cls.__name__} supports: ({param})" - - -_cookies = {} - -def get_cookies(cookie_domain: str) -> dict: - if cookie_domain not in _cookies: - _cookies[cookie_domain] = {} - try: - for cookie in browser_cookie3.load(cookie_domain): - _cookies[cookie_domain][cookie.name] = cookie.value - except: - pass - return _cookies[cookie_domain] - - -def format_prompt(messages: list[dict[str, str]], add_special_tokens=False): - if add_special_tokens or len(messages) > 1: - formatted = "\n".join( - ["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages] - ) - return f"{formatted}\nAssistant:" - else: - return messages.pop()["content"] - class AsyncProvider(BaseProvider): @@ -67,8 +43,9 @@ class AsyncProvider(BaseProvider): cls, model: str, messages: list[dict[str, str]], - stream: bool = False, **kwargs: Any) -> CreateResult: - + stream: bool = False, + **kwargs + ) -> CreateResult: yield asyncio.run(cls.create_async(model, messages, **kwargs)) @staticmethod @@ -90,7 +67,20 @@ class AsyncGeneratorProvider(AsyncProvider): stream: bool = True, **kwargs ) -> CreateResult: - yield from run_generator(cls.create_async_generator(model, messages, stream=stream, **kwargs)) + loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(loop) + generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) + gen = generator.__aiter__() + while True: + try: + yield loop.run_until_complete(gen.__anext__()) + except StopAsyncIteration: + break + finally: + asyncio.set_event_loop(None) + loop.close() + @classmethod async def create_async( @@ -99,27 +89,36 @@ class AsyncGeneratorProvider(AsyncProvider): messages: list[dict[str, str]], **kwargs ) -> str: - chunks = [chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)] - if chunks: - return "".join(chunks) + return "".join([chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)]) @staticmethod @abstractmethod def create_async_generator( - model: str, - messages: list[dict[str, str]], - **kwargs - ) -> AsyncGenerator: + model: str, + messages: list[dict[str, str]], + **kwargs + ) -> AsyncGenerator: raise NotImplementedError() -def run_generator(generator: AsyncGenerator[Union[Any, str], Any]): - loop = asyncio.new_event_loop() - gen = generator.__aiter__() +_cookies = {} - while True: +def get_cookies(cookie_domain: str) -> dict: + if cookie_domain not in _cookies: + _cookies[cookie_domain] = {} try: - yield loop.run_until_complete(gen.__anext__()) + for cookie in browser_cookie3.load(cookie_domain): + _cookies[cookie_domain][cookie.name] = cookie.value + except: + pass + return _cookies[cookie_domain] - except StopAsyncIteration: - break + +def format_prompt(messages: list[dict[str, str]], add_special_tokens=False): + if add_special_tokens or len(messages) > 1: + formatted = "\n".join( + ["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages] + ) + return f"{formatted}\nAssistant:" + else: + return messages[0]["content"]
\ No newline at end of file diff --git a/testing/test_needs_auth.py b/testing/test_needs_auth.py index 3cef1c61..26630e23 100644 --- a/testing/test_needs_auth.py +++ b/testing/test_needs_auth.py @@ -17,7 +17,7 @@ _providers = [ g4f.Provider.Bard ] -_instruct = "Hello, tell about you in one sentence." +_instruct = "Hello, are you GPT 4?." _example = """ OpenaiChat: Hello! How can I assist you today? 2.0 secs @@ -39,14 +39,14 @@ No Stream Total: 10.14 secs print("Bing: ", end="") for response in log_time_yield( g4f.ChatCompletion.create, - model=g4f.models.gpt_35_turbo, + model=g4f.models.default, messages=[{"role": "user", "content": _instruct}], provider=g4f.Provider.Bing, #cookies=g4f.get_cookies(".huggingface.co"), - #stream=True, + stream=True, auth=True ): - print(response, end="") + print(response, end="", flush=True) print() print() @@ -75,7 +75,7 @@ def run_stream(): model=None, messages=[{"role": "user", "content": _instruct}], ): - print(response, end="") + print(response, end="", flush=True) print() print("Stream Total:", log_time(run_stream)) print() |