summaryrefslogtreecommitdiffstats
path: root/g4f/gui/server/api.py
blob: bb5d0b5a9d32c6525ea281f9bd813ee7a18ea92f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
from __future__ import annotations

import logging
import os
import os.path
import uuid
import asyncio
import time
import base64
from aiohttp import ClientSession
from typing import Iterator, Optional
from flask import send_from_directory

from g4f import version, models
from g4f import get_last_provider, ChatCompletion
from g4f.errors import VersionNotFoundError
from g4f.typing import Cookies
from g4f.image import ImagePreview, ImageResponse, is_accepted_format
from g4f.requests.aiohttp import get_connector
from g4f.Provider import ProviderType, __providers__, __map__
from g4f.providers.base_provider import ProviderModelMixin, FinishReason
from g4f.providers.conversation import BaseConversation

conversations: dict[dict[str, BaseConversation]] = {}
images_dir = "./generated_images"

class Api():

    @staticmethod
    def get_models() -> list[str]:
        """
        Return a list of all models.

        Fetches and returns a list of all available models in the system.

        Returns:
            List[str]: A list of model names.
        """
        return models._all_models

    @staticmethod
    def get_provider_models(provider: str) -> list[dict]:
        if provider in __map__:
            provider: ProviderType = __map__[provider]
            if issubclass(provider, ProviderModelMixin):
                return [{"model": model, "default": model == provider.default_model} for model in provider.get_models()]
            elif provider.supports_gpt_35_turbo or provider.supports_gpt_4:
                return [
                    *([{"model": "gpt-4", "default": not provider.supports_gpt_4}] if provider.supports_gpt_4 else []),
                    *([{"model": "gpt-3.5-turbo", "default": not provider.supports_gpt_4}] if provider.supports_gpt_35_turbo else [])
                ]
            else:
                return [];

    @staticmethod
    def get_image_models() -> list[dict]:
        image_models = []
        index = []
        for provider in __providers__:
            if hasattr(provider, "image_models"):
                if hasattr(provider, "get_models"):
                    provider.get_models()
                parent = provider
                if hasattr(provider, "parent"):
                    parent = __map__[provider.parent]
                if parent.__name__ not in index:
                    for model in provider.image_models:
                        image_models.append({
                            "provider": parent.__name__,
                            "url": parent.url,
                            "label": parent.label if hasattr(parent, "label") else None,
                            "image_model": model,
                            "vision_model": parent.default_vision_model if hasattr(parent, "default_vision_model") else None
                        })
                        index.append(parent.__name__)
            elif hasattr(provider, "default_vision_model") and provider.__name__ not in index:
                image_models.append({
                    "provider": provider.__name__,
                    "url": provider.url,
                    "label": provider.label if hasattr(provider, "label") else None,
                    "image_model": None,
                    "vision_model": provider.default_vision_model
                })
                index.append(provider.__name__)
        return image_models

    @staticmethod
    def get_providers() -> list[str]:
        """
        Return a list of all working providers.
        """
        return {
            provider.__name__: (provider.label
                if hasattr(provider, "label")
                else provider.__name__) +
                (" (WebDriver)"
                if "webdriver" in provider.get_parameters()
                else "") + 
                (" (Auth)"
                if provider.needs_auth
                else "")
            for provider in __providers__
            if provider.working
        }

    @staticmethod
    def get_version():
        """
        Returns the current and latest version of the application.

        Returns:
            dict: A dictionary containing the current and latest version.
        """
        try:
            current_version = version.utils.current_version
        except VersionNotFoundError:
            current_version = None
        return {
            "version": current_version,
            "latest_version": version.utils.latest_version,
        }

    def serve_images(self, name):
        return send_from_directory(os.path.abspath(images_dir), name)

    def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict):
        """
        Prepares arguments for chat completion based on the request data.

        Reads the request and prepares the necessary arguments for handling 
        a chat completion request.

        Returns:
            dict: Arguments prepared for chat completion.
        """ 
        model = json_data.get('model') or models.default
        provider = json_data.get('provider')
        messages = json_data['messages']
        api_key = json_data.get("api_key")
        if api_key is not None:
            kwargs["api_key"] = api_key
        if json_data.get('web_search'):
            if provider in ("Bing", "HuggingChat"):
                kwargs['web_search'] = True
            else:
                from .internet import get_search_message
                messages[-1]["content"] = get_search_message(messages[-1]["content"])

        conversation_id = json_data.get("conversation_id")
        if conversation_id and provider in conversations and conversation_id in conversations[provider]:
            kwargs["conversation"] = conversations[provider][conversation_id]

        return {
            "model": model,
            "provider": provider,
            "messages": messages,
            "stream": True,
            "ignore_stream": True,
            "return_conversation": True,
            **kwargs
        }

    def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str) -> Iterator:
        """
        Creates and returns a streaming response for the conversation.

        Args:
            kwargs (dict): Arguments for creating the chat completion.

        Yields:
            str: JSON formatted response chunks for the stream.

        Raises:
            Exception: If an error occurs during the streaming process.
        """
        try:
            first = True
            for chunk in ChatCompletion.create(**kwargs):
                if first:
                    first = False
                    yield self._format_json("provider", get_last_provider(True))
                if isinstance(chunk, BaseConversation):
                    if provider not in conversations:
                        conversations[provider] = {}
                    conversations[provider][conversation_id] = chunk
                    yield self._format_json("conversation", conversation_id)
                elif isinstance(chunk, Exception):
                    logging.exception(chunk)
                    yield self._format_json("message", get_error_message(chunk))
                elif isinstance(chunk, ImagePreview):
                    yield self._format_json("preview", chunk.to_string())
                elif isinstance(chunk, ImageResponse):
                    async def copy_images(images: list[str], cookies: Optional[Cookies] = None):
                        async with ClientSession(
                            connector=get_connector(None, os.environ.get("G4F_PROXY")),
                            cookies=cookies
                        ) as session:
                            async def copy_image(image):
                                if image.startswith("data:"):
                                    # Processing the data URL
                                    data_uri_parts = image.split(",")
                                    if len(data_uri_parts) == 2:
                                        content_type, base64_data = data_uri_parts
                                        extension = content_type.split("/")[-1].split(";")[0]
                                        target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}.{extension}")
                                        with open(target, "wb") as f:
                                            f.write(base64.b64decode(base64_data))
                                        return f"/images/{os.path.basename(target)}"
                                    else:
                                        return None
                                else:
                                    # Processing a regular URL
                                    async with session.get(image) as response:
                                        target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}")
                                        with open(target, "wb") as f:
                                            async for chunk in response.content.iter_any():
                                                f.write(chunk)
                                        with open(target, "rb") as f:
                                            extension = is_accepted_format(f.read(12)).split("/")[-1]
                                            extension = "jpg" if extension == "jpeg" else extension
                                        new_target = f"{target}.{extension}"
                                        os.rename(target, new_target)
                                        return f"/images/{os.path.basename(new_target)}"
                            return await asyncio.gather(*[copy_image(image) for image in images])
                    images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies")))
                    yield self._format_json("content", str(ImageResponse(images, chunk.alt)))
                elif not isinstance(chunk, FinishReason):
                    yield self._format_json("content", str(chunk))
        except Exception as e:
            logging.exception(e)
            yield self._format_json('error', get_error_message(e))

    def _format_json(self, response_type: str, content):
        """
        Formats and returns a JSON response.

        Args:
            response_type (str): The type of the response.
            content: The content to be included in the response.

        Returns:
            str: A JSON formatted string.
        """
        return {
            'type': response_type,
            response_type: content
        }

def get_error_message(exception: Exception) -> str:
    """
    Generates a formatted error message from an exception.

    Args:
        exception (Exception): The exception to format.

    Returns:
        str: A formatted error message string.
    """
    message = f"{type(exception).__name__}: {exception}"
    provider = get_last_provider()
    if provider is None:
        return message
    return f"{provider.__name__}: {message}"