K-Sort-Arena / model /models /generate_image_cache.py
YangZhoumill's picture
pre-release
b257e01
raw
history blame
5.08 kB
from huggingface_models import load_huggingface_model
from replicate_api_models import load_replicate_model
from openai_api_models import load_openai_model
from other_api_models import load_other_model
import concurrent.futures
import os
import io, time
import requests
import json
from PIL import Image
IMAGE_GENERATION_MODELS = [
# 'replicate_SDXL_text2image',
# 'replicate_SD-v3.0_text2image',
# 'replicate_SD-v2.1_text2image',
# 'replicate_SD-v1.5_text2image',
# 'replicate_SDXL-Lightning_text2image',
# 'replicate_Kandinsky-v2.0_text2image',
# 'replicate_Kandinsky-v2.2_text2image',
# 'replicate_Proteus-v0.2_text2image',
# 'replicate_Playground-v2.0_text2image',
# 'replicate_Playground-v2.5_text2image',
# 'replicate_Dreamshaper-xl-turbo_text2image',
# 'replicate_SDXL-Deepcache_text2image',
# 'replicate_Openjourney-v4_text2image',
# 'replicate_LCM-v1.5_text2image',
# 'replicate_Realvisxl-v3.0_text2image',
# 'replicate_Realvisxl-v2.0_text2image',
# 'replicate_Pixart-Sigma_text2image',
# 'replicate_SSD-1b_text2image',
# 'replicate_Open-Dalle-v1.1_text2image',
# 'replicate_Deepfloyd-IF_text2image',
# 'huggingface_SD-turbo_text2image',
# 'huggingface_SDXL-turbo_text2image',
# 'huggingface_Stable-cascade_text2image',
# 'openai_Dalle-2_text2image',
# 'openai_Dalle-3_text2image',
'other_Midjourney-v6.0_text2image',
'other_Midjourney-v5.0_text2image',
# "replicate_FLUX.1-schnell_text2image",
# "replicate_FLUX.1-pro_text2image",
# "replicate_FLUX.1-dev_text2image",
]
Prompts = [
# 'An aerial view of someone walking through a forest alone in the style of Romanticism.',
# 'With dark tones and backlit resolution, this oil painting depicts a thunderstorm over a cityscape.',
# 'The rendering depicts a futuristic train station with volumetric lighting in an Art Nouveau style.',
# 'An Impressionist illustration depicts a river winding through a meadow.', # featuring a thick black outline
# 'Photo of a black and white picture of a person facing the sunset from a bench.',
# 'The skyline of a city is painted in bright, high-resolution colors.',
# 'A sketch shows two robots talking to each other, featuring a surreal look and narrow aspect ratio.',
# 'An abstract Dadaist collage in neon tones and 4K resolutions of a post-apocalyptic world.',
# 'With abstract elements and a rococo style, the painting depicts a garden in high resolution.',
# 'A picture of a senior man walking in the rain and looking directly at the camera from a medium distance.',
]
def load_pipeline(model_name):
model_source, model_name, model_type = model_name.split("_")
if model_source == "replicate":
pipe = load_replicate_model(model_name, model_type)
elif model_source == "huggingface":
pipe = load_huggingface_model(model_name, model_type)
elif model_source == "openai":
pipe = load_openai_model(model_name, model_type)
elif model_source == "other":
pipe = load_other_model(model_name, model_type)
else:
raise ValueError(f"Model source {model_source} not supported")
return pipe
def generate_image_ig_api(prompt, model_name):
pipe = load_pipeline(model_name)
result = pipe(prompt=prompt)
return result
save_names = []
for name in IMAGE_GENERATION_MODELS:
model_source, model_name, model_type = name.split("_")
save_names.append(model_name)
for i, prompt in enumerate(Prompts):
print("save the {} prompt".format(i+1))
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(generate_image_ig_api, prompt, model) for model in IMAGE_GENERATION_MODELS]
results = [future.result() for future in futures]
root_dir = '/rscratch/zhendong/lizhikai/ksort/ksort_image_cache/'
save_dir = os.path.join(root_dir, f'output-{i+4}')
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, "prompt.txt"), 'w', encoding='utf-8') as file:
file.write(prompt)
for j, result in enumerate(results):
result = result.resize((512, 512))
file_path = os.path.join(save_dir, f'{save_names[j]}.jpg')
result.save(file_path, format="JPEG")