diff options
author | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-02-22 00:16:58 +0100 |
---|---|---|
committer | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-02-22 00:16:58 +0100 |
commit | e5b7f72b719814ffa2748e8e8ed1c6713a24e1a6 (patch) | |
tree | 70e9f87cb18f165428492d53a1c6e28c0828490a /g4f/Provider/base_provider.py | |
parent | Update docs / readme, Improve Gemini auth (diff) | |
download | gpt4free-e5b7f72b719814ffa2748e8e8ed1c6713a24e1a6.tar gpt4free-e5b7f72b719814ffa2748e8e8ed1c6713a24e1a6.tar.gz gpt4free-e5b7f72b719814ffa2748e8e8ed1c6713a24e1a6.tar.bz2 gpt4free-e5b7f72b719814ffa2748e8e8ed1c6713a24e1a6.tar.lz gpt4free-e5b7f72b719814ffa2748e8e8ed1c6713a24e1a6.tar.xz gpt4free-e5b7f72b719814ffa2748e8e8ed1c6713a24e1a6.tar.zst gpt4free-e5b7f72b719814ffa2748e8e8ed1c6713a24e1a6.zip |
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/base_provider.py | 283 |
1 files changed, 2 insertions, 281 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 8659f506..8e761dba 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -1,281 +1,2 @@ -from __future__ import annotations - -import sys -import asyncio -from asyncio import AbstractEventLoop -from concurrent.futures import ThreadPoolExecutor -from abc import abstractmethod -from inspect import signature, Parameter -from .helper import get_cookies, format_prompt -from ..typing import CreateResult, AsyncResult, Messages, Union -from ..base_provider import BaseProvider -from ..errors import NestAsyncioError, ModelNotSupportedError -from .. import debug - -if sys.version_info < (3, 10): - NoneType = type(None) -else: - from types import NoneType - -# Set Windows event loop policy for better compatibility with asyncio and curl_cffi -if sys.platform == 'win32': - if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - -def get_running_loop() -> Union[AbstractEventLoop, None]: - try: - loop = asyncio.get_running_loop() - if not hasattr(loop.__class__, "_nest_patched"): - raise NestAsyncioError( - 'Use "create_async" instead of "create" function in a running event loop. Or use "nest_asyncio" package.' - ) - return loop - except RuntimeError: - pass - -class AbstractProvider(BaseProvider): - """ - Abstract class for providing asynchronous functionality to derived classes. - """ - - @classmethod - async def create_async( - cls, - model: str, - messages: Messages, - *, - loop: AbstractEventLoop = None, - executor: ThreadPoolExecutor = None, - **kwargs - ) -> str: - """ - Asynchronously creates a result based on the given model and messages. - - Args: - cls (type): The class on which this method is called. - model (str): The model to use for creation. - messages (Messages): The messages to process. - loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. - executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None. - **kwargs: Additional keyword arguments. - - Returns: - str: The created result as a string. - """ - loop = loop or asyncio.get_running_loop() - - def create_func() -> str: - return "".join(cls.create_completion(model, messages, False, **kwargs)) - - return await asyncio.wait_for( - loop.run_in_executor(executor, create_func), - timeout=kwargs.get("timeout") - ) - - @classmethod - @property - def params(cls) -> str: - """ - Returns the parameters supported by the provider. - - Args: - cls (type): The class on which this property is called. - - Returns: - str: A string listing the supported parameters. - """ - sig = signature( - cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else - cls.create_async if issubclass(cls, AsyncProvider) else - cls.create_completion - ) - - def get_type_name(annotation: type) -> str: - return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation) - - args = "" - for name, param in sig.parameters.items(): - if name in ("self", "kwargs") or (name == "stream" and not cls.supports_stream): - continue - args += f"\n {name}" - args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else "" - args += f' = "{param.default}"' if param.default == "" else f" = {param.default}" if param.default is not Parameter.empty else "" - - return f"g4f.Provider.{cls.__name__} supports: ({args}\n)" - - -class AsyncProvider(AbstractProvider): - """ - Provides asynchronous functionality for creating completions. - """ - - @classmethod - def create_completion( - cls, - model: str, - messages: Messages, - stream: bool = False, - **kwargs - ) -> CreateResult: - """ - Creates a completion result synchronously. - - Args: - cls (type): The class on which this method is called. - model (str): The model to use for creation. - messages (Messages): The messages to process. - stream (bool): Indicates whether to stream the results. Defaults to False. - loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. - **kwargs: Additional keyword arguments. - - Returns: - CreateResult: The result of the completion creation. - """ - get_running_loop() - yield asyncio.run(cls.create_async(model, messages, **kwargs)) - - @staticmethod - @abstractmethod - async def create_async( - model: str, - messages: Messages, - **kwargs - ) -> str: - """ - Abstract method for creating asynchronous results. - - Args: - model (str): The model to use for creation. - messages (Messages): The messages to process. - **kwargs: Additional keyword arguments. - - Raises: - NotImplementedError: If this method is not overridden in derived classes. - - Returns: - str: The created result as a string. - """ - raise NotImplementedError() - - -class AsyncGeneratorProvider(AsyncProvider): - """ - Provides asynchronous generator functionality for streaming results. - """ - supports_stream = True - - @classmethod - def create_completion( - cls, - model: str, - messages: Messages, - stream: bool = True, - **kwargs - ) -> CreateResult: - """ - Creates a streaming completion result synchronously. - - Args: - cls (type): The class on which this method is called. - model (str): The model to use for creation. - messages (Messages): The messages to process. - stream (bool): Indicates whether to stream the results. Defaults to True. - loop (AbstractEventLoop, optional): The event loop to use. Defaults to None. - **kwargs: Additional keyword arguments. - - Returns: - CreateResult: The result of the streaming completion creation. - """ - loop = get_running_loop() - new_loop = False - if not loop: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - new_loop = True - - generator = cls.create_async_generator(model, messages, stream=stream, **kwargs) - gen = generator.__aiter__() - - # Fix for RuntimeError: async generator ignored GeneratorExit - async def await_callback(callback): - return await callback() - - try: - while True: - yield loop.run_until_complete(await_callback(gen.__anext__)) - except StopAsyncIteration: - ... - # Fix for: ResourceWarning: unclosed event loop - finally: - if new_loop: - loop.close() - asyncio.set_event_loop(None) - - @classmethod - async def create_async( - cls, - model: str, - messages: Messages, - **kwargs - ) -> str: - """ - Asynchronously creates a result from a generator. - - Args: - cls (type): The class on which this method is called. - model (str): The model to use for creation. - messages (Messages): The messages to process. - **kwargs: Additional keyword arguments. - - Returns: - str: The created result as a string. - """ - return "".join([ - chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs) - if not isinstance(chunk, Exception) - ]) - - @staticmethod - @abstractmethod - async def create_async_generator( - model: str, - messages: Messages, - stream: bool = True, - **kwargs - ) -> AsyncResult: - """ - Abstract method for creating an asynchronous generator. - - Args: - model (str): The model to use for creation. - messages (Messages): The messages to process. - stream (bool): Indicates whether to stream the results. Defaults to True. - **kwargs: Additional keyword arguments. - - Raises: - NotImplementedError: If this method is not overridden in derived classes. - - Returns: - AsyncResult: An asynchronous generator yielding results. - """ - raise NotImplementedError() - -class ProviderModelMixin: - default_model: str - models: list[str] = [] - model_aliases: dict[str, str] = {} - - @classmethod - def get_models(cls) -> list[str]: - return cls.models - - @classmethod - def get_model(cls, model: str) -> str: - if not model: - model = cls.default_model - elif model in cls.model_aliases: - model = cls.model_aliases[model] - elif model not in cls.get_models(): - raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") - debug.last_model = model - return model
\ No newline at end of file +from ..providers.base_provider import * +from .helper import get_cookies, format_prompt
\ No newline at end of file |