Last commit not found
from io import BytesIO | |
import torch | |
from internals.data.dataAccessor import update_db | |
from internals.data.task import ModelType, Task, TaskType | |
from internals.pipelines.inpainter import InPainter | |
from internals.pipelines.object_remove import ObjectRemoval | |
from internals.pipelines.prompt_modifier import PromptModifier | |
from internals.pipelines.remove_background import RemoveBackground, RemoveBackgroundV2 | |
from internals.pipelines.replace_background import ReplaceBackground | |
from internals.pipelines.safety_checker import SafetyChecker | |
from internals.pipelines.upscaler import Upscaler | |
from internals.util.avatar import Avatar | |
from internals.util.cache import auto_clear_cuda_and_gc, clear_cuda | |
from internals.util.commons import construct_default_s3_url, upload_image, upload_images | |
from internals.util.config import ( | |
num_return_sequences, | |
set_configs_from_task, | |
set_root_dir, | |
) | |
from internals.util.failure_hander import FailureHandler | |
from internals.util.slack import Slack | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cuda.matmul.allow_tf32 = True | |
auto_mode = False | |
slack = Slack() | |
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences) | |
upscaler = Upscaler() | |
inpainter = InPainter() | |
safety_checker = SafetyChecker() | |
object_removal = ObjectRemoval() | |
remove_background_v2 = RemoveBackgroundV2() | |
avatar = Avatar() | |
replace_background = ReplaceBackground() | |
def remove_bg(task: Task): | |
remove_background = RemoveBackground() | |
output_image = remove_background.remove(task.get_imageUrl()) | |
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId()) | |
upload_image(output_image, output_key) | |
return {"generated_image_url": construct_default_s3_url(output_key)} | |
def inpaint(task: Task): | |
prompt = avatar.add_code_names(task.get_prompt()) | |
if task.is_prompt_engineering(): | |
prompt = prompt_modifier.modify(prompt) | |
else: | |
prompt = [prompt] * num_return_sequences | |
print({"prompts": prompt}) | |
images = inpainter.process( | |
prompt=prompt, | |
image_url=task.get_imageUrl(), | |
mask_image_url=task.get_maskImageUrl(), | |
width=task.get_width(), | |
height=task.get_height(), | |
seed=task.get_seed(), | |
negative_prompt=[task.get_negative_prompt()] * num_return_sequences, | |
) | |
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId()) | |
clear_cuda() | |
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} | |
def remove_object(task: Task): | |
output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId()) | |
images = object_removal.process( | |
image_url=task.get_imageUrl(), | |
mask_image_url=task.get_maskImageUrl(), | |
seed=task.get_seed(), | |
width=task.get_width(), | |
height=task.get_height(), | |
) | |
generated_image_urls = upload_image(images[0], output_key) | |
clear_cuda() | |
return {"generated_image_urls": generated_image_urls} | |
def replace_bg(task: Task): | |
prompt = task.get_prompt() | |
if task.is_prompt_engineering(): | |
prompt = prompt_modifier.modify(prompt) | |
else: | |
prompt = [prompt] * num_return_sequences | |
images, has_nsfw = replace_background.replace( | |
image=task.get_imageUrl(), | |
prompt=prompt, | |
negative_prompt=[task.get_negative_prompt()] * num_return_sequences, | |
seed=task.get_seed(), | |
width=task.get_width(), | |
height=task.get_height(), | |
steps=task.get_steps(), | |
resize_dimension=task.get_resize_dimension(), | |
product_scale_width=task.get_image_scale(), | |
) | |
generated_image_urls = upload_images(images, "_replace_bg", task.get_taskId()) | |
return { | |
"modified_prompts": prompt, | |
"generated_image_urls": generated_image_urls, | |
"has_nsfw": has_nsfw, | |
} | |
def upscale_image(task: Task): | |
output_key = "crecoAI/{}_upscale.png".format(task.get_taskId()) | |
out_img = None | |
if task.get_modelType() == ModelType.ANIME: | |
print("Using Anime model") | |
out_img = upscaler.upscale_anime( | |
image=task.get_imageUrl(), resize_dimension=task.get_resize_dimension() | |
) | |
else: | |
print("Using Real model") | |
out_img = upscaler.upscale( | |
image=task.get_imageUrl(), resize_dimension=task.get_resize_dimension() | |
) | |
upload_image(BytesIO(out_img), output_key) | |
return {"generated_image_url": construct_default_s3_url(output_key)} | |
def model_fn(model_dir): | |
print("Logs: model loaded .... starts") | |
set_root_dir(__file__) | |
FailureHandler.register() | |
avatar.load_local(model_dir) | |
prompt_modifier.load() | |
safety_checker.load() | |
object_removal.load(model_dir) | |
upscaler.load() | |
inpainter.load() | |
replace_background.load(upscaler, remove_background_v2) | |
safety_checker.apply(inpainter) | |
print("Logs: model loaded ....") | |
return | |
def predict_fn(data, pipe): | |
task = Task(data) | |
print("task is ", data) | |
FailureHandler.handle(task) | |
# Set set_environment | |
set_configs_from_task(task) | |
try: | |
# Set set_environment | |
set_configs_from_task(task) | |
# Fetch avatars | |
avatar.fetch_from_network(task.get_model_id()) | |
task_type = task.get_type() | |
if task_type == TaskType.REMOVE_BG: | |
return remove_bg(task) | |
elif task_type == TaskType.INPAINT: | |
return inpaint(task) | |
elif task_type == TaskType.UPSCALE_IMAGE: | |
return upscale_image(task) | |
elif task_type == TaskType.OBJECT_REMOVAL: | |
return remove_object(task) | |
elif task_type == TaskType.REPLACE_BG: | |
return replace_bg(task) | |
else: | |
raise Exception("Invalid task type") | |
except Exception as e: | |
print(f"Error: {e}") | |
slack.error_alert(task, e) | |
return None | |