K-Sort-Arena / model /models /replicate_api_models.py
YangZhoumill's picture
pre-release
b257e01
raw
history blame
16.3 kB
import replicate
from PIL import Image
import requests
import io
import os
import base64
Replicate_MODEl_NAME_MAP = {
"SDXL": "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
"SD-v3.0": "stability-ai/stable-diffusion-3",
"SD-v2.1": "stability-ai/stable-diffusion:ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4",
"SD-v1.5": "stability-ai/stable-diffusion:b3d14e1cd1f9470bbb0bb68cac48e5f483e5be309551992cc33dc30654a82bb7",
"SDXL-Lightning": "bytedance/sdxl-lightning-4step:5f24084160c9089501c1b3545d9be3c27883ae2239b6f412990e82d4a6210f8f",
"Kandinsky-v2.0": "ai-forever/kandinsky-2:3c6374e7a9a17e01afe306a5218cc67de55b19ea536466d6ea2602cfecea40a9",
"Kandinsky-v2.2": "ai-forever/kandinsky-2.2:ad9d7879fbffa2874e1d909d1d37d9bc682889cc65b31f7bb00d2362619f194a",
"Proteus-v0.2": "lucataco/proteus-v0.2:06775cd262843edbde5abab958abdbb65a0a6b58ca301c9fd78fa55c775fc019",
"Playground-v2.0": "playgroundai/playground-v2-1024px-aesthetic:42fe626e41cc811eaf02c94b892774839268ce1994ea778eba97103fe1ef51b8",
"Playground-v2.5": "playgroundai/playground-v2.5-1024px-aesthetic:a45f82a1382bed5c7aeb861dac7c7d191b0fdf74d8d57c4a0e6ed7d4d0bf7d24",
"Dreamshaper-xl-turbo": "lucataco/dreamshaper-xl-turbo:0a1710e0187b01a255302738ca0158ff02a22f4638679533e111082f9dd1b615",
"SDXL-Deepcache": "lucataco/sdxl-deepcache:eaf678fb34006669e9a3c6dd5971e2279bf20ee0adeced464d7b6d95de16dc93",
"Openjourney-v4": "prompthero/openjourney:ad59ca21177f9e217b9075e7300cf6e14f7e5b4505b87b9689dbd866e9768969",
"LCM-v1.5": "fofr/latent-consistency-model:683d19dc312f7a9f0428b04429a9ccefd28dbf7785fef083ad5cf991b65f406f",
"Realvisxl-v3.0": "fofr/realvisxl-v3:33279060bbbb8858700eb2146350a98d96ef334fcf817f37eb05915e1534aa1c",
"Realvisxl-v2.0": "lucataco/realvisxl-v2.0:7d6a2f9c4754477b12c14ed2a58f89bb85128edcdd581d24ce58b6926029de08",
"Pixart-Sigma": "cjwbw/pixart-sigma:5a54352c99d9fef467986bc8f3a20205e8712cbd3df1cbae4975d6254c902de1",
"SSD-1b": "lucataco/ssd-1b:b19e3639452c59ce8295b82aba70a231404cb062f2eb580ea894b31e8ce5bbb6",
"Open-Dalle-v1.1": "lucataco/open-dalle-v1.1:1c7d4c8dec39c7306df7794b28419078cb9d18b9213ab1c21fdc46a1deca0144",
"Deepfloyd-IF": "andreasjansson/deepfloyd-if:fb84d659df149f4515c351e394d22222a94144aa1403870c36025c8b28846c8d",
"Zeroscope-v2-xl": "anotherjesse/zeroscope-v2-xl:9f747673945c62801b13b84701c783929c0ee784e4748ec062204894dda1a351",
# "Damo-Text-to-Video": "cjwbw/damo-text-to-video:1e205ea73084bd17a0a3b43396e49ba0d6bc2e754e9283b2df49fad2dcf95755",
"Animate-Diff": "lucataco/animate-diff:beecf59c4aee8d81bf04f0381033dfa10dc16e845b4ae00d281e2fa377e48a9f",
"OpenSora": "camenduru/open-sora:8099e5722ba3d5f408cd3e696e6df058137056268939337a3fbe3912e86e72ad",
"LaVie": "cjwbw/lavie:0bca850c4928b6c30052541fa002f24cbb4b677259c461dd041d271ba9d3c517",
"VideoCrafter2": "lucataco/video-crafter:7757c5775e962c618053e7df4343052a21075676d6234e8ede5fa67c9e43bce0",
"Stable-Video-Diffusion": "sunfjun/stable-video-diffusion:d68b6e09eedbac7a49e3d8644999d93579c386a083768235cabca88796d70d82",
"FLUX.1-schnell": "black-forest-labs/flux-schnell",
"FLUX.1-pro": "black-forest-labs/flux-pro",
"FLUX.1-dev": "black-forest-labs/flux-dev",
}
class ReplicateModel():
def __init__(self, model_name, model_type):
self.model_name = model_name
self.model_type = model_type
# os.environ['FAL_KEY'] = os.environ['FalAPI']
def __call__(self, *args, **kwargs):
if self.model_type == "text2image":
assert "prompt" in kwargs, "prompt is required for text2image model"
output = replicate.run(
f"{Replicate_MODEl_NAME_MAP[self.model_name]}",
input={
"width": 512,
"height": 512,
"prompt": kwargs["prompt"]
},
)
if 'Openjourney' in self.model_name:
for item in output:
result_url = item
break
elif isinstance(output, list):
result_url = output[0]
else:
result_url = output
print(self.model_name, result_url)
response = requests.get(result_url)
result = Image.open(io.BytesIO(response.content))
return result
elif self.model_type == "text2video":
assert "prompt" in kwargs, "prompt is required for text2image model"
if self.model_name == "Zeroscope-v2-xl":
input = {
"fps": 24,
"width": 512,
"height": 512,
"prompt": kwargs["prompt"],
"guidance_scale": 17.5,
# "negative_prompt": "very blue, dust, noisy, washed out, ugly, distorted, broken",
"num_frames": 48,
}
elif self.model_name == "Damo-Text-to-Video":
input={
"fps": 8,
"prompt": kwargs["prompt"],
"num_frames": 16,
"num_inference_steps": 50
}
elif self.model_name == "Animate-Diff":
input={
"path": "toonyou_beta3.safetensors",
"seed": 255224557,
"steps": 25,
"prompt": kwargs["prompt"],
"n_prompt": "badhandv4, easynegative, ng_deepnegative_v1_75t, verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, teeth",
"motion_module": "mm_sd_v14",
"guidance_scale": 7.5
}
elif self.model_name == "OpenSora":
input={
"seed": 1234,
"prompt": kwargs["prompt"],
}
elif self.model_name == "LaVie":
input={
"width": 512,
"height": 512,
"prompt": kwargs["prompt"],
"quality": 9,
"video_fps": 8,
"interpolation": False,
"sample_method": "ddpm",
"guidance_scale": 7,
"super_resolution": False,
"num_inference_steps": 50
}
elif self.model_name == "VideoCrafter2":
input={
"fps": 24,
"seed": 64045,
"steps": 40,
"width": 512,
"height": 512,
"prompt": kwargs["prompt"],
}
elif self.model_name == "Stable-Video-Diffusion":
text2image_name = "SD-v2.1"
output = replicate.run(
f"{Replicate_MODEl_NAME_MAP[text2image_name]}",
input={
"width": 512,
"height": 512,
"prompt": kwargs["prompt"]
},
)
if isinstance(output, list):
image_url = output[0]
else:
image_url = output
print(image_url)
input={
"cond_aug": 0.02,
"decoding_t": 14,
"input_image": "{}".format(image_url),
"video_length": "14_frames_with_svd",
"sizing_strategy": "maintain_aspect_ratio",
"motion_bucket_id": 127,
"frames_per_second": 6
}
output = replicate.run(
f"{Replicate_MODEl_NAME_MAP[self.model_name]}",
input=input,
)
if isinstance(output, list):
result_url = output[0]
else:
result_url = output
print(self.model_name)
print(result_url)
# response = requests.get(result_url)
# result = Image.open(io.BytesIO(response.content))
# for event in handler.iter_events(with_logs=True):
# if isinstance(event, fal_client.InProgress):
# print('Request in progress')
# print(event.logs)
# result = handler.get()
# print("result video: ====")
# print(result)
# result_url = result['video']['url']
# return result_url
return result_url
else:
raise ValueError("model_type must be text2image or image2image")
def load_replicate_model(model_name, model_type):
return ReplicateModel(model_name, model_type)
if __name__ == "__main__":
import replicate
import time
import concurrent.futures
import os, shutil, re
import requests
from moviepy.editor import VideoFileClip
# model_name = 'replicate_zeroscope-v2-xl_text2video'
# model_name = 'replicate_Damo-Text-to-Video_text2video'
# model_name = 'replicate_Animate-Diff_text2video'
# model_name = 'replicate_open-sora_text2video'
# model_name = 'replicate_lavie_text2video'
# model_name = 'replicate_video-crafter_text2video'
# model_name = 'replicate_stable-video-diffusion_text2video'
# model_source, model_name, model_type = model_name.split("_")
# pipe = load_replicate_model(model_name, model_type)
# prompt = "Clown fish swimming in a coral reef, beautiful, 8k, perfect, award winning, national geographic"
# result = pipe(prompt=prompt)
# # 文件复制
source_folder = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1.0add/'
destination_folder = '/mnt/data/lizhikai/ksort_video_cache/Advance/'
special_char = 'output'
for dirpath, dirnames, filenames in os.walk(source_folder):
for dirname in dirnames:
des_dirname = "output-"+dirname[-3:]
print(des_dirname)
if special_char in dirname:
model_name = ["Pika-v1.0"]
for name in model_name:
source_file_path = os.path.join(source_folder, os.path.join(dirname, name+".mp4"))
print(source_file_path)
destination_file_path = os.path.join(destination_folder, os.path.join(des_dirname, name+".mp4"))
print(destination_file_path)
shutil.copy(source_file_path, destination_file_path)
# 视频裁剪
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen3/'
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen2/'
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-Beta/'
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1/'
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Sora/'
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1.0add/'
# special_char = 'output'
# num = 0
# for dirpath, dirnames, filenames in os.walk(root_dir):
# for dirname in dirnames:
# # 如果文件夹名称中包含指定的特殊字符
# if special_char in dirname:
# num = num+1
# print(num)
# if num < 0:
# continue
# video_path = os.path.join(root_dir, (os.path.join(dirname, f"{dirname}.mp4")))
# out_video_path = os.path.join(root_dir, (os.path.join(dirname, f"Pika-v1.0.mp4")))
# print(video_path)
# print(out_video_path)
# video = VideoFileClip(video_path)
# width, height = video.size
# center_x, center_y = width // 2, height // 2
# new_width, new_height = 512, 512
# cropped_video = video.crop(x_center=center_x, y_center=center_y, width=min(width, height), height=min(width, height))
# resized_video = cropped_video.resize(newsize=(new_width, new_height))
# resized_video.write_videofile(out_video_path, codec='libx264', fps=video.fps)
# os.remove(video_path)
# file_path = '/home/lizhikai/webvid_prompt100.txt'
# str_list = []
# with open(file_path, 'r', encoding='utf-8') as file:
# for line in file:
# str_list.append(line.strip())
# if len(str_list) == 100:
# break
# 生成代码
# def generate_image_ig_api(prompt, model_name):
# model_source, model_name, model_type = model_name.split("_")
# pipe = load_replicate_model(model_name, model_type)
# result = pipe(prompt=prompt)
# return result
# model_names = ['replicate_Zeroscope-v2-xl_text2video',
# # 'replicate_Damo-Text-to-Video_text2video',
# 'replicate_Animate-Diff_text2video',
# 'replicate_OpenSora_text2video',
# 'replicate_LaVie_text2video',
# 'replicate_VideoCrafter2_text2video',
# 'replicate_Stable-Video-Diffusion_text2video',
# ]
# save_names = []
# for name in model_names:
# model_source, model_name, model_type = name.split("_")
# save_names.append(model_name)
# # 遍历根目录及其子目录
# # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen3/'
# root_dir = '/mnt/data/lizhikai/ksort_video_cache/Runway-Gen2/'
# # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-Beta/'
# # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Pika-v1/'
# # root_dir = '/mnt/data/lizhikai/ksort_video_cache/Sora/'
# special_char = 'output'
# num = 0
# for dirpath, dirnames, filenames in os.walk(root_dir):
# for dirname in dirnames:
# # 如果文件夹名称中包含指定的特殊字符
# if special_char in dirname:
# num = num+1
# print(num)
# if num < 0:
# continue
# str_list = []
# prompt_path = os.path.join(root_dir, (os.path.join(dirname, "prompt.txt")))
# print(prompt_path)
# with open(prompt_path, 'r', encoding='utf-8') as file:
# for line in file:
# str_list.append(line.strip())
# prompt = str_list[0]
# print(prompt)
# with concurrent.futures.ThreadPoolExecutor() as executor:
# futures = [executor.submit(generate_image_ig_api, prompt, model) for model in model_names]
# results = [future.result() for future in futures]
# # 下载视频并保存
# repeat_num = 5
# for j, url in enumerate(results):
# while 1:
# time.sleep(1)
# response = requests.get(url, stream=True)
# if response.status_code == 200:
# file_path = os.path.join(os.path.join(root_dir, dirname), f'{save_names[j]}.mp4')
# with open(file_path, 'wb') as file:
# for chunk in response.iter_content(chunk_size=8192):
# file.write(chunk)
# print(f"视频 {j} 已保存到 {file_path}")
# break
# else:
# repeat_num = repeat_num - 1
# if repeat_num == 0:
# print(f"视频 {j} 保存失败")
# # raise ValueError("Video request failed.")
# continue