diff options
author | Tekky <98614666+xtekky@users.noreply.github.com> | 2024-10-03 13:20:46 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-03 13:20:46 +0200 |
commit | 6d19ba695655c85916deb2b6ba67c831fcf4c885 (patch) | |
tree | 467f56307f7a538b48eb9ffd6e318a9ff33e0159 /g4f/client | |
parent | ~ (diff) | |
parent | feat(g4f/models.py): enhance llama_3_1_405b with Blackbox provider (diff) | |
download | gpt4free-6d19ba695655c85916deb2b6ba67c831fcf4c885.tar gpt4free-6d19ba695655c85916deb2b6ba67c831fcf4c885.tar.gz gpt4free-6d19ba695655c85916deb2b6ba67c831fcf4c885.tar.bz2 gpt4free-6d19ba695655c85916deb2b6ba67c831fcf4c885.tar.lz gpt4free-6d19ba695655c85916deb2b6ba67c831fcf4c885.tar.xz gpt4free-6d19ba695655c85916deb2b6ba67c831fcf4c885.tar.zst gpt4free-6d19ba695655c85916deb2b6ba67c831fcf4c885.zip |
Diffstat (limited to 'g4f/client')
-rw-r--r-- | g4f/client/async_client.py | 151 |
1 files changed, 95 insertions, 56 deletions
diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py index 9caa74b2..b4d52a60 100644 --- a/g4f/client/async_client.py +++ b/g4f/client/async_client.py @@ -33,6 +33,12 @@ except NameError: except StopAsyncIteration: raise StopIteration +async def safe_aclose(generator): + try: + await generator.aclose() + except Exception as e: + logging.warning(f"Error while closing generator: {e}") + async def iter_response( response: AsyncIterator[str], stream: bool, @@ -45,48 +51,56 @@ async def iter_response( completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) idx = 0 - async for chunk in response: - if isinstance(chunk, FinishReason): - finish_reason = chunk.reason - break - elif isinstance(chunk, BaseConversation): - yield chunk - continue + try: + async for chunk in response: + if isinstance(chunk, FinishReason): + finish_reason = chunk.reason + break + elif isinstance(chunk, BaseConversation): + yield chunk + continue - content += str(chunk) - idx += 1 + content += str(chunk) + idx += 1 - if max_tokens is not None and idx >= max_tokens: - finish_reason = "length" + if max_tokens is not None and idx >= max_tokens: + finish_reason = "length" - first, content, chunk = find_stop(stop, content, chunk if stream else None) + first, content, chunk = find_stop(stop, content, chunk if stream else None) - if first != -1: - finish_reason = "stop" + if first != -1: + finish_reason = "stop" - if stream: - yield ChatCompletionChunk(chunk, None, completion_id, int(time.time())) + if stream: + yield ChatCompletionChunk(chunk, None, completion_id, int(time.time())) - if finish_reason is not None: - break + if finish_reason is not None: + break - finish_reason = "stop" if finish_reason is None else finish_reason + finish_reason = "stop" if finish_reason is None else finish_reason - if stream: - yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time())) - else: - if response_format is not None and "type" in response_format: - if response_format["type"] == "json_object": - content = filter_json(content) - yield ChatCompletion(content, finish_reason, completion_id, int(time.time())) + if stream: + yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time())) + else: + if response_format is not None and "type" in response_format: + if response_format["type"] == "json_object": + content = filter_json(content) + yield ChatCompletion(content, finish_reason, completion_id, int(time.time())) + finally: + if hasattr(response, 'aclose'): + await safe_aclose(response) async def iter_append_model_and_provider(response: AsyncIterator) -> AsyncIterator: last_provider = None - async for chunk in response: - last_provider = get_last_provider(True) if last_provider is None else last_provider - chunk.model = last_provider.get("model") - chunk.provider = last_provider.get("name") - yield chunk + try: + async for chunk in response: + last_provider = get_last_provider(True) if last_provider is None else last_provider + chunk.model = last_provider.get("model") + chunk.provider = last_provider.get("name") + yield chunk + finally: + if hasattr(response, 'aclose'): + await safe_aclose(response) class AsyncClient(BaseClient): def __init__( @@ -158,8 +172,6 @@ class Completions: response = iter_append_model_and_provider(response) return response if stream else await anext(response) - - class Chat: completions: Completions @@ -168,14 +180,18 @@ class Chat: async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]: logging.info("Starting iter_image_response") - async for chunk in response: - logging.info(f"Processing chunk: {chunk}") - if isinstance(chunk, ImageProviderResponse): - logging.info("Found ImageProviderResponse") - return ImagesResponse([Image(image) for image in chunk.get_list()]) - - logging.warning("No ImageProviderResponse found in the response") - return None + try: + async for chunk in response: + logging.info(f"Processing chunk: {chunk}") + if isinstance(chunk, ImageProviderResponse): + logging.info("Found ImageProviderResponse") + return ImagesResponse([Image(image) for image in chunk.get_list()]) + + logging.warning("No ImageProviderResponse found in the response") + return None + finally: + if hasattr(response, 'aclose'): + await safe_aclose(response) async def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator: logging.info(f"Creating image with provider: {provider}, model: {model}, prompt: {prompt}") @@ -220,12 +236,25 @@ class Images: if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): logging.info("Using AsyncGeneratorProvider") messages = [{"role": "user", "content": prompt}] - async for response in provider.create_async_generator(model, messages, **kwargs): - if isinstance(response, ImageResponse): - return self._process_image_response(response) - elif isinstance(response, str): - image_response = ImageResponse([response], prompt) - return self._process_image_response(image_response) + generator = None + try: + generator = provider.create_async_generator(model, messages, **kwargs) + async for response in generator: + logging.debug(f"Received response: {type(response)}") + if isinstance(response, ImageResponse): + return self._process_image_response(response) + elif isinstance(response, str): + image_response = ImageResponse([response], prompt) + return self._process_image_response(image_response) + except RuntimeError as e: + if "async generator ignored GeneratorExit" in str(e): + logging.warning("Generator ignored GeneratorExit, handling gracefully") + else: + raise + finally: + if generator and hasattr(generator, 'aclose'): + await safe_aclose(generator) + logging.info("AsyncGeneratorProvider processing completed") elif hasattr(provider, 'create'): logging.info("Using provider's create method") async_create = asyncio.iscoroutinefunction(provider.create) @@ -241,7 +270,7 @@ class Images: return self._process_image_response(image_response) elif hasattr(provider, 'create_completion'): logging.info("Using provider's create_completion method") - response = await create_image(provider, prompt, model, **kwargs) + response = await create_image(self.client, provider, prompt, model, **kwargs) async for chunk in response: if isinstance(chunk, ImageProviderResponse): logging.info("Found ImageProviderResponse") @@ -277,12 +306,24 @@ class Images: if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider): messages = [{"role": "user", "content": "create a variation of this image"}] image_data = to_data_uri(image) - async for response in provider.create_async_generator(model, messages, image=image_data, **kwargs): - if isinstance(response, ImageResponse): - return self._process_image_response(response) - elif isinstance(response, str): - image_response = ImageResponse([response], "Image variation") - return self._process_image_response(image_response) + generator = None + try: + generator = provider.create_async_generator(model, messages, image=image_data, **kwargs) + async for response in generator: + if isinstance(response, ImageResponse): + return self._process_image_response(response) + elif isinstance(response, str): + image_response = ImageResponse([response], "Image variation") + return self._process_image_response(image_response) + except RuntimeError as e: + if "async generator ignored GeneratorExit" in str(e): + logging.warning("Generator ignored GeneratorExit in create_variation, handling gracefully") + else: + raise + finally: + if generator and hasattr(generator, 'aclose'): + await safe_aclose(generator) + logging.info("AsyncGeneratorProvider processing completed in create_variation") elif hasattr(provider, 'create_variation'): if asyncio.iscoroutinefunction(provider.create_variation): response = await provider.create_variation(image, **kwargs) @@ -296,5 +337,3 @@ class Images: return self._process_image_response(image_response) else: raise ValueError(f"Provider {provider} does not support image variation") - - raise NoImageResponseError("Failed to create image variation") |