summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorH Lohaus <hlohaus@users.noreply.github.com>2024-05-18 08:12:17 +0200
committerGitHub <noreply@github.com>2024-05-18 08:12:17 +0200
commit4c3472f5417281ecf902edaec390bb1f7bafd808 (patch)
tree7851c85bafa452d0fa9cce74aa7b383aa704f101
parentMerge pull request #1969 from hlohaus/leech (diff)
parentImprove Liabots provider, Add image api support (diff)
downloadgpt4free-4c3472f5417281ecf902edaec390bb1f7bafd808.tar
gpt4free-4c3472f5417281ecf902edaec390bb1f7bafd808.tar.gz
gpt4free-4c3472f5417281ecf902edaec390bb1f7bafd808.tar.bz2
gpt4free-4c3472f5417281ecf902edaec390bb1f7bafd808.tar.lz
gpt4free-4c3472f5417281ecf902edaec390bb1f7bafd808.tar.xz
gpt4free-4c3472f5417281ecf902edaec390bb1f7bafd808.tar.zst
gpt4free-4c3472f5417281ecf902edaec390bb1f7bafd808.zip
Diffstat (limited to '')
-rw-r--r--etc/examples/api.py2
-rw-r--r--g4f/Provider/Liaobots.py79
-rw-r--r--g4f/Provider/needs_auth/Gemini.py36
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py12
-rw-r--r--g4f/Provider/openai/har_file.py30
-rw-r--r--g4f/Provider/openai/proofofwork.py48
-rw-r--r--g4f/api/__init__.py66
-rw-r--r--g4f/client/async_client.py90
-rw-r--r--g4f/client/service.py6
-rw-r--r--g4f/client/stubs.py23
-rw-r--r--g4f/image.py12
11 files changed, 288 insertions, 116 deletions
diff --git a/etc/examples/api.py b/etc/examples/api.py
index d4d03a77..1ab9b51b 100644
--- a/etc/examples/api.py
+++ b/etc/examples/api.py
@@ -3,7 +3,7 @@ import json
url = "http://localhost:1337/v1/chat/completions"
body = {
"model": "",
- "provider": "MetaAI",
+ "provider": "",
"stream": True,
"messages": [
{"role": "assistant", "content": "What can you do? Who are you?"}
diff --git a/g4f/Provider/Liaobots.py b/g4f/Provider/Liaobots.py
index deb7899c..75ecf300 100644
--- a/g4f/Provider/Liaobots.py
+++ b/g4f/Provider/Liaobots.py
@@ -10,6 +10,15 @@ from .helper import get_connector
from ..requests import raise_for_status
models = {
+ "gpt-4o": {
+ "context": "8K",
+ "id": "gpt-4o-free",
+ "maxLength": 31200,
+ "model": "ChatGPT",
+ "name": "GPT-4o-free",
+ "provider": "OpenAI",
+ "tokenLimit": 7800,
+ },
"gpt-3.5-turbo": {
"id": "gpt-3.5-turbo",
"name": "GPT-3.5-Turbo",
@@ -95,7 +104,7 @@ class Liaobots(AsyncGeneratorProvider, ProviderModelMixin):
model_aliases = {
"claude-v2": "claude-2"
}
- _auth_code = None
+ _auth_code = ""
_cookie_jar = None
@classmethod
@@ -120,7 +129,13 @@ class Liaobots(AsyncGeneratorProvider, ProviderModelMixin):
cookie_jar=cls._cookie_jar,
connector=get_connector(connector, proxy, True)
) as session:
- cls._auth_code = auth if isinstance(auth, str) else cls._auth_code
+ data = {
+ "conversationId": str(uuid.uuid4()),
+ "model": models[cls.get_model(model)],
+ "messages": messages,
+ "key": "",
+ "prompt": kwargs.get("system_message", "You are a helpful assistant."),
+ }
if not cls._auth_code:
async with session.post(
"https://liaobots.work/recaptcha/api/login",
@@ -128,31 +143,49 @@ class Liaobots(AsyncGeneratorProvider, ProviderModelMixin):
verify_ssl=False
) as response:
await raise_for_status(response)
+ try:
async with session.post(
"https://liaobots.work/api/user",
- json={"authcode": ""},
+ json={"authcode": cls._auth_code},
verify_ssl=False
) as response:
await raise_for_status(response)
cls._auth_code = (await response.json(content_type=None))["authCode"]
+ if not cls._auth_code:
+ raise RuntimeError("Empty auth code")
cls._cookie_jar = session.cookie_jar
-
- data = {
- "conversationId": str(uuid.uuid4()),
- "model": models[cls.get_model(model)],
- "messages": messages,
- "key": "",
- "prompt": kwargs.get("system_message", "You are a helpful assistant."),
- }
- async with session.post(
- "https://liaobots.work/api/chat",
- json=data,
- headers={"x-auth-code": cls._auth_code},
- verify_ssl=False
- ) as response:
- await raise_for_status(response)
- async for chunk in response.content.iter_any():
- if b"<html coupert-item=" in chunk:
- raise RuntimeError("Invalid session")
- if chunk:
- yield chunk.decode(errors="ignore")
+ async with session.post(
+ "https://liaobots.work/api/chat",
+ json=data,
+ headers={"x-auth-code": cls._auth_code},
+ verify_ssl=False
+ ) as response:
+ await raise_for_status(response)
+ async for chunk in response.content.iter_any():
+ if b"<html coupert-item=" in chunk:
+ raise RuntimeError("Invalid session")
+ if chunk:
+ yield chunk.decode(errors="ignore")
+ except:
+ async with session.post(
+ "https://liaobots.work/api/user",
+ json={"authcode": "pTIQr4FTnVRfr"},
+ verify_ssl=False
+ ) as response:
+ await raise_for_status(response)
+ cls._auth_code = (await response.json(content_type=None))["authCode"]
+ if not cls._auth_code:
+ raise RuntimeError("Empty auth code")
+ cls._cookie_jar = session.cookie_jar
+ async with session.post(
+ "https://liaobots.work/api/chat",
+ json=data,
+ headers={"x-auth-code": cls._auth_code},
+ verify_ssl=False
+ ) as response:
+ await raise_for_status(response)
+ async for chunk in response.content.iter_any():
+ if b"<html coupert-item=" in chunk:
+ raise RuntimeError("Invalid session")
+ if chunk:
+ yield chunk.decode(errors="ignore")
diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py
index e468f64a..f9b1c4a5 100644
--- a/g4f/Provider/needs_auth/Gemini.py
+++ b/g4f/Provider/needs_auth/Gemini.py
@@ -4,6 +4,7 @@ import os
import json
import random
import re
+import base64
from aiohttp import ClientSession, BaseConnector
@@ -22,7 +23,7 @@ from ..base_provider import AsyncGeneratorProvider
from ..helper import format_prompt, get_cookies
from ...requests.raise_for_status import raise_for_status
from ...errors import MissingAuthError, MissingRequirementsError
-from ...image import to_bytes, to_data_uri, ImageResponse
+from ...image import to_bytes, ImageResponse, ImageDataResponse
from ...webdriver import get_browser, get_driver_cookies
REQUEST_HEADERS = {
@@ -122,6 +123,7 @@ class Gemini(AsyncGeneratorProvider):
connector: BaseConnector = None,
image: ImageType = None,
image_name: str = None,
+ response_format: str = None,
**kwargs
) -> AsyncResult:
prompt = format_prompt(messages)
@@ -192,22 +194,22 @@ class Gemini(AsyncGeneratorProvider):
if image_prompt:
images = [image[0][3][3] for image in response_part[4][0][12][7][0]]
resolved_images = []
- preview = []
- for image in images:
- async with client.get(image, allow_redirects=False) as fetch:
- image = fetch.headers["location"]
- async with client.get(image, allow_redirects=False) as fetch:
- image = fetch.headers["location"]
- resolved_images.append(image)
- preview.append(image.replace('=s512', '=s200'))
- # preview_url = image.replace('=s512', '=s200')
- # async with client.get(preview_url) as fetch:
- # preview_data = to_data_uri(await fetch.content.read())
- # async with client.get(image) as fetch:
- # data = to_data_uri(await fetch.content.read())
- # preview.append(preview_data)
- # resolved_images.append(data)
- yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview})
+ if response_format == "b64_json":
+ for image in images:
+ async with client.get(image) as response:
+ data = base64.b64encode(await response.content.read()).decode()
+ resolved_images.append(data)
+ yield ImageDataResponse(resolved_images, image_prompt)
+ else:
+ preview = []
+ for image in images:
+ async with client.get(image, allow_redirects=False) as fetch:
+ image = fetch.headers["location"]
+ async with client.get(image, allow_redirects=False) as fetch:
+ image = fetch.headers["location"]
+ resolved_images.append(image)
+ preview.append(image.replace('=s512', '=s200'))
+ yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview})
def build_request(
prompt: str,
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index b4b8bb02..28d0558b 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -61,7 +61,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True
default_model = None
default_vision_model = "gpt-4o"
- models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo", "gpt-4o"]
+ models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo", "gpt-4o", "auto"]
model_aliases = {
"text-davinci-002-render-sha": "gpt-3.5-turbo",
"": "gpt-3.5-turbo",
@@ -394,10 +394,11 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
print(f"{e.__class__.__name__}: {e}")
arkose_token = None
+ proofTokens = None
if cls.default_model is None:
error = None
try:
- arkose_token, api_key, cookies, headers = await getArkoseAndAccessToken(proxy)
+ arkose_token, api_key, cookies, headers, proofTokens = await getArkoseAndAccessToken(proxy)
cls._create_request_args(cookies, headers)
cls._set_api_key(api_key)
except NoValidHarFileError as e:
@@ -413,17 +414,17 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
if cls._api_key is None else
f"{cls.url}/backend-api/sentinel/chat-requirements",
json={"conversation_mode_kind": "primary_assistant"},
+ #json={"p": generate_proof_token(True, user_agent=cls._headers["user-agent"], proofTokens=proofTokens)},
headers=cls._headers
) as response:
cls._update_request_args(session)
await raise_for_status(response)
data = await response.json()
- blob = data["arkose"]["dx"]
- need_arkose = data["arkose"]["required"]
+ need_arkose = data.get("arkose", {}).get("required")
chat_token = data["token"]
proofofwork = ""
if "proofofwork" in data:
- proofofwork = generate_proof_token(**data["proofofwork"], user_agent=cls._headers["user-agent"])
+ proofofwork = generate_proof_token(**data["proofofwork"], user_agent=cls._headers["user-agent"], proofTokens=proofTokens)
if need_arkose and arkose_token is None:
arkose_token, api_key, cookies, headers = await getArkoseAndAccessToken(proxy)
@@ -435,7 +436,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
if debug.logging:
print(
'Arkose:', False if not need_arkose else arkose_token[:12]+"...",
- 'Turnstile:', data["turnstile"]["required"],
'Proofofwork:', False if proofofwork is None else proofofwork[:12]+"...",
)
diff --git a/g4f/Provider/openai/har_file.py b/g4f/Provider/openai/har_file.py
index 220c20bf..eefe305f 100644
--- a/g4f/Provider/openai/har_file.py
+++ b/g4f/Provider/openai/har_file.py
@@ -12,6 +12,7 @@ from copy import deepcopy
from .crypt import decrypt, encrypt
from ...requests import StreamSession
+from ... import debug
class NoValidHarFileError(Exception):
...
@@ -31,6 +32,7 @@ chatArk: arkReq = None
accessToken: str = None
cookies: dict = None
headers: dict = None
+proofTokens: list = []
def readHAR():
dirPath = "./"
@@ -54,6 +56,15 @@ def readHAR():
# Error: not a HAR file!
continue
for v in harFile['log']['entries']:
+ v_headers = get_headers(v)
+ try:
+ if "openai-sentinel-proof-token" in v_headers:
+ proofTokens.append(json.loads(base64.b64decode(
+ v_headers["openai-sentinel-proof-token"].split("gAAAAAB", 1)[-1].encode()
+ ).decode()))
+ except Exception as e:
+ if debug.logging:
+ print(f"Read proof token: {e}")
if arkPreURL in v['request']['url']:
chatArks.append(parseHAREntry(v))
elif v['request']['url'] == sessionUrl:
@@ -61,13 +72,13 @@ def readHAR():
accessToken = json.loads(v["response"]["content"]["text"]).get("accessToken")
except KeyError:
continue
- cookies = {c['name']: c['value'] for c in v['request']['cookies']}
- headers = get_headers(v)
+ cookies = {c['name']: c['value'] for c in v['request']['cookies'] if c['name'] != "oai-did"}
+ headers = v_headers
if not accessToken:
raise NoValidHarFileError("No accessToken found in .har files")
if not chatArks:
- return None, accessToken, cookies, headers
- return chatArks.pop(), accessToken, cookies, headers
+ return None, accessToken, cookies, headers, proofTokens
+ return chatArks.pop(), accessToken, cookies, headers, proofTokens
def get_headers(entry) -> dict:
return {h['name'].lower(): h['value'] for h in entry['request']['headers'] if h['name'].lower() not in ['content-length', 'cookie'] and not h['name'].startswith(':')}
@@ -101,7 +112,8 @@ def genArkReq(chatArk: arkReq) -> arkReq:
async def sendRequest(tmpArk: arkReq, proxy: str = None):
async with StreamSession(headers=tmpArk.arkHeader, cookies=tmpArk.arkCookies, proxies={"https": proxy}) as session:
async with session.post(tmpArk.arkURL, data=tmpArk.arkBody) as response:
- arkose = (await response.json()).get("token")
+ data = await response.json()
+ arkose = data.get("token")
if "sup=1|rid=" not in arkose:
return RuntimeError("No valid arkose token generated")
return arkose
@@ -131,10 +143,10 @@ def getN() -> str:
return base64.b64encode(timestamp.encode()).decode()
async def getArkoseAndAccessToken(proxy: str) -> tuple[str, str, dict, dict]:
- global chatArk, accessToken, cookies, headers
+ global chatArk, accessToken, cookies, headers, proofTokens
if chatArk is None or accessToken is None:
- chatArk, accessToken, cookies, headers = readHAR()
+ chatArk, accessToken, cookies, headers, proofTokens = readHAR()
if chatArk is None:
- return None, accessToken, cookies, headers
+ return None, accessToken, cookies, headers, proofTokens
newReq = genArkReq(chatArk)
- return await sendRequest(newReq, proxy), accessToken, cookies, headers
+ return await sendRequest(newReq, proxy), accessToken, cookies, headers, proofTokens
diff --git a/g4f/Provider/openai/proofofwork.py b/g4f/Provider/openai/proofofwork.py
index 51d96bc4..cbce153f 100644
--- a/g4f/Provider/openai/proofofwork.py
+++ b/g4f/Provider/openai/proofofwork.py
@@ -2,35 +2,51 @@ import random
import hashlib
import json
import base64
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timezone
-def generate_proof_token(required: bool, seed: str, difficulty: str, user_agent: str):
+proof_token_cache: dict = {}
+
+def generate_proof_token(required: bool, seed: str = None, difficulty: str = None, user_agent: str = None, proofTokens: list = None):
if not required:
return
-
- cores = [8, 12, 16, 24]
- screens = [3000, 4000, 6000]
-
- core = random.choice(cores)
- screen = random.choice(screens)
+ if seed is not None and seed in proof_token_cache:
+ return proof_token_cache[seed]
# Get current UTC time
now_utc = datetime.now(timezone.utc)
parse_time = now_utc.strftime('%a, %d %b %Y %H:%M:%S GMT')
- config = [core + screen, parse_time, None, 0, user_agent, "https://tcr9i.chat.openai.com/v2/35536E1E-65B4-4D96-9D97-6ADB7EFF8147/api.js","dpl=53d243de46ff04dadd88d293f088c2dd728f126f","en","en-US",442,"plugins−[object PluginArray]","","alert"]
-
- diff_len = len(difficulty) // 2
-
+ if proofTokens:
+ config = random.choice(proofTokens)
+ else:
+ screen = random.choice([3008, 4010, 6000]) * random.choice([1, 2, 4])
+ config = [
+ screen, parse_time,
+ None, 0, user_agent,
+ "https://tcr9i.chat.openai.com/v2/35536E1E-65B4-4D96-9D97-6ADB7EFF8147/api.js",
+ "dpl=1440a687921de39ff5ee56b92807faaadce73f13","en","en-US",
+ None,
+ "plugins−[object PluginArray]",
+ random.choice(["_reactListeningcfilawjnerp", "_reactListening9ne2dfo1i47", "_reactListening410nzwhan2a"]),
+ random.choice(["alert", "ontransitionend", "onprogress"])
+ ]
+
+ config[1] = parse_time
+ config[4] = user_agent
+ config[7] = random.randint(101, 2100)
+
+ diff_len = None if difficulty is None else len(difficulty)
for i in range(100000):
config[3] = i
json_data = json.dumps(config)
base = base64.b64encode(json_data.encode()).decode()
- hash_value = hashlib.sha3_512((seed + base).encode()).digest()
+ hash_value = hashlib.sha3_512((seed or "" + base).encode()).digest()
- if hash_value.hex()[:diff_len] <= difficulty:
- result = "gAAAAAB" + base
- return result
+ if difficulty is None or hash_value.hex()[:diff_len] <= difficulty:
+ if seed is None:
+ return "gAAAAAC" + base
+ proof_token_cache[seed] = "gAAAAAB" + base
+ return proof_token_cache[seed]
fallback_base = base64.b64encode(f'"{seed}"'.encode()).decode()
return "gAAAAABwQ8Lk5FbGpA2NcR9dShT6gYjU7VxZ4D" + fallback_base
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py
index 7296a542..acb27e9c 100644
--- a/g4f/api/__init__.py
+++ b/g4f/api/__init__.py
@@ -47,6 +47,14 @@ class ChatCompletionsForm(BaseModel):
web_search: Optional[bool] = None
proxy: Optional[str] = None
+class ImagesGenerateForm(BaseModel):
+ model: Optional[str] = None
+ provider: Optional[str] = None
+ prompt: str
+ response_format: Optional[str] = None
+ api_key: Optional[str] = None
+ proxy: Optional[str] = None
+
class AppConfig():
list_ignored_providers: Optional[list[str]] = None
g4f_api_key: Optional[str] = None
@@ -149,37 +157,53 @@ class Api:
if auth_header and auth_header != "Bearer":
config.api_key = auth_header
response = self.client.chat.completions.create(
- **{
- **AppConfig.defaults,
- **config.dict(exclude_none=True),
- },
-
+ **{
+ **AppConfig.defaults,
+ **config.dict(exclude_none=True),
+ },
ignored=AppConfig.list_ignored_providers
)
+ if not config.stream:
+ return JSONResponse((await response).to_json())
+
+ async def streaming():
+ try:
+ async for chunk in response:
+ yield f"data: {json.dumps(chunk.to_json())}\n\n"
+ except GeneratorExit:
+ pass
+ except Exception as e:
+ logging.exception(e)
+ yield f'data: {format_exception(e, config)}\n\n'
+ yield "data: [DONE]\n\n"
+ return StreamingResponse(streaming(), media_type="text/event-stream")
+
except Exception as e:
logging.exception(e)
return Response(content=format_exception(e, config), status_code=500, media_type="application/json")
- if not config.stream:
- return JSONResponse((await response).to_json())
-
- async def streaming():
- try:
- async for chunk in response:
- yield f"data: {json.dumps(chunk.to_json())}\n\n"
- except GeneratorExit:
- pass
- except Exception as e:
- logging.exception(e)
- yield f'data: {format_exception(e, config)}\n\n'
- yield "data: [DONE]\n\n"
-
- return StreamingResponse(streaming(), media_type="text/event-stream")
-
@self.app.post("/v1/completions")
async def completions():
return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")
+ @self.app.post("/v1/images/generations")
+ async def images_generate(config: ImagesGenerateForm, request: Request = None, provider: str = None):
+ try:
+ config.provider = provider if config.provider is None else config.provider
+ if config.api_key is None and request is not None:
+ auth_header = request.headers.get("Authorization")
+ if auth_header is not None:
+ auth_header = auth_header.split(None, 1)[-1]
+ if auth_header and auth_header != "Bearer":
+ config.api_key = auth_header
+ response = self.client.images.generate(
+ **config.dict(exclude_none=True),
+ )
+ return JSONResponse((await response).to_json())
+ except Exception as e:
+ logging.exception(e)
+ return Response(content=format_exception(e, config), status_code=500, media_type="application/json")
+
def format_exception(e: Exception, config: ChatCompletionsForm) -> str:
last_provider = g4f.get_last_provider(True)
return json.dumps({
diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py
index 07ad3357..1508e566 100644
--- a/g4f/client/async_client.py
+++ b/g4f/client/async_client.py
@@ -3,6 +3,9 @@ from __future__ import annotations
import time
import random
import string
+import asyncio
+import base64
+from aiohttp import ClientSession, BaseConnector
from .types import Client as BaseClient
from .types import ProviderType, FinishReason
@@ -11,9 +14,11 @@ from .types import AsyncIterResponse, ImageProvider
from .image_models import ImageModels
from .helper import filter_json, find_stop, filter_none, cast_iter_async
from .service import get_last_provider, get_model_and_provider
+from ..Provider import ProviderUtils
from ..typing import Union, Messages, AsyncIterator, ImageType
-from ..errors import NoImageResponseError
-from ..image import ImageResponse as ImageProviderResponse
+from ..errors import NoImageResponseError, ProviderNotFoundError
+from ..requests.aiohttp import get_connector
+from ..image import ImageResponse as ImageProviderResponse, ImageDataResponse
try:
anext
@@ -156,12 +161,28 @@ class Chat():
def __init__(self, client: AsyncClient, provider: ProviderType = None):
self.completions = Completions(client, provider)
-async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]:
+async def iter_image_response(
+ response: AsyncIterator,
+ response_format: str = None,
+ connector: BaseConnector = None,
+ proxy: str = None
+) -> Union[ImagesResponse, None]:
async for chunk in response:
if isinstance(chunk, ImageProviderResponse):
- return ImagesResponse([Image(image) for image in chunk.get_list()])
+ if response_format == "b64_json":
+ async with ClientSession(
+ connector=get_connector(connector, proxy)
+ ) as session:
+ async def fetch_image(image):
+ async with session.get(image) as response:
+ return base64.b64encode(await response.content.read()).decode()
+ images = await asyncio.gather(*[fetch_image(image) for image in chunk.get_list()])
+ return ImagesResponse([Image(None, image, chunk.alt) for image in images], int(time.time()))
+ return ImagesResponse([Image(image, None, chunk.alt) for image in chunk.get_list()], int(time.time()))
+ elif isinstance(chunk, ImageDataResponse):
+ return ImagesResponse([Image(None, image, chunk.alt) for image in chunk.get_list()], int(time.time()))
-def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator:
+def create_image(provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator:
prompt = f"create a image with: {prompt}"
if provider.__name__ == "You":
kwargs["chat_mode"] = "create"
@@ -169,7 +190,6 @@ def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model
model,
[{"role": "user", "content": prompt}],
stream=True,
- proxy=client.get_proxy(),
**kwargs
)
@@ -179,31 +199,71 @@ class Images():
self.provider: ImageProvider = provider
self.models: ImageModels = ImageModels(client)
- async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse:
- provider = self.models.get(model, self.provider)
+ def get_provider(self, model: str, provider: ProviderType = None):
+ if isinstance(provider, str):
+ if provider in ProviderUtils.convert:
+ provider = ProviderUtils.convert[provider]
+ else:
+ raise ProviderNotFoundError(f'Provider not found: {provider}')
+ else:
+ provider = self.models.get(model, self.provider)
+ return provider
+
+ async def generate(
+ self,
+ prompt,
+ model: str = "",
+ provider: ProviderType = None,
+ response_format: str = None,
+ connector: BaseConnector = None,
+ proxy: str = None,
+ **kwargs
+ ) -> ImagesResponse:
+ provider = self.get_provider(model, provider)
if hasattr(provider, "create_async_generator"):
- response = create_image(self.client, provider, prompt, **kwargs)
+ response = create_image(
+ provider,
+ prompt,
+ **filter_none(
+ response_format=response_format,
+ connector=connector,
+ proxy=self.client.get_proxy() if proxy is None else proxy,
+ ),
+ **kwargs
+ )
else:
response = await provider.create_async(prompt)
return ImagesResponse([Image(image) for image in response.get_list()])
- image = await iter_image_response(response)
+ image = await iter_image_response(response, response_format, connector, proxy)
if image is None:
raise NoImageResponseError()
return image
- async def create_variation(self, image: ImageType, model: str = None, **kwargs):
- provider = self.models.get(model, self.provider)
+ async def create_variation(
+ self,
+ image: ImageType,
+ model: str = None,
+ response_format: str = None,
+ connector: BaseConnector = None,
+ proxy: str = None,
+ **kwargs
+ ):
+ provider = self.get_provider(model, provider)
result = None
if hasattr(provider, "create_async_generator"):
response = provider.create_async_generator(
"",
[{"role": "user", "content": "create a image like this"}],
- True,
+ stream=True,
image=image,
- proxy=self.client.get_proxy(),
+ **filter_none(
+ response_format=response_format,
+ connector=connector,
+ proxy=self.client.get_proxy() if proxy is None else proxy,
+ ),
**kwargs
)
- result = iter_image_response(response)
+ result = iter_image_response(response, response_format, connector, proxy)
if result is None:
raise NoImageResponseError()
return result
diff --git a/g4f/client/service.py b/g4f/client/service.py
index dd6bf4b6..5fdb150c 100644
--- a/g4f/client/service.py
+++ b/g4f/client/service.py
@@ -4,7 +4,7 @@ from typing import Union
from .. import debug, version
from ..errors import ProviderNotFoundError, ModelNotFoundError, ProviderNotWorkingError, StreamNotSupportedError
-from ..models import Model, ModelUtils
+from ..models import Model, ModelUtils, default
from ..Provider import ProviderUtils
from ..providers.types import BaseRetryProvider, ProviderType
from ..providers.retry_provider import IterProvider
@@ -60,7 +60,9 @@ def get_model_and_provider(model : Union[Model, str],
model = ModelUtils.convert[model]
if not provider:
- if isinstance(model, str):
+ if not model:
+ model = default
+ elif isinstance(model, str):
raise ModelNotFoundError(f'Model not found: {model}')
provider = model.best_provider
diff --git a/g4f/client/stubs.py b/g4f/client/stubs.py
index ceb51b83..8cf2bcba 100644
--- a/g4f/client/stubs.py
+++ b/g4f/client/stubs.py
@@ -96,13 +96,24 @@ class ChatCompletionDeltaChoice(Model):
}
class Image(Model):
- url: str
+ def __init__(self, url: str = None, b64_json: str = None, revised_prompt: str = None) -> None:
+ if url is not None:
+ self.url = url
+ if b64_json is not None:
+ self.b64_json = b64_json
+ if revised_prompt is not None:
+ self.revised_prompt = revised_prompt
- def __init__(self, url: str) -> None:
- self.url = url
+ def to_json(self):
+ return self.__dict__
class ImagesResponse(Model):
- data: list[Image]
-
- def __init__(self, data: list) -> None:
+ def __init__(self, data: list[Image], created: int = 0) -> None:
self.data = data
+ self.created = created
+
+ def to_json(self):
+ return {
+ **self.__dict__,
+ "data": [image.to_json() for image in self.data]
+ } \ No newline at end of file
diff --git a/g4f/image.py b/g4f/image.py
index 270b59ad..3d339266 100644
--- a/g4f/image.py
+++ b/g4f/image.py
@@ -275,6 +275,18 @@ class ImagePreview(ImageResponse):
def to_string(self):
return super().__str__()
+class ImageDataResponse():
+ def __init__(
+ self,
+ images: Union[str, list],
+ alt: str,
+ ):
+ self.images = images
+ self.alt = alt
+
+ def get_list(self) -> list[str]:
+ return [self.images] if isinstance(self.images, str) else self.images
+
class ImageRequest:
def __init__(
self,