diff options
author | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-05-18 07:37:37 +0200 |
---|---|---|
committer | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-05-18 07:37:37 +0200 |
commit | b1dafc0ef79bdd94f69c783877217d8a5524d460 (patch) | |
tree | 7851c85bafa452d0fa9cce74aa7b383aa704f101 | |
parent | Merge pull request #1969 from hlohaus/leech (diff) | |
download | gpt4free-b1dafc0ef79bdd94f69c783877217d8a5524d460.tar gpt4free-b1dafc0ef79bdd94f69c783877217d8a5524d460.tar.gz gpt4free-b1dafc0ef79bdd94f69c783877217d8a5524d460.tar.bz2 gpt4free-b1dafc0ef79bdd94f69c783877217d8a5524d460.tar.lz gpt4free-b1dafc0ef79bdd94f69c783877217d8a5524d460.tar.xz gpt4free-b1dafc0ef79bdd94f69c783877217d8a5524d460.tar.zst gpt4free-b1dafc0ef79bdd94f69c783877217d8a5524d460.zip |
Diffstat (limited to '')
-rw-r--r-- | etc/examples/api.py | 2 | ||||
-rw-r--r-- | g4f/Provider/Liaobots.py | 79 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/Gemini.py | 36 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 12 | ||||
-rw-r--r-- | g4f/Provider/openai/har_file.py | 30 | ||||
-rw-r--r-- | g4f/Provider/openai/proofofwork.py | 48 | ||||
-rw-r--r-- | g4f/api/__init__.py | 66 | ||||
-rw-r--r-- | g4f/client/async_client.py | 90 | ||||
-rw-r--r-- | g4f/client/service.py | 6 | ||||
-rw-r--r-- | g4f/client/stubs.py | 23 | ||||
-rw-r--r-- | g4f/image.py | 12 |
11 files changed, 288 insertions, 116 deletions
diff --git a/etc/examples/api.py b/etc/examples/api.py index d4d03a77..1ab9b51b 100644 --- a/etc/examples/api.py +++ b/etc/examples/api.py @@ -3,7 +3,7 @@ import json url = "http://localhost:1337/v1/chat/completions" body = { "model": "", - "provider": "MetaAI", + "provider": "", "stream": True, "messages": [ {"role": "assistant", "content": "What can you do? Who are you?"} diff --git a/g4f/Provider/Liaobots.py b/g4f/Provider/Liaobots.py index deb7899c..75ecf300 100644 --- a/g4f/Provider/Liaobots.py +++ b/g4f/Provider/Liaobots.py @@ -10,6 +10,15 @@ from .helper import get_connector from ..requests import raise_for_status models = { + "gpt-4o": { + "context": "8K", + "id": "gpt-4o-free", + "maxLength": 31200, + "model": "ChatGPT", + "name": "GPT-4o-free", + "provider": "OpenAI", + "tokenLimit": 7800, + }, "gpt-3.5-turbo": { "id": "gpt-3.5-turbo", "name": "GPT-3.5-Turbo", @@ -95,7 +104,7 @@ class Liaobots(AsyncGeneratorProvider, ProviderModelMixin): model_aliases = { "claude-v2": "claude-2" } - _auth_code = None + _auth_code = "" _cookie_jar = None @classmethod @@ -120,7 +129,13 @@ class Liaobots(AsyncGeneratorProvider, ProviderModelMixin): cookie_jar=cls._cookie_jar, connector=get_connector(connector, proxy, True) ) as session: - cls._auth_code = auth if isinstance(auth, str) else cls._auth_code + data = { + "conversationId": str(uuid.uuid4()), + "model": models[cls.get_model(model)], + "messages": messages, + "key": "", + "prompt": kwargs.get("system_message", "You are a helpful assistant."), + } if not cls._auth_code: async with session.post( "https://liaobots.work/recaptcha/api/login", @@ -128,31 +143,49 @@ class Liaobots(AsyncGeneratorProvider, ProviderModelMixin): verify_ssl=False ) as response: await raise_for_status(response) + try: async with session.post( "https://liaobots.work/api/user", - json={"authcode": ""}, + json={"authcode": cls._auth_code}, verify_ssl=False ) as response: await raise_for_status(response) cls._auth_code = (await response.json(content_type=None))["authCode"] + if not cls._auth_code: + raise RuntimeError("Empty auth code") cls._cookie_jar = session.cookie_jar - - data = { - "conversationId": str(uuid.uuid4()), - "model": models[cls.get_model(model)], - "messages": messages, - "key": "", - "prompt": kwargs.get("system_message", "You are a helpful assistant."), - } - async with session.post( - "https://liaobots.work/api/chat", - json=data, - headers={"x-auth-code": cls._auth_code}, - verify_ssl=False - ) as response: - await raise_for_status(response) - async for chunk in response.content.iter_any(): - if b"<html coupert-item=" in chunk: - raise RuntimeError("Invalid session") - if chunk: - yield chunk.decode(errors="ignore") + async with session.post( + "https://liaobots.work/api/chat", + json=data, + headers={"x-auth-code": cls._auth_code}, + verify_ssl=False + ) as response: + await raise_for_status(response) + async for chunk in response.content.iter_any(): + if b"<html coupert-item=" in chunk: + raise RuntimeError("Invalid session") + if chunk: + yield chunk.decode(errors="ignore") + except: + async with session.post( + "https://liaobots.work/api/user", + json={"authcode": "pTIQr4FTnVRfr"}, + verify_ssl=False + ) as response: + await raise_for_status(response) + cls._auth_code = (await response.json(content_type=None))["authCode"] + if not cls._auth_code: + raise RuntimeError("Empty auth code") + cls._cookie_jar = session.cookie_jar + async with session.post( + "https://liaobots.work/api/chat", + json=data, + headers={"x-auth-code": cls._auth_code}, + verify_ssl=False + ) as response: + await raise_for_status(response) + async for chunk in response.content.iter_any(): + if b"<html coupert-item=" in chunk: + raise RuntimeError("Invalid session") + if chunk: + yield chunk.decode(errors="ignore") diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py index e468f64a..f9b1c4a5 100644 --- a/g4f/Provider/needs_auth/Gemini.py +++ b/g4f/Provider/needs_auth/Gemini.py @@ -4,6 +4,7 @@ import os import json import random import re +import base64 from aiohttp import ClientSession, BaseConnector @@ -22,7 +23,7 @@ from ..base_provider import AsyncGeneratorProvider from ..helper import format_prompt, get_cookies from ...requests.raise_for_status import raise_for_status from ...errors import MissingAuthError, MissingRequirementsError -from ...image import to_bytes, to_data_uri, ImageResponse +from ...image import to_bytes, ImageResponse, ImageDataResponse from ...webdriver import get_browser, get_driver_cookies REQUEST_HEADERS = { @@ -122,6 +123,7 @@ class Gemini(AsyncGeneratorProvider): connector: BaseConnector = None, image: ImageType = None, image_name: str = None, + response_format: str = None, **kwargs ) -> AsyncResult: prompt = format_prompt(messages) @@ -192,22 +194,22 @@ class Gemini(AsyncGeneratorProvider): if image_prompt: images = [image[0][3][3] for image in response_part[4][0][12][7][0]] resolved_images = [] - preview = [] - for image in images: - async with client.get(image, allow_redirects=False) as fetch: - image = fetch.headers["location"] - async with client.get(image, allow_redirects=False) as fetch: - image = fetch.headers["location"] - resolved_images.append(image) - preview.append(image.replace('=s512', '=s200')) - # preview_url = image.replace('=s512', '=s200') - # async with client.get(preview_url) as fetch: - # preview_data = to_data_uri(await fetch.content.read()) - # async with client.get(image) as fetch: - # data = to_data_uri(await fetch.content.read()) - # preview.append(preview_data) - # resolved_images.append(data) - yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview}) + if response_format == "b64_json": + for image in images: + async with client.get(image) as response: + data = base64.b64encode(await response.content.read()).decode() + resolved_images.append(data) + yield ImageDataResponse(resolved_images, image_prompt) + else: + preview = [] + for image in images: + async with client.get(image, allow_redirects=False) as fetch: + image = fetch.headers["location"] + async with client.get(image, allow_redirects=False) as fetch: + image = fetch.headers["location"] + resolved_images.append(image) + preview.append(image.replace('=s512', '=s200')) + yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview}) def build_request( prompt: str, diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index b4b8bb02..28d0558b 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -61,7 +61,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): supports_system_message = True default_model = None default_vision_model = "gpt-4o" - models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo", "gpt-4o"] + models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo", "gpt-4o", "auto"] model_aliases = { "text-davinci-002-render-sha": "gpt-3.5-turbo", "": "gpt-3.5-turbo", @@ -394,10 +394,11 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): print(f"{e.__class__.__name__}: {e}") arkose_token = None + proofTokens = None if cls.default_model is None: error = None try: - arkose_token, api_key, cookies, headers = await getArkoseAndAccessToken(proxy) + arkose_token, api_key, cookies, headers, proofTokens = await getArkoseAndAccessToken(proxy) cls._create_request_args(cookies, headers) cls._set_api_key(api_key) except NoValidHarFileError as e: @@ -413,17 +414,17 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if cls._api_key is None else f"{cls.url}/backend-api/sentinel/chat-requirements", json={"conversation_mode_kind": "primary_assistant"}, + #json={"p": generate_proof_token(True, user_agent=cls._headers["user-agent"], proofTokens=proofTokens)}, headers=cls._headers ) as response: cls._update_request_args(session) await raise_for_status(response) data = await response.json() - blob = data["arkose"]["dx"] - need_arkose = data["arkose"]["required"] + need_arkose = data.get("arkose", {}).get("required") chat_token = data["token"] proofofwork = "" if "proofofwork" in data: - proofofwork = generate_proof_token(**data["proofofwork"], user_agent=cls._headers["user-agent"]) + proofofwork = generate_proof_token(**data["proofofwork"], user_agent=cls._headers["user-agent"], proofTokens=proofTokens) if need_arkose and arkose_token is None: arkose_token, api_key, cookies, headers = await getArkoseAndAccessToken(proxy) @@ -435,7 +436,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if debug.logging: print( 'Arkose:', False if not need_arkose else arkose_token[:12]+"...", - 'Turnstile:', data["turnstile"]["required"], 'Proofofwork:', False if proofofwork is None else proofofwork[:12]+"...", ) diff --git a/g4f/Provider/openai/har_file.py b/g4f/Provider/openai/har_file.py index 220c20bf..eefe305f 100644 --- a/g4f/Provider/openai/har_file.py +++ b/g4f/Provider/openai/har_file.py @@ -12,6 +12,7 @@ from copy import deepcopy from .crypt import decrypt, encrypt from ...requests import StreamSession +from ... import debug class NoValidHarFileError(Exception): ... @@ -31,6 +32,7 @@ chatArk: arkReq = None accessToken: str = None cookies: dict = None headers: dict = None +proofTokens: list = [] def readHAR(): dirPath = "./" @@ -54,6 +56,15 @@ def readHAR(): # Error: not a HAR file! continue for v in harFile['log']['entries']: + v_headers = get_headers(v) + try: + if "openai-sentinel-proof-token" in v_headers: + proofTokens.append(json.loads(base64.b64decode( + v_headers["openai-sentinel-proof-token"].split("gAAAAAB", 1)[-1].encode() + ).decode())) + except Exception as e: + if debug.logging: + print(f"Read proof token: {e}") if arkPreURL in v['request']['url']: chatArks.append(parseHAREntry(v)) elif v['request']['url'] == sessionUrl: @@ -61,13 +72,13 @@ def readHAR(): accessToken = json.loads(v["response"]["content"]["text"]).get("accessToken") except KeyError: continue - cookies = {c['name']: c['value'] for c in v['request']['cookies']} - headers = get_headers(v) + cookies = {c['name']: c['value'] for c in v['request']['cookies'] if c['name'] != "oai-did"} + headers = v_headers if not accessToken: raise NoValidHarFileError("No accessToken found in .har files") if not chatArks: - return None, accessToken, cookies, headers - return chatArks.pop(), accessToken, cookies, headers + return None, accessToken, cookies, headers, proofTokens + return chatArks.pop(), accessToken, cookies, headers, proofTokens def get_headers(entry) -> dict: return {h['name'].lower(): h['value'] for h in entry['request']['headers'] if h['name'].lower() not in ['content-length', 'cookie'] and not h['name'].startswith(':')} @@ -101,7 +112,8 @@ def genArkReq(chatArk: arkReq) -> arkReq: async def sendRequest(tmpArk: arkReq, proxy: str = None): async with StreamSession(headers=tmpArk.arkHeader, cookies=tmpArk.arkCookies, proxies={"https": proxy}) as session: async with session.post(tmpArk.arkURL, data=tmpArk.arkBody) as response: - arkose = (await response.json()).get("token") + data = await response.json() + arkose = data.get("token") if "sup=1|rid=" not in arkose: return RuntimeError("No valid arkose token generated") return arkose @@ -131,10 +143,10 @@ def getN() -> str: return base64.b64encode(timestamp.encode()).decode() async def getArkoseAndAccessToken(proxy: str) -> tuple[str, str, dict, dict]: - global chatArk, accessToken, cookies, headers + global chatArk, accessToken, cookies, headers, proofTokens if chatArk is None or accessToken is None: - chatArk, accessToken, cookies, headers = readHAR() + chatArk, accessToken, cookies, headers, proofTokens = readHAR() if chatArk is None: - return None, accessToken, cookies, headers + return None, accessToken, cookies, headers, proofTokens newReq = genArkReq(chatArk) - return await sendRequest(newReq, proxy), accessToken, cookies, headers + return await sendRequest(newReq, proxy), accessToken, cookies, headers, proofTokens diff --git a/g4f/Provider/openai/proofofwork.py b/g4f/Provider/openai/proofofwork.py index 51d96bc4..cbce153f 100644 --- a/g4f/Provider/openai/proofofwork.py +++ b/g4f/Provider/openai/proofofwork.py @@ -2,35 +2,51 @@ import random import hashlib import json import base64 -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone -def generate_proof_token(required: bool, seed: str, difficulty: str, user_agent: str): +proof_token_cache: dict = {} + +def generate_proof_token(required: bool, seed: str = None, difficulty: str = None, user_agent: str = None, proofTokens: list = None): if not required: return - - cores = [8, 12, 16, 24] - screens = [3000, 4000, 6000] - - core = random.choice(cores) - screen = random.choice(screens) + if seed is not None and seed in proof_token_cache: + return proof_token_cache[seed] # Get current UTC time now_utc = datetime.now(timezone.utc) parse_time = now_utc.strftime('%a, %d %b %Y %H:%M:%S GMT') - config = [core + screen, parse_time, None, 0, user_agent, "https://tcr9i.chat.openai.com/v2/35536E1E-65B4-4D96-9D97-6ADB7EFF8147/api.js","dpl=53d243de46ff04dadd88d293f088c2dd728f126f","en","en-US",442,"plugins−[object PluginArray]","","alert"] - - diff_len = len(difficulty) // 2 - + if proofTokens: + config = random.choice(proofTokens) + else: + screen = random.choice([3008, 4010, 6000]) * random.choice([1, 2, 4]) + config = [ + screen, parse_time, + None, 0, user_agent, + "https://tcr9i.chat.openai.com/v2/35536E1E-65B4-4D96-9D97-6ADB7EFF8147/api.js", + "dpl=1440a687921de39ff5ee56b92807faaadce73f13","en","en-US", + None, + "plugins−[object PluginArray]", + random.choice(["_reactListeningcfilawjnerp", "_reactListening9ne2dfo1i47", "_reactListening410nzwhan2a"]), + random.choice(["alert", "ontransitionend", "onprogress"]) + ] + + config[1] = parse_time + config[4] = user_agent + config[7] = random.randint(101, 2100) + + diff_len = None if difficulty is None else len(difficulty) for i in range(100000): config[3] = i json_data = json.dumps(config) base = base64.b64encode(json_data.encode()).decode() - hash_value = hashlib.sha3_512((seed + base).encode()).digest() + hash_value = hashlib.sha3_512((seed or "" + base).encode()).digest() - if hash_value.hex()[:diff_len] <= difficulty: - result = "gAAAAAB" + base - return result + if difficulty is None or hash_value.hex()[:diff_len] <= difficulty: + if seed is None: + return "gAAAAAC" + base + proof_token_cache[seed] = "gAAAAAB" + base + return proof_token_cache[seed] fallback_base = base64.b64encode(f'"{seed}"'.encode()).decode() return "gAAAAABwQ8Lk5FbGpA2NcR9dShT6gYjU7VxZ4D" + fallback_base diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index 7296a542..acb27e9c 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -47,6 +47,14 @@ class ChatCompletionsForm(BaseModel): web_search: Optional[bool] = None proxy: Optional[str] = None +class ImagesGenerateForm(BaseModel): + model: Optional[str] = None + provider: Optional[str] = None + prompt: str + response_format: Optional[str] = None + api_key: Optional[str] = None + proxy: Optional[str] = None + class AppConfig(): list_ignored_providers: Optional[list[str]] = None g4f_api_key: Optional[str] = None @@ -149,37 +157,53 @@ class Api: if auth_header and auth_header != "Bearer": config.api_key = auth_header response = self.client.chat.completions.create( - **{ - **AppConfig.defaults, - **config.dict(exclude_none=True), - }, - + **{ + **AppConfig.defaults, + **config.dict(exclude_none=True), + }, ignored=AppConfig.list_ignored_providers ) + if not config.stream: + return JSONResponse((await response).to_json()) + + async def streaming(): + try: + async for chunk in response: + yield f"data: {json.dumps(chunk.to_json())}\n\n" + except GeneratorExit: + pass + except Exception as e: + logging.exception(e) + yield f'data: {format_exception(e, config)}\n\n' + yield "data: [DONE]\n\n" + return StreamingResponse(streaming(), media_type="text/event-stream") + except Exception as e: logging.exception(e) return Response(content=format_exception(e, config), status_code=500, media_type="application/json") - if not config.stream: - return JSONResponse((await response).to_json()) - - async def streaming(): - try: - async for chunk in response: - yield f"data: {json.dumps(chunk.to_json())}\n\n" - except GeneratorExit: - pass - except Exception as e: - logging.exception(e) - yield f'data: {format_exception(e, config)}\n\n' - yield "data: [DONE]\n\n" - - return StreamingResponse(streaming(), media_type="text/event-stream") - @self.app.post("/v1/completions") async def completions(): return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json") + @self.app.post("/v1/images/generations") + async def images_generate(config: ImagesGenerateForm, request: Request = None, provider: str = None): + try: + config.provider = provider if config.provider is None else config.provider + if config.api_key is None and request is not None: + auth_header = request.headers.get("Authorization") + if auth_header is not None: + auth_header = auth_header.split(None, 1)[-1] + if auth_header and auth_header != "Bearer": + config.api_key = auth_header + response = self.client.images.generate( + **config.dict(exclude_none=True), + ) + return JSONResponse((await response).to_json()) + except Exception as e: + logging.exception(e) + return Response(content=format_exception(e, config), status_code=500, media_type="application/json") + def format_exception(e: Exception, config: ChatCompletionsForm) -> str: last_provider = g4f.get_last_provider(True) return json.dumps({ diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py index 07ad3357..1508e566 100644 --- a/g4f/client/async_client.py +++ b/g4f/client/async_client.py @@ -3,6 +3,9 @@ from __future__ import annotations import time import random import string +import asyncio +import base64 +from aiohttp import ClientSession, BaseConnector from .types import Client as BaseClient from .types import ProviderType, FinishReason @@ -11,9 +14,11 @@ from .types import AsyncIterResponse, ImageProvider from .image_models import ImageModels from .helper import filter_json, find_stop, filter_none, cast_iter_async from .service import get_last_provider, get_model_and_provider +from ..Provider import ProviderUtils from ..typing import Union, Messages, AsyncIterator, ImageType -from ..errors import NoImageResponseError -from ..image import ImageResponse as ImageProviderResponse +from ..errors import NoImageResponseError, ProviderNotFoundError +from ..requests.aiohttp import get_connector +from ..image import ImageResponse as ImageProviderResponse, ImageDataResponse try: anext @@ -156,12 +161,28 @@ class Chat(): def __init__(self, client: AsyncClient, provider: ProviderType = None): self.completions = Completions(client, provider) -async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]: +async def iter_image_response( + response: AsyncIterator, + response_format: str = None, + connector: BaseConnector = None, + proxy: str = None +) -> Union[ImagesResponse, None]: async for chunk in response: if isinstance(chunk, ImageProviderResponse): - return ImagesResponse([Image(image) for image in chunk.get_list()]) + if response_format == "b64_json": + async with ClientSession( + connector=get_connector(connector, proxy) + ) as session: + async def fetch_image(image): + async with session.get(image) as response: + return base64.b64encode(await response.content.read()).decode() + images = await asyncio.gather(*[fetch_image(image) for image in chunk.get_list()]) + return ImagesResponse([Image(None, image, chunk.alt) for image in images], int(time.time())) + return ImagesResponse([Image(image, None, chunk.alt) for image in chunk.get_list()], int(time.time())) + elif isinstance(chunk, ImageDataResponse): + return ImagesResponse([Image(None, image, chunk.alt) for image in chunk.get_list()], int(time.time())) -def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator: +def create_image(provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator: prompt = f"create a image with: {prompt}" if provider.__name__ == "You": kwargs["chat_mode"] = "create" @@ -169,7 +190,6 @@ def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model model, [{"role": "user", "content": prompt}], stream=True, - proxy=client.get_proxy(), **kwargs ) @@ -179,31 +199,71 @@ class Images(): self.provider: ImageProvider = provider self.models: ImageModels = ImageModels(client) - async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse: - provider = self.models.get(model, self.provider) + def get_provider(self, model: str, provider: ProviderType = None): + if isinstance(provider, str): + if provider in ProviderUtils.convert: + provider = ProviderUtils.convert[provider] + else: + raise ProviderNotFoundError(f'Provider not found: {provider}') + else: + provider = self.models.get(model, self.provider) + return provider + + async def generate( + self, + prompt, + model: str = "", + provider: ProviderType = None, + response_format: str = None, + connector: BaseConnector = None, + proxy: str = None, + **kwargs + ) -> ImagesResponse: + provider = self.get_provider(model, provider) if hasattr(provider, "create_async_generator"): - response = create_image(self.client, provider, prompt, **kwargs) + response = create_image( + provider, + prompt, + **filter_none( + response_format=response_format, + connector=connector, + proxy=self.client.get_proxy() if proxy is None else proxy, + ), + **kwargs + ) else: response = await provider.create_async(prompt) return ImagesResponse([Image(image) for image in response.get_list()]) - image = await iter_image_response(response) + image = await iter_image_response(response, response_format, connector, proxy) if image is None: raise NoImageResponseError() return image - async def create_variation(self, image: ImageType, model: str = None, **kwargs): - provider = self.models.get(model, self.provider) + async def create_variation( + self, + image: ImageType, + model: str = None, + response_format: str = None, + connector: BaseConnector = None, + proxy: str = None, + **kwargs + ): + provider = self.get_provider(model, provider) result = None if hasattr(provider, "create_async_generator"): response = provider.create_async_generator( "", [{"role": "user", "content": "create a image like this"}], - True, + stream=True, image=image, - proxy=self.client.get_proxy(), + **filter_none( + response_format=response_format, + connector=connector, + proxy=self.client.get_proxy() if proxy is None else proxy, + ), **kwargs ) - result = iter_image_response(response) + result = iter_image_response(response, response_format, connector, proxy) if result is None: raise NoImageResponseError() return result diff --git a/g4f/client/service.py b/g4f/client/service.py index dd6bf4b6..5fdb150c 100644 --- a/g4f/client/service.py +++ b/g4f/client/service.py @@ -4,7 +4,7 @@ from typing import Union from .. import debug, version from ..errors import ProviderNotFoundError, ModelNotFoundError, ProviderNotWorkingError, StreamNotSupportedError -from ..models import Model, ModelUtils +from ..models import Model, ModelUtils, default from ..Provider import ProviderUtils from ..providers.types import BaseRetryProvider, ProviderType from ..providers.retry_provider import IterProvider @@ -60,7 +60,9 @@ def get_model_and_provider(model : Union[Model, str], model = ModelUtils.convert[model] if not provider: - if isinstance(model, str): + if not model: + model = default + elif isinstance(model, str): raise ModelNotFoundError(f'Model not found: {model}') provider = model.best_provider diff --git a/g4f/client/stubs.py b/g4f/client/stubs.py index ceb51b83..8cf2bcba 100644 --- a/g4f/client/stubs.py +++ b/g4f/client/stubs.py @@ -96,13 +96,24 @@ class ChatCompletionDeltaChoice(Model): } class Image(Model): - url: str + def __init__(self, url: str = None, b64_json: str = None, revised_prompt: str = None) -> None: + if url is not None: + self.url = url + if b64_json is not None: + self.b64_json = b64_json + if revised_prompt is not None: + self.revised_prompt = revised_prompt - def __init__(self, url: str) -> None: - self.url = url + def to_json(self): + return self.__dict__ class ImagesResponse(Model): - data: list[Image] - - def __init__(self, data: list) -> None: + def __init__(self, data: list[Image], created: int = 0) -> None: self.data = data + self.created = created + + def to_json(self): + return { + **self.__dict__, + "data": [image.to_json() for image in self.data] + }
\ No newline at end of file diff --git a/g4f/image.py b/g4f/image.py index 270b59ad..3d339266 100644 --- a/g4f/image.py +++ b/g4f/image.py @@ -275,6 +275,18 @@ class ImagePreview(ImageResponse): def to_string(self): return super().__str__() +class ImageDataResponse(): + def __init__( + self, + images: Union[str, list], + alt: str, + ): + self.images = images + self.alt = alt + + def get_list(self) -> list[str]: + return [self.images] if isinstance(self.images, str) else self.images + class ImageRequest: def __init__( self, |