From 91feb34054f529c37e10d98d2471c8c0c6780147 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Tue, 23 Jan 2024 19:44:48 +0100 Subject: Add ProviderModelMixin for model selection --- g4f/Provider/DeepInfra.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) (limited to 'g4f/Provider/DeepInfra.py') diff --git a/g4f/Provider/DeepInfra.py b/g4f/Provider/DeepInfra.py index acde1200..2f34b679 100644 --- a/g4f/Provider/DeepInfra.py +++ b/g4f/Provider/DeepInfra.py @@ -1,18 +1,27 @@ from __future__ import annotations import json -from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider -from ..requests import StreamSession +import requests +from ..typing import AsyncResult, Messages +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin +from ..requests import StreamSession -class DeepInfra(AsyncGeneratorProvider): +class DeepInfra(AsyncGeneratorProvider, ProviderModelMixin): url = "https://deepinfra.com" working = True supports_stream = True supports_message_history = True - + default_model = 'meta-llama/Llama-2-70b-chat-hf' + @staticmethod + def get_models(): + url = 'https://api.deepinfra.com/models/featured' + models = requests.get(url).json() + return [model['model_name'] for model in models] + + @classmethod async def create_async_generator( + cls, model: str, messages: Messages, stream: bool, @@ -21,8 +30,6 @@ class DeepInfra(AsyncGeneratorProvider): auth: str = None, **kwargs ) -> AsyncResult: - if not model: - model = 'meta-llama/Llama-2-70b-chat-hf' headers = { 'Accept-Encoding': 'gzip, deflate, br', 'Accept-Language': 'en-US', @@ -49,7 +56,7 @@ class DeepInfra(AsyncGeneratorProvider): impersonate="chrome110" ) as session: json_data = { - 'model' : model, + 'model' : cls.get_model(model), 'messages': messages, 'stream' : True } @@ -70,7 +77,8 @@ class DeepInfra(AsyncGeneratorProvider): if token: if first: token = token.lstrip() + if token: first = False - yield token + yield token except Exception: raise RuntimeError(f"Response: {line}") -- cgit v1.2.3