diff options
Diffstat (limited to '')
-rw-r--r-- | etc/unittest/asyncio.py | 10 | ||||
-rw-r--r-- | etc/unittest/client.py | 54 | ||||
-rw-r--r-- | etc/unittest/mocks.py | 19 | ||||
-rw-r--r-- | g4f/Provider/base_provider.py | 23 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 2 | ||||
-rw-r--r-- | g4f/client.py | 151 | ||||
-rw-r--r-- | g4f/stubs.py | 44 |
7 files changed, 170 insertions, 133 deletions
diff --git a/etc/unittest/asyncio.py b/etc/unittest/asyncio.py index e886c43a..57a1fb7d 100644 --- a/etc/unittest/asyncio.py +++ b/etc/unittest/asyncio.py @@ -8,6 +8,7 @@ import unittest import g4f from g4f import ChatCompletion +from g4f.client import Client from .mocks import ProviderMock, AsyncProviderMock, AsyncGeneratorProviderMock DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}] @@ -24,11 +25,16 @@ class TestChatCompletion(unittest.TestCase): def test_create(self): result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncProviderMock) - self.assertEqual("Mock",result) + self.assertEqual("Mock", result) def test_create_generator(self): result = ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, AsyncGeneratorProviderMock) - self.assertEqual("Mock",result) + self.assertEqual("Mock", result) + + def test_await_callback(self): + client = Client(provider=AsyncGeneratorProviderMock) + response = client.chat.completions.create(DEFAULT_MESSAGES, "", max_tokens=0) + self.assertEqual("Mock", response.choices[0].message.content) class TestChatCompletionAsync(unittest.IsolatedAsyncioTestCase): diff --git a/etc/unittest/client.py b/etc/unittest/client.py new file mode 100644 index 00000000..c63edbd2 --- /dev/null +++ b/etc/unittest/client.py @@ -0,0 +1,54 @@ +import unittest + +from g4f.client import Client, ChatCompletion, ChatCompletionChunk +from .mocks import AsyncGeneratorProviderMock, ModelProviderMock, YieldProviderMock + +DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}] + +class TestPassModel(unittest.TestCase): + + def test_response(self): + client = Client(provider=AsyncGeneratorProviderMock) + response = client.chat.completions.create(DEFAULT_MESSAGES, "") + self.assertIsInstance(response, ChatCompletion) + self.assertEqual("Mock", response.choices[0].message.content) + + def test_pass_model(self): + client = Client(provider=ModelProviderMock) + response = client.chat.completions.create(DEFAULT_MESSAGES, "Hello") + self.assertIsInstance(response, ChatCompletion) + self.assertEqual("Hello", response.choices[0].message.content) + + def test_max_tokens(self): + client = Client(provider=YieldProviderMock) + messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]] + response = client.chat.completions.create(messages, "Hello", max_tokens=1) + self.assertIsInstance(response, ChatCompletion) + self.assertEqual("How ", response.choices[0].message.content) + response = client.chat.completions.create(messages, "Hello", max_tokens=2) + self.assertIsInstance(response, ChatCompletion) + self.assertEqual("How are ", response.choices[0].message.content) + + def test_max_stream(self): + client = Client(provider=YieldProviderMock) + messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]] + response = client.chat.completions.create(messages, "Hello", stream=True) + for chunk in response: + self.assertIsInstance(chunk, ChatCompletionChunk) + self.assertIsInstance(chunk.choices[0].delta.content, str) + messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You ", "Other", "?"]] + response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2) + response = list(response) + self.assertEqual(len(response), 2) + for chunk in response: + self.assertEqual(chunk.choices[0].delta.content, "You ") + + def no_test_stop(self): + client = Client(provider=YieldProviderMock) + messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]] + response = client.chat.completions.create(messages, "Hello", stop=["and"]) + self.assertIsInstance(response, ChatCompletion) + self.assertEqual("How are you?", response.choices[0].message.content) + +if __name__ == '__main__': + unittest.main()
\ No newline at end of file diff --git a/etc/unittest/mocks.py b/etc/unittest/mocks.py index 885bdaee..8a67aaf7 100644 --- a/etc/unittest/mocks.py +++ b/etc/unittest/mocks.py @@ -7,10 +7,10 @@ class ProviderMock(AbstractProvider): model, messages, stream, **kwargs ): yield "Mock" - + class AsyncProviderMock(AsyncProvider): working = True - + async def create_async( model, messages, **kwargs ): @@ -18,16 +18,25 @@ class AsyncProviderMock(AsyncProvider): class AsyncGeneratorProviderMock(AsyncGeneratorProvider): working = True - + async def create_async_generator( model, messages, stream, **kwargs ): yield "Mock" - + class ModelProviderMock(AbstractProvider): working = True def create_completion( model, messages, stream, **kwargs ): - yield model
\ No newline at end of file + yield model + +class YieldProviderMock(AsyncGeneratorProvider): + working = True + + async def create_async_generator( + model, messages, stream, **kwargs + ): + for message in messages: + yield message["content"]
\ No newline at end of file diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 4b312ffc..8659f506 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -196,15 +196,20 @@ class AsyncGeneratorProvider(AsyncProvider): generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) gen = generator.__aiter__() - while True: - try: - yield loop.run_until_complete(gen.__anext__()) - except StopAsyncIteration: - break - - if new_loop: - loop.close() - asyncio.set_event_loop(None) + # Fix for RuntimeError: async generator ignored GeneratorExit + async def await_callback(callback): + return await callback() + + try: + while True: + yield loop.run_until_complete(await_callback(gen.__anext__)) + except StopAsyncIteration: + ... + # Fix for: ResourceWarning: unclosed event loop + finally: + if new_loop: + loop.close() + asyncio.set_event_loop(None) @classmethod async def create_async( diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 9e0edd8a..b3577ad5 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -385,7 +385,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): } ) as response: if not response.ok: - raise RuntimeError(f"Response {response.status_code}: {await response.text()}") + raise RuntimeError(f"Response {response.status}: {await response.text()}") last_message: int = 0 async for line in response.iter_lines(): if not line.startswith(b"data: "): diff --git a/g4f/client.py b/g4f/client.py index 03b0eda3..a1494d47 100644 --- a/g4f/client.py +++ b/g4f/client.py @@ -2,9 +2,9 @@ from __future__ import annotations import re -from .typing import Union, Generator, AsyncGenerator, Messages, ImageType +from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse +from .typing import Union, Generator, Messages, ImageType from .base_provider import BaseProvider, ProviderType -from .Provider.base_provider import AsyncGeneratorProvider from .image import ImageResponse as ImageProviderResponse from .Provider import BingCreateImages, Gemini, OpenaiChat from .errors import NoImageResponseError @@ -36,14 +36,14 @@ def iter_response( stop: list = None ) -> Generator: content = "" - idx = 1 - chunk = None - finish_reason = "stop" + finish_reason = None + last_chunk = None for idx, chunk in enumerate(response): + if last_chunk is not None: + yield ChatCompletionChunk(last_chunk, finish_reason) content += str(chunk) - if max_tokens is not None and idx > max_tokens: + if max_tokens is not None and idx + 1 >= max_tokens: finish_reason = "max_tokens" - break first = -1 word = None if stop is not None: @@ -52,98 +52,30 @@ def iter_response( if first != -1: content = content[:first] break - if stream: + if stream and first != -1: + first = chunk.find(word) if first != -1: - first = chunk.find(word) - if first != -1: - chunk = chunk[:first] - else: - first = 0 - yield ChatCompletionChunk([ChatCompletionDeltaChoice(ChatCompletionDelta(chunk))]) + chunk = chunk[:first] + else: + first = 0 if first != -1: + finish_reason = "stop" + if stream: + last_chunk = chunk + if finish_reason is not None: break + if last_chunk is not None: + yield ChatCompletionChunk(last_chunk, finish_reason) if not stream: if response_format is not None and "type" in response_format: if response_format["type"] == "json_object": response = read_json(response) - yield ChatCompletion([ChatCompletionChoice(ChatCompletionMessage(response, finish_reason))]) - -async def aiter_response( - response: aiter, - stream: bool, - response_format: dict = None, - max_tokens: int = None, - stop: list = None -) -> AsyncGenerator: - content = "" - try: - idx = 0 - chunk = None - async for chunk in response: - content += str(chunk) - if max_tokens is not None and idx > max_tokens: - break - first = -1 - word = None - if stop is not None: - for word in list(stop): - first = content.find(word) - if first != -1: - content = content[:first] - break - if stream: - if first != -1: - first = chunk.find(word) - if first != -1: - chunk = chunk[:first] - else: - first = 0 - yield ChatCompletionChunk([ChatCompletionDeltaChoice(ChatCompletionDelta(chunk))]) - if first != -1: - break - idx += 1 - except: - ... - if not stream: - if response_format is not None and "type" in response_format: - if response_format["type"] == "json_object": - response = read_json(response) - yield ChatCompletion([ChatCompletionChoice(ChatCompletionMessage(response))]) - -class Model(): - def __getitem__(self, item): - return getattr(self, item) - -class ChatCompletion(Model): - def __init__(self, choices: list): - self.choices = choices - -class ChatCompletionChunk(Model): - def __init__(self, choices: list): - self.choices = choices - -class ChatCompletionChoice(Model): - def __init__(self, message: ChatCompletionMessage): - self.message = message - -class ChatCompletionMessage(Model): - def __init__(self, content: str, finish_reason: str): - self.content = content - self.finish_reason = finish_reason - self.index = 0 - self.logprobs = None - -class ChatCompletionDelta(Model): - def __init__(self, content: str): - self.content = content - -class ChatCompletionDeltaChoice(Model): - def __init__(self, delta: ChatCompletionDelta): - self.delta = delta + yield ChatCompletion(content, finish_reason) class Client(): proxies: Proxies = None chat: Chat + images: Images def __init__( self, @@ -152,9 +84,9 @@ class Client(): proxies: Proxies = None, **kwargs ) -> None: - self.proxies: Proxies = proxies - self.images = Images(self, image_provider) self.chat = Chat(self, provider) + self.images = Images(self, image_provider) + self.proxies: Proxies = proxies def get_proxy(self) -> Union[str, None]: if isinstance(self.proxies, str) or self.proxies is None: @@ -178,13 +110,13 @@ class Completions(): stream: bool = False, response_format: dict = None, max_tokens: int = None, - stop: list = None, + stop: Union[list. str] = None, **kwargs - ) -> Union[dict, Generator]: + ) -> Union[ChatCompletion, Generator[ChatCompletionChunk]]: if max_tokens is not None: kwargs["max_tokens"] = max_tokens if stop: - kwargs["stop"] = list(stop) + kwargs["stop"] = stop model, provider = get_model_and_provider( model, self.provider if provider is None else provider, @@ -192,10 +124,8 @@ class Completions(): **kwargs ) response = provider.create_completion(model, messages, stream=stream, **kwargs) - if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): - response = iter_response(response, stream, response_format) # max_tokens, stop - else: - response = iter_response(response, stream, response_format, max_tokens, stop) + stop = [stop] if isinstance(stop, str) else stop + response = iter_response(response, stream, response_format, max_tokens, stop) return response if stream else next(response) class Chat(): @@ -203,7 +133,7 @@ class Chat(): def __init__(self, client: Client, provider: ProviderType = None): self.completions = Completions(client, provider) - + class ImageModels(): gemini = Gemini openai = OpenaiChat @@ -212,21 +142,9 @@ class ImageModels(): self.client = client self.default = BingCreateImages(proxy=self.client.get_proxy()) - def get(self, name: str) -> ImageProvider: - return getattr(self, name) if hasattr(self, name) else self.default + def get(self, name: str, default: ImageProvider = None) -> ImageProvider: + return getattr(self, name) if hasattr(self, name) else default or self.default -class ImagesResponse(Model): - data: list[Image] - - def __init__(self, data: list) -> None: - self.data = data - -class Image(Model): - url: str - - def __init__(self, url: str) -> None: - self.url = url - class Images(): def __init__(self, client: Client, provider: ImageProvider = None): self.client: Client = client @@ -234,7 +152,7 @@ class Images(): self.models: ImageModels = ImageModels(client) def generate(self, prompt, model: str = None, **kwargs): - provider = self.models.get(model) if model else self.provider or self.models.get(model) + provider = self.models.get(model, self.provider) if isinstance(provider, BaseProvider) or isinstance(provider, type) and issubclass(provider, BaseProvider): prompt = f"create a image: {prompt}" response = provider.create_completion( @@ -246,14 +164,15 @@ class Images(): ) else: response = provider.create(prompt) - + for chunk in response: if isinstance(chunk, ImageProviderResponse): - return ImagesResponse([Image(image)for image in list(chunk.images)]) + images = [chunk.images] if isinstance(chunk.images, str) else chunk.images + return ImagesResponse([Image(image) for image in images]) raise NoImageResponseError() def create_variation(self, image: ImageType, model: str = None, **kwargs): - provider = self.models.get(model) if model else self.provider + provider = self.models.get(model, self.provider) result = None if isinstance(provider, type) and issubclass(provider, BaseProvider): response = provider.create_completion( diff --git a/g4f/stubs.py b/g4f/stubs.py new file mode 100644 index 00000000..1cbbb134 --- /dev/null +++ b/g4f/stubs.py @@ -0,0 +1,44 @@ + +from __future__ import annotations + +class Model(): + def __getitem__(self, item): + return getattr(self, item) + +class ChatCompletion(Model): + def __init__(self, content: str, finish_reason: str): + self.choices = [ChatCompletionChoice(ChatCompletionMessage(content, finish_reason))] + +class ChatCompletionChunk(Model): + def __init__(self, content: str, finish_reason: str): + self.choices = [ChatCompletionDeltaChoice(ChatCompletionDelta(content, finish_reason))] + +class ChatCompletionMessage(Model): + def __init__(self, content: str, finish_reason: str): + self.content = content + self.finish_reason = finish_reason + +class ChatCompletionChoice(Model): + def __init__(self, message: ChatCompletionMessage): + self.message = message + +class ChatCompletionDelta(Model): + def __init__(self, content: str, finish_reason: str): + self.content = content + self.finish_reason = finish_reason + +class ChatCompletionDeltaChoice(Model): + def __init__(self, delta: ChatCompletionDelta): + self.delta = delta + +class Image(Model): + url: str + + def __init__(self, url: str) -> None: + self.url = url + +class ImagesResponse(Model): + data: list[Image] + + def __init__(self, data: list) -> None: + self.data = data
\ No newline at end of file |