diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..ac29c10c09ff8b93f0bca58563635d41ced58657 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +doc/deepfake_progress_source.psd filter=lfs diff=lfs merge=lfs -text +doc/deepfake_progress.png filter=lfs diff=lfs merge=lfs -text +merger/gfx/help_merger_face_avatar_source.psd filter=lfs diff=lfs merge=lfs -text +merger/gfx/help_merger_masked_source.psd filter=lfs diff=lfs merge=lfs -text diff --git a/CODEGUIDELINES b/CODEGUIDELINES new file mode 100644 index 0000000000000000000000000000000000000000..0d40a02f8eaee1ef4e7c29462a42b23698ec5f34 --- /dev/null +++ b/CODEGUIDELINES @@ -0,0 +1,5 @@ +Please don't ruin the code and this good (as I think) architecture. + +Please follow the same logic and brevity/pithiness. + +Don't abstract the code into huge classes if you only win some lines of code in one place, because this can prevent programmers from understanding it quickly. \ No newline at end of file diff --git a/DFLIMG/DFLIMG.py b/DFLIMG/DFLIMG.py new file mode 100644 index 0000000000000000000000000000000000000000..e213920c92e52823e81d051088e853afec279431 --- /dev/null +++ b/DFLIMG/DFLIMG.py @@ -0,0 +1,12 @@ +from pathlib import Path + +from .DFLJPG import DFLJPG + +class DFLIMG(): + + @staticmethod + def load(filepath, loader_func=None): + if filepath.suffix == '.jpg': + return DFLJPG.load ( str(filepath), loader_func=loader_func ) + else: + return None diff --git a/DFLIMG/DFLJPG.py b/DFLIMG/DFLJPG.py new file mode 100644 index 0000000000000000000000000000000000000000..614cebcedcfe3e976aba03a57b1b14feca550908 --- /dev/null +++ b/DFLIMG/DFLJPG.py @@ -0,0 +1,324 @@ +import pickle +import struct +import traceback + +import cv2 +import numpy as np + +from core import imagelib +from core.cv2ex import * +from core.imagelib import SegIEPolys +from core.interact import interact as io +from core.structex import * +from facelib import FaceType + + +class DFLJPG(object): + def __init__(self, filename): + self.filename = filename + self.data = b"" + self.length = 0 + self.chunks = [] + self.dfl_dict = None + self.shape = None + self.img = None + + @staticmethod + def load_raw(filename, loader_func=None): + try: + if loader_func is not None: + data = loader_func(filename) + else: + with open(filename, "rb") as f: + data = f.read() + except: + raise FileNotFoundError(filename) + + try: + inst = DFLJPG(filename) + inst.data = data + inst.length = len(data) + inst_length = inst.length + chunks = [] + data_counter = 0 + while data_counter < inst_length: + chunk_m_l, chunk_m_h = struct.unpack ("BB", data[data_counter:data_counter+2]) + data_counter += 2 + + if chunk_m_l != 0xFF: + raise ValueError(f"No Valid JPG info in {filename}") + + chunk_name = None + chunk_size = None + chunk_data = None + chunk_ex_data = None + is_unk_chunk = False + + if chunk_m_h & 0xF0 == 0xD0: + n = chunk_m_h & 0x0F + + if n >= 0 and n <= 7: + chunk_name = "RST%d" % (n) + chunk_size = 0 + elif n == 0x8: + chunk_name = "SOI" + chunk_size = 0 + if len(chunks) != 0: + raise Exception("") + elif n == 0x9: + chunk_name = "EOI" + chunk_size = 0 + elif n == 0xA: + chunk_name = "SOS" + elif n == 0xB: + chunk_name = "DQT" + elif n == 0xD: + chunk_name = "DRI" + chunk_size = 2 + else: + is_unk_chunk = True + elif chunk_m_h & 0xF0 == 0xC0: + n = chunk_m_h & 0x0F + if n == 0: + chunk_name = "SOF0" + elif n == 2: + chunk_name = "SOF2" + elif n == 4: + chunk_name = "DHT" + else: + is_unk_chunk = True + elif chunk_m_h & 0xF0 == 0xE0: + n = chunk_m_h & 0x0F + chunk_name = "APP%d" % (n) + else: + is_unk_chunk = True + + #if is_unk_chunk: + # #raise ValueError(f"Unknown chunk {chunk_m_h} in {filename}") + # io.log_info(f"Unknown chunk {chunk_m_h} in {filename}") + + if chunk_size == None: #variable size + chunk_size, = struct.unpack (">H", data[data_counter:data_counter+2]) + chunk_size -= 2 + data_counter += 2 + + if chunk_size > 0: + chunk_data = data[data_counter:data_counter+chunk_size] + data_counter += chunk_size + + if chunk_name == "SOS": + c = data_counter + while c < inst_length and (data[c] != 0xFF or data[c+1] != 0xD9): + c += 1 + + chunk_ex_data = data[data_counter:c] + data_counter = c + + chunks.append ({'name' : chunk_name, + 'm_h' : chunk_m_h, + 'data' : chunk_data, + 'ex_data' : chunk_ex_data, + }) + inst.chunks = chunks + + return inst + except Exception as e: + raise Exception (f"Corrupted JPG file {filename} {e}") + + @staticmethod + def load(filename, loader_func=None): + try: + inst = DFLJPG.load_raw (filename, loader_func=loader_func) + inst.dfl_dict = {} + + for chunk in inst.chunks: + if chunk['name'] == 'APP0': + d, c = chunk['data'], 0 + c, id, _ = struct_unpack (d, c, "=4sB") + + if id == b"JFIF": + c, ver_major, ver_minor, units, Xdensity, Ydensity, Xthumbnail, Ythumbnail = struct_unpack (d, c, "=BBBHHBB") + else: + raise Exception("Unknown jpeg ID: %s" % (id) ) + elif chunk['name'] == 'SOF0' or chunk['name'] == 'SOF2': + d, c = chunk['data'], 0 + c, precision, height, width = struct_unpack (d, c, ">BHH") + inst.shape = (height, width, 3) + + elif chunk['name'] == 'APP15': + if type(chunk['data']) == bytes: + inst.dfl_dict = pickle.loads(chunk['data']) + + return inst + except Exception as e: + io.log_err (f'Exception occured while DFLJPG.load : {traceback.format_exc()}') + return None + + def has_data(self): + return len(self.dfl_dict.keys()) != 0 + + def save(self): + try: + with open(self.filename, "wb") as f: + f.write ( self.dump() ) + except: + raise Exception( f'cannot save {self.filename}' ) + + def dump(self): + data = b"" + + dict_data = self.dfl_dict + + # Remove None keys + for key in list(dict_data.keys()): + if dict_data[key] is None: + dict_data.pop(key) + + for chunk in self.chunks: + if chunk['name'] == 'APP15': + self.chunks.remove(chunk) + break + + last_app_chunk = 0 + for i, chunk in enumerate (self.chunks): + if chunk['m_h'] & 0xF0 == 0xE0: + last_app_chunk = i + + dflchunk = {'name' : 'APP15', + 'm_h' : 0xEF, + 'data' : pickle.dumps(dict_data), + 'ex_data' : None, + } + self.chunks.insert (last_app_chunk+1, dflchunk) + + + for chunk in self.chunks: + data += struct.pack ("BB", 0xFF, chunk['m_h'] ) + chunk_data = chunk['data'] + if chunk_data is not None: + data += struct.pack (">H", len(chunk_data)+2 ) + data += chunk_data + + chunk_ex_data = chunk['ex_data'] + if chunk_ex_data is not None: + data += chunk_ex_data + + return data + + def get_img(self): + if self.img is None: + self.img = cv2_imread(self.filename) + return self.img + + def get_shape(self): + if self.shape is None: + img = self.get_img() + if img is not None: + self.shape = img.shape + return self.shape + + def get_height(self): + for chunk in self.chunks: + if type(chunk) == IHDR: + return chunk.height + return 0 + + def get_dict(self): + return self.dfl_dict + + def set_dict (self, dict_data=None): + self.dfl_dict = dict_data + + def get_face_type(self): return self.dfl_dict.get('face_type', FaceType.toString (FaceType.FULL) ) + def set_face_type(self, face_type): self.dfl_dict['face_type'] = face_type + + def get_landmarks(self): return np.array ( self.dfl_dict['landmarks'] ) + def set_landmarks(self, landmarks): self.dfl_dict['landmarks'] = landmarks + + def get_eyebrows_expand_mod(self): return self.dfl_dict.get ('eyebrows_expand_mod', 1.0) + def set_eyebrows_expand_mod(self, eyebrows_expand_mod): self.dfl_dict['eyebrows_expand_mod'] = eyebrows_expand_mod + + def get_source_filename(self): return self.dfl_dict.get ('source_filename', None) + def set_source_filename(self, source_filename): self.dfl_dict['source_filename'] = source_filename + + def get_source_rect(self): return self.dfl_dict.get ('source_rect', None) + def set_source_rect(self, source_rect): self.dfl_dict['source_rect'] = source_rect + + def get_source_landmarks(self): return np.array ( self.dfl_dict.get('source_landmarks', None) ) + def set_source_landmarks(self, source_landmarks): self.dfl_dict['source_landmarks'] = source_landmarks + + def get_image_to_face_mat(self): + mat = self.dfl_dict.get ('image_to_face_mat', None) + if mat is not None: + return np.array (mat) + return None + def set_image_to_face_mat(self, image_to_face_mat): self.dfl_dict['image_to_face_mat'] = image_to_face_mat + + def has_seg_ie_polys(self): + return self.dfl_dict.get('seg_ie_polys',None) is not None + + def get_seg_ie_polys(self): + d = self.dfl_dict.get('seg_ie_polys',None) + if d is not None: + d = SegIEPolys.load(d) + else: + d = SegIEPolys() + + return d + + def set_seg_ie_polys(self, seg_ie_polys): + if seg_ie_polys is not None: + if not isinstance(seg_ie_polys, SegIEPolys): + raise ValueError('seg_ie_polys should be instance of SegIEPolys') + + if seg_ie_polys.has_polys(): + seg_ie_polys = seg_ie_polys.dump() + else: + seg_ie_polys = None + + self.dfl_dict['seg_ie_polys'] = seg_ie_polys + + def has_xseg_mask(self): + return self.dfl_dict.get('xseg_mask',None) is not None + + def get_xseg_mask_compressed(self): + mask_buf = self.dfl_dict.get('xseg_mask',None) + if mask_buf is None: + return None + + return mask_buf + + def get_xseg_mask(self): + mask_buf = self.dfl_dict.get('xseg_mask',None) + if mask_buf is None: + return None + + img = cv2.imdecode(mask_buf, cv2.IMREAD_UNCHANGED) + if len(img.shape) == 2: + img = img[...,None] + + return img.astype(np.float32) / 255.0 + + + def set_xseg_mask(self, mask_a): + if mask_a is None: + self.dfl_dict['xseg_mask'] = None + return + + mask_a = imagelib.normalize_channels(mask_a, 1) + img_data = np.clip( mask_a*255, 0, 255 ).astype(np.uint8) + + data_max_len = 50000 + + ret, buf = cv2.imencode('.png', img_data) + + if not ret or len(buf) > data_max_len: + for jpeg_quality in range(100,-1,-1): + ret, buf = cv2.imencode( '.jpg', img_data, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality] ) + if ret and len(buf) <= data_max_len: + break + + if not ret: + raise Exception("set_xseg_mask: unable to generate image data for set_xseg_mask") + + self.dfl_dict['xseg_mask'] = buf diff --git a/DFLIMG/__init__.py b/DFLIMG/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62dac28f6246005d76fe7f452f0af9195054d09a --- /dev/null +++ b/DFLIMG/__init__.py @@ -0,0 +1,2 @@ +from .DFLIMG import DFLIMG +from .DFLJPG import DFLJPG \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..94a9ed024d3859793618152ea559a168bbcbb5e2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/README.md b/README.md index 154df8298fab5ecf322016157858e08cd1bccbe1..9c7e467b69cd0da1c866fd50e9490a99308c8948 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,429 @@ ---- -license: apache-2.0 ---- + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +# DeepFaceLab + + + + +https://arxiv.org/abs/2005.05535 + + +### the leading software for creating deepfakes + + + +
+ +

+ +![](doc/logo_tensorflow.png) +![](doc/logo_cuda.png) +![](doc/logo_directx.png) + +

+ +More than 95% of deepfake videos are created with DeepFaceLab. + +DeepFaceLab is used by such popular youtube channels as + +|![](doc/tiktok_icon.png) [deeptomcruise](https://www.tiktok.com/@deeptomcruise)|![](doc/tiktok_icon.png) [1facerussia](https://www.tiktok.com/@1facerussia)|![](doc/tiktok_icon.png) [arnoldschwarzneggar](https://www.tiktok.com/@arnoldschwarzneggar) +|---|---|---| + +|![](doc/tiktok_icon.png) [mariahcareyathome?](https://www.tiktok.com/@mariahcareyathome?)|![](doc/tiktok_icon.png) [diepnep](https://www.tiktok.com/@diepnep)|![](doc/tiktok_icon.png) [mr__heisenberg](https://www.tiktok.com/@mr__heisenberg)|![](doc/tiktok_icon.png) [deepcaprio](https://www.tiktok.com/@deepcaprio) +|---|---|---|---| + +|![](doc/youtube_icon.png) [VFXChris Ume](https://www.youtube.com/channel/UCGf4OlX_aTt8DlrgiH3jN3g/videos)|![](doc/youtube_icon.png) [Sham00k](https://www.youtube.com/channel/UCZXbWcv7fSZFTAZV4beckyw/videos)| +|---|---| + +|![](doc/youtube_icon.png) [Collider videos](https://www.youtube.com/watch?v=A91P2qtPT54&list=PLayt6616lBclvOprvrC8qKGCO-mAhPRux)|![](doc/youtube_icon.png) [iFake](https://www.youtube.com/channel/UCC0lK2Zo2BMXX-k1Ks0r7dg/videos)|![](doc/youtube_icon.png) [NextFace](https://www.youtube.com/channel/UCFh3gL0a8BS21g-DHvXZEeQ/videos)| +|---|---|---| + +|![](doc/youtube_icon.png) [Futuring Machine](https://www.youtube.com/channel/UCC5BbFxqLQgfnWPhprmQLVg)|![](doc/youtube_icon.png) [RepresentUS](https://www.youtube.com/channel/UCRzgK52MmetD9aG8pDOID3g)|![](doc/youtube_icon.png) [Corridor Crew](https://www.youtube.com/c/corridorcrew/videos)| +|---|---|---| + +|![](doc/youtube_icon.png) [DeepFaker](https://www.youtube.com/channel/UCkHecfDTcSazNZSKPEhtPVQ)|![](doc/youtube_icon.png) [DeepFakes in movie](https://www.youtube.com/c/DeepFakesinmovie/videos)| +|---|---| + +|![](doc/youtube_icon.png) [DeepFakeCreator](https://www.youtube.com/channel/UCkNFhcYNLQ5hr6A6lZ56mKA)|![](doc/youtube_icon.png) [Jarkan](https://www.youtube.com/user/Jarkancio/videos)| +|---|---| + +
+ +# What can I do using DeepFaceLab? + +
+ +## Replace the face + + + +
+ +## De-age the face + +
+ + + + + + + +
+ +![](doc/youtube_icon.png) https://www.youtube.com/watch?v=Ddx5B-84ebo + +
+ +## Replace the head + +
+ + + + + + + +
+ +![](doc/youtube_icon.png) https://www.youtube.com/watch?v=xr5FHd0AdlQ + +
+ + + + + + + +
+ +![](doc/youtube_icon.png) https://www.youtube.com/watch?v=RTjgkhMugVw + +
+ + + + + + + +
+ +![](doc/youtube_icon.png) https://www.youtube.com/watch?v=R9f7WD0gKPo + +
+ +## Manipulate politicians lips +(voice replacement is not included!) +(also requires a skill in video editors such as *Adobe After Effects* or *Davinci Resolve*) + + + +![](doc/youtube_icon.png) https://www.youtube.com/watch?v=IvY-Abd2FfM + + + +![](doc/youtube_icon.png) https://www.youtube.com/watch?v=ERQlaJ_czHU + +
+ +# Deepfake native resolution progress + +
+ + + +
+ + + +Unfortunately, there is no "make everything ok" button in DeepFaceLab. You should spend time studying the workflow and growing your skills. A skill in programs such as *AfterEffects* or *Davinci Resolve* is also desirable. + +
+ +## Mini tutorial + + + + + + + +
+ +## Releases + +
+Windows (magnet link) +Last release. Use torrent client to download.
+Windows (Mega.nz) +Contains new and prev releases.
+Google Colab (github) +by @chervonij . You can train fakes for free using Google Colab.
+Linux (github) +by @nagadit
+CentOS Linux (github) +May be outdated. By @elemantalcode
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +## Links + +
+ +### Guides and tutorials + +
+DeepFaceLab guide +Main guide
+Faceset creation guide +How to create the right faceset
+Google Colab guide +Guide how to train the fake on Google Colab
+Compositing +To achieve the highest quality, compose deepfake manually in video editors such as Davinci Resolve or Adobe AfterEffects
+Discussion and suggestions +
+ +### Supplementary material + +
+Ready to work facesets +Celebrity facesets made by community
+Pretrained models +Pretrained models made by community
+ +### Communication groups + +
+Discord +Official discord channel. English / Russian.
+Telegram group +Official telegram group. English / Russian. For anonymous communication. Don't forget to hide your phone number
+Русский форум +
+mrdeepfakes +the biggest NSFW English community
+reddit r/DeepFakesSFW/ +Post your deepfakes there !
+reddit r/RUdeepfakes/ +Постим русские дипфейки сюда !
+QQ群1095077489 +中文交流QQ群,商务合作找群主
+dfldata.xyz +中文交流论坛,免费软件教程、模型、人脸数据
+deepfaker.xyz +中文学习站(非官方)
+ +## Related works + +
+DeepFaceLive +Real-time face swap for PC streaming or video calls
+neuralchen/SimSwap +Swapping face using ONE single photo 一张图免训练换脸
+deepfakes/faceswap +Something that was before DeepFaceLab and still remains in the past
+ + + + + + + + + + + + + + + + + + + + + +
+ +## How I can help the project? + +
+ +### Sponsor deepfake research and DeepFaceLab development. + +
+Donate via Paypal +
+Donate via Yandex.Money +
+bitcoin:bc1qkhh7h0gwwhxgg6h6gpllfgstkd645fefrd5s6z +
+ +### Collect facesets + +
+ +You can collect faceset of any celebrity that can be used in DeepFaceLab and share it in the community +
+ +### Star this repo + +
+ +Register github account and push "Star" button. + +
+ + + + + + + + + + + + + + + + + + +
+ +## Meme zone + +
+ + + + + + + +
+ +
+ +## You don't need deepfake detector. You need to stop lying. + + + + + + + +V.I. Lenin +
+ +#deepfacelab #deepfakes #faceswap #face-swap #deep-learning #deeplearning #deep-neural-networks #deepface #deep-face-swap #fakeapp #fake-app #neural-networks #neural-nets #tensorflow #cuda #nvidia + +
diff --git a/XSegEditor/QCursorDB.py b/XSegEditor/QCursorDB.py new file mode 100644 index 0000000000000000000000000000000000000000..0909cba2022c96fbe30042cb05dda6b473595c24 --- /dev/null +++ b/XSegEditor/QCursorDB.py @@ -0,0 +1,10 @@ +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * + +class QCursorDB(): + @staticmethod + def initialize(cursor_path): + QCursorDB.cross_red = QCursor ( QPixmap ( str(cursor_path / 'cross_red.png') ) ) + QCursorDB.cross_green = QCursor ( QPixmap ( str(cursor_path / 'cross_green.png') ) ) + QCursorDB.cross_blue = QCursor ( QPixmap ( str(cursor_path / 'cross_blue.png') ) ) diff --git a/XSegEditor/QIconDB.py b/XSegEditor/QIconDB.py new file mode 100644 index 0000000000000000000000000000000000000000..1fd9e3e66e381e997b8e579da6e3029a45e635e3 --- /dev/null +++ b/XSegEditor/QIconDB.py @@ -0,0 +1,26 @@ +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * + + +class QIconDB(): + @staticmethod + def initialize(icon_path): + QIconDB.app_icon = QIcon ( str(icon_path / 'app_icon.png') ) + QIconDB.delete_poly = QIcon ( str(icon_path / 'delete_poly.png') ) + QIconDB.undo_pt = QIcon ( str(icon_path / 'undo_pt.png') ) + QIconDB.redo_pt = QIcon ( str(icon_path / 'redo_pt.png') ) + QIconDB.poly_color_red = QIcon ( str(icon_path / 'poly_color_red.png') ) + QIconDB.poly_color_green = QIcon ( str(icon_path / 'poly_color_green.png') ) + QIconDB.poly_color_blue = QIcon ( str(icon_path / 'poly_color_blue.png') ) + QIconDB.poly_type_include = QIcon ( str(icon_path / 'poly_type_include.png') ) + QIconDB.poly_type_exclude = QIcon ( str(icon_path / 'poly_type_exclude.png') ) + QIconDB.left = QIcon ( str(icon_path / 'left.png') ) + QIconDB.right = QIcon ( str(icon_path / 'right.png') ) + QIconDB.trashcan = QIcon ( str(icon_path / 'trashcan.png') ) + QIconDB.pt_edit_mode = QIcon ( str(icon_path / 'pt_edit_mode.png') ) + QIconDB.view_lock_center = QIcon ( str(icon_path / 'view_lock_center.png') ) + QIconDB.view_baked = QIcon ( str(icon_path / 'view_baked.png') ) + QIconDB.view_xseg = QIcon ( str(icon_path / 'view_xseg.png') ) + QIconDB.view_xseg_overlay = QIcon ( str(icon_path / 'view_xseg_overlay.png') ) + \ No newline at end of file diff --git a/XSegEditor/QImageDB.py b/XSegEditor/QImageDB.py new file mode 100644 index 0000000000000000000000000000000000000000..45cad78d3b82a03088aa92afd9863cca6c0dd856 --- /dev/null +++ b/XSegEditor/QImageDB.py @@ -0,0 +1,8 @@ +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * + +class QImageDB(): + @staticmethod + def initialize(image_path): + QImageDB.intro = QImage ( str(image_path / 'intro.png') ) diff --git a/XSegEditor/QStringDB.py b/XSegEditor/QStringDB.py new file mode 100644 index 0000000000000000000000000000000000000000..b9100d255e0777fd46ac620b05c1f7e0f7dfed1a --- /dev/null +++ b/XSegEditor/QStringDB.py @@ -0,0 +1,102 @@ +from localization import system_language + + +class QStringDB(): + + @staticmethod + def initialize(): + lang = system_language + + if lang not in ['en','ru','zh']: + lang = 'en' + + QStringDB.btn_poly_color_red_tip = { 'en' : 'Poly color scheme red', + 'ru' : 'Красная цветовая схема полигонов', + 'zh' : '选区配色方案红色', + }[lang] + + QStringDB.btn_poly_color_green_tip = { 'en' : 'Poly color scheme green', + 'ru' : 'Зелёная цветовая схема полигонов', + 'zh' : '选区配色方案绿色', + }[lang] + + QStringDB.btn_poly_color_blue_tip = { 'en' : 'Poly color scheme blue', + 'ru' : 'Синяя цветовая схема полигонов', + 'zh' : '选区配色方案蓝色', + }[lang] + + QStringDB.btn_view_baked_mask_tip = { 'en' : 'View baked mask', + 'ru' : 'Посмотреть запечёную маску', + 'zh' : '查看遮罩通道', + }[lang] + + QStringDB.btn_view_xseg_mask_tip = { 'en' : 'View trained XSeg mask', + 'ru' : 'Посмотреть тренированную XSeg маску', + 'zh' : '查看导入后的XSeg遮罩', + }[lang] + + QStringDB.btn_view_xseg_overlay_mask_tip = { 'en' : 'View trained XSeg mask overlay face', + 'ru' : 'Посмотреть тренированную XSeg маску поверх лица', + 'zh' : '查看导入后的XSeg遮罩于脸上方', + }[lang] + + QStringDB.btn_poly_type_include_tip = { 'en' : 'Poly include mode', + 'ru' : 'Режим полигонов - включение', + 'zh' : '包含选区模式', + }[lang] + + QStringDB.btn_poly_type_exclude_tip = { 'en' : 'Poly exclude mode', + 'ru' : 'Режим полигонов - исключение', + 'zh' : '排除选区模式', + }[lang] + + QStringDB.btn_undo_pt_tip = { 'en' : 'Undo point', + 'ru' : 'Отменить точку', + 'zh' : '撤消点', + }[lang] + + QStringDB.btn_redo_pt_tip = { 'en' : 'Redo point', + 'ru' : 'Повторить точку', + 'zh' : '重做点', + }[lang] + + QStringDB.btn_delete_poly_tip = { 'en' : 'Delete poly', + 'ru' : 'Удалить полигон', + 'zh' : '删除选区', + }[lang] + + QStringDB.btn_pt_edit_mode_tip = { 'en' : 'Add/delete point mode ( HOLD CTRL )', + 'ru' : 'Режим добавления/удаления точек ( удерживайте CTRL )', + 'zh' : '点加/删除模式 ( 按住CTRL )', + }[lang] + + QStringDB.btn_view_lock_center_tip = { 'en' : 'Lock cursor at the center ( HOLD SHIFT )', + 'ru' : 'Заблокировать курсор в центре ( удерживайте SHIFT )', + 'zh' : '将光标锁定在中心 ( 按住SHIFT )', + }[lang] + + + QStringDB.btn_prev_image_tip = { 'en' : 'Save and Prev image\nHold SHIFT : accelerate\nHold CTRL : skip non masked\n', + 'ru' : 'Сохранить и предыдущее изображение\nУдерживать SHIFT : ускорить\nУдерживать CTRL : пропустить неразмеченные\n', + 'zh' : '保存并转到上一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n', + }[lang] + QStringDB.btn_next_image_tip = { 'en' : 'Save and Next image\nHold SHIFT : accelerate\nHold CTRL : skip non masked\n', + 'ru' : 'Сохранить и следующее изображение\nУдерживать SHIFT : ускорить\nУдерживать CTRL : пропустить неразмеченные\n', + 'zh' : '保存并转到下一张图片\n按住SHIFT : 加快\n按住CTRL : 跳过未标记的\n', + }[lang] + + QStringDB.btn_delete_image_tip = { 'en' : 'Move to _trash and Next image\n', + 'ru' : 'Переместить в _trash и следующее изображение\n', + 'zh' : '移至_trash,转到下一张图片 ', + }[lang] + + QStringDB.loading_tip = {'en' : 'Loading', + 'ru' : 'Загрузка', + 'zh' : '正在载入', + }[lang] + + QStringDB.labeled_tip = {'en' : 'labeled', + 'ru' : 'размечено', + 'zh' : '标记的', + }[lang] + diff --git a/XSegEditor/XSegEditor - Copy.py b/XSegEditor/XSegEditor - Copy.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e7b315c7e2fac6037945c4339b3e64bcc8af0a --- /dev/null +++ b/XSegEditor/XSegEditor - Copy.py @@ -0,0 +1,1522 @@ +import json +import multiprocessing +import os +import pickle +import sys +import tempfile +import time +import traceback +from enum import IntEnum +from types import SimpleNamespace as sn + +import cv2 +import numpy as np +import numpy.linalg as npla +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * + +from core import imagelib, pathex +from core.cv2ex import * +from core.imagelib import SegIEPoly, SegIEPolys, SegIEPolyType, sd +from core.qtex import * +from DFLIMG import * +from localization import StringsDB, system_language +from samplelib import PackedFaceset + +from .QCursorDB import QCursorDB +from .QIconDB import QIconDB +from .QStringDB import QStringDB +from .QImageDB import QImageDB + +class OpMode(IntEnum): + NONE = 0 + DRAW_PTS = 1 + EDIT_PTS = 2 + VIEW_BAKED = 3 + VIEW_XSEG_MASK = 4 + VIEW_XSEG_OVERLAY_MASK = 5 + +class PTEditMode(IntEnum): + MOVE = 0 + ADD_DEL = 1 + +class DragType(IntEnum): + NONE = 0 + IMAGE_LOOK = 1 + POLY_PT = 2 + +class ViewLock(IntEnum): + NONE = 0 + CENTER = 1 + +class QUIConfig(): + @staticmethod + def initialize(icon_size = 48, icon_spacer_size=16, preview_bar_icon_size=64): + QUIConfig.icon_q_size = QSize(icon_size, icon_size) + QUIConfig.icon_spacer_q_size = QSize(icon_spacer_size, icon_spacer_size) + QUIConfig.preview_bar_icon_q_size = QSize(preview_bar_icon_size, preview_bar_icon_size) + +class ImagePreviewSequenceBar(QFrame): + def __init__(self, preview_images_count, icon_size): + super().__init__() + self.preview_images_count = preview_images_count = max(1, preview_images_count + (preview_images_count % 2 -1) ) + + self.icon_size = icon_size + + black_q_img = QImage(np.zeros( (icon_size,icon_size,3) ).data, icon_size, icon_size, 3*icon_size, QImage.Format_RGB888) + self.black_q_pixmap = QPixmap.fromImage(black_q_img) + + self.image_containers = [ QLabel() for i in range(preview_images_count)] + + main_frame_l_cont_hl = QGridLayout() + main_frame_l_cont_hl.setContentsMargins(0,0,0,0) + #main_frame_l_cont_hl.setSpacing(0) + + + + for i in range(len(self.image_containers)): + q_label = self.image_containers[i] + q_label.setScaledContents(True) + if i == preview_images_count//2: + q_label.setMinimumSize(icon_size+16, icon_size+16 ) + q_label.setMaximumSize(icon_size+16, icon_size+16 ) + else: + q_label.setMinimumSize(icon_size, icon_size ) + q_label.setMaximumSize(icon_size, icon_size ) + opacity_effect = QGraphicsOpacityEffect() + opacity_effect.setOpacity(0.5) + q_label.setGraphicsEffect(opacity_effect) + + q_label.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + + main_frame_l_cont_hl.addWidget (q_label, 0, i) + + self.setLayout(main_frame_l_cont_hl) + + self.prev_img_conts = self.image_containers[(preview_images_count//2) -1::-1] + self.next_img_conts = self.image_containers[preview_images_count//2:] + + self.update_images() + + def get_preview_images_count(self): + return self.preview_images_count + + def update_images(self, prev_imgs=None, next_imgs=None): + # Fix arrays + if prev_imgs is None: + prev_imgs = [] + prev_img_conts_len = len(self.prev_img_conts) + prev_q_imgs_len = len(prev_imgs) + if prev_q_imgs_len < prev_img_conts_len: + for i in range ( prev_img_conts_len - prev_q_imgs_len ): + prev_imgs.append(None) + elif prev_q_imgs_len > prev_img_conts_len: + prev_imgs = prev_imgs[:prev_img_conts_len] + + if next_imgs is None: + next_imgs = [] + next_img_conts_len = len(self.next_img_conts) + next_q_imgs_len = len(next_imgs) + if next_q_imgs_len < next_img_conts_len: + for i in range ( next_img_conts_len - next_q_imgs_len ): + next_imgs.append(None) + elif next_q_imgs_len > next_img_conts_len: + next_imgs = next_imgs[:next_img_conts_len] + + for i,img in enumerate(prev_imgs): + self.prev_img_conts[i].setPixmap( QPixmap.fromImage( QImage_from_np(img) ) if img is not None else self.black_q_pixmap ) + + for i,img in enumerate(next_imgs): + self.next_img_conts[i].setPixmap( QPixmap.fromImage( QImage_from_np(img) ) if img is not None else self.black_q_pixmap ) + +class ColorScheme(): + def __init__(self, unselected_color, selected_color, outline_color, outline_width, pt_outline_color, cross_cursor): + self.poly_unselected_brush = QBrush(unselected_color) + self.poly_selected_brush = QBrush(selected_color) + + self.poly_outline_solid_pen = QPen(outline_color, outline_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin) + self.poly_outline_dot_pen = QPen(outline_color, outline_width, Qt.DotLine, Qt.RoundCap, Qt.RoundJoin) + + self.pt_outline_pen = QPen(pt_outline_color) + self.cross_cursor = cross_cursor + +class CanvasConfig(): + + def __init__(self, + pt_radius=4, + pt_select_radius=8, + color_schemes=None, + **kwargs): + self.pt_radius = pt_radius + self.pt_select_radius = pt_select_radius + + if color_schemes is None: + color_schemes = [ + ColorScheme( QColor(192,0,0,alpha=0), QColor(192,0,0,alpha=72), QColor(192,0,0), 2, QColor(255,255,255), QCursorDB.cross_red ), + ColorScheme( QColor(0,192,0,alpha=0), QColor(0,192,0,alpha=72), QColor(0,192,0), 2, QColor(255,255,255), QCursorDB.cross_green ), + ColorScheme( QColor(0,0,192,alpha=0), QColor(0,0,192,alpha=72), QColor(0,0,192), 2, QColor(255,255,255), QCursorDB.cross_blue ), + ] + self.color_schemes = color_schemes + +class QCanvasControlsLeftBar(QFrame): + + def __init__(self): + super().__init__() + #============================================== + btn_poly_type_include = QToolButton() + self.btn_poly_type_include_act = QActionEx( QIconDB.poly_type_include, QStringDB.btn_poly_type_include_tip, shortcut='Q', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_type_include.setDefaultAction(self.btn_poly_type_include_act) + btn_poly_type_include.setIconSize(QUIConfig.icon_q_size) + + btn_poly_type_exclude = QToolButton() + self.btn_poly_type_exclude_act = QActionEx( QIconDB.poly_type_exclude, QStringDB.btn_poly_type_exclude_tip, shortcut='W', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_type_exclude.setDefaultAction(self.btn_poly_type_exclude_act) + btn_poly_type_exclude.setIconSize(QUIConfig.icon_q_size) + + self.btn_poly_type_act_grp = QActionGroup (self) + self.btn_poly_type_act_grp.addAction(self.btn_poly_type_include_act) + self.btn_poly_type_act_grp.addAction(self.btn_poly_type_exclude_act) + self.btn_poly_type_act_grp.setExclusive(True) + #============================================== + btn_undo_pt = QToolButton() + self.btn_undo_pt_act = QActionEx( QIconDB.undo_pt, QStringDB.btn_undo_pt_tip, shortcut='Ctrl+Z', shortcut_in_tooltip=True, is_auto_repeat=True) + btn_undo_pt.setDefaultAction(self.btn_undo_pt_act) + btn_undo_pt.setIconSize(QUIConfig.icon_q_size) + + btn_redo_pt = QToolButton() + self.btn_redo_pt_act = QActionEx( QIconDB.redo_pt, QStringDB.btn_redo_pt_tip, shortcut='Ctrl+Shift+Z', shortcut_in_tooltip=True, is_auto_repeat=True) + btn_redo_pt.setDefaultAction(self.btn_redo_pt_act) + btn_redo_pt.setIconSize(QUIConfig.icon_q_size) + + btn_delete_poly = QToolButton() + self.btn_delete_poly_act = QActionEx( QIconDB.delete_poly, QStringDB.btn_delete_poly_tip, shortcut='Delete', shortcut_in_tooltip=True) + btn_delete_poly.setDefaultAction(self.btn_delete_poly_act) + btn_delete_poly.setIconSize(QUIConfig.icon_q_size) + #============================================== + btn_pt_edit_mode = QToolButton() + self.btn_pt_edit_mode_act = QActionEx( QIconDB.pt_edit_mode, QStringDB.btn_pt_edit_mode_tip, shortcut_in_tooltip=True, is_checkable=True) + btn_pt_edit_mode.setDefaultAction(self.btn_pt_edit_mode_act) + btn_pt_edit_mode.setIconSize(QUIConfig.icon_q_size) + #============================================== + + controls_bar_frame2_l = QVBoxLayout() + controls_bar_frame2_l.addWidget ( btn_poly_type_include ) + controls_bar_frame2_l.addWidget ( btn_poly_type_exclude ) + controls_bar_frame2 = QFrame() + controls_bar_frame2.setFrameShape(QFrame.StyledPanel) + controls_bar_frame2.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame2.setLayout(controls_bar_frame2_l) + + controls_bar_frame3_l = QVBoxLayout() + controls_bar_frame3_l.addWidget ( btn_undo_pt ) + controls_bar_frame3_l.addWidget ( btn_redo_pt ) + controls_bar_frame3_l.addWidget ( btn_delete_poly ) + controls_bar_frame3 = QFrame() + controls_bar_frame3.setFrameShape(QFrame.StyledPanel) + controls_bar_frame3.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame3.setLayout(controls_bar_frame3_l) + + controls_bar_frame4_l = QVBoxLayout() + controls_bar_frame4_l.addWidget ( btn_pt_edit_mode ) + controls_bar_frame4 = QFrame() + controls_bar_frame4.setFrameShape(QFrame.StyledPanel) + controls_bar_frame4.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame4.setLayout(controls_bar_frame4_l) + + btn_view_lock_center = QToolButton() + self.btn_view_lock_center_act = QActionEx( QIconDB.view_lock_center, QStringDB.btn_view_lock_center_tip, shortcut_in_tooltip=True, is_checkable=True) + btn_view_lock_center.setDefaultAction(self.btn_view_lock_center_act) + btn_view_lock_center.setIconSize(QUIConfig.icon_q_size) + + controls_bar_frame5_l = QVBoxLayout() + controls_bar_frame5_l.addWidget ( btn_view_lock_center ) + controls_bar_frame5 = QFrame() + controls_bar_frame5.setFrameShape(QFrame.StyledPanel) + controls_bar_frame5.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame5.setLayout(controls_bar_frame5_l) + + + controls_bar_l = QVBoxLayout() + controls_bar_l.setContentsMargins(0,0,0,0) + controls_bar_l.addWidget(controls_bar_frame2) + controls_bar_l.addWidget(controls_bar_frame3) + controls_bar_l.addWidget(controls_bar_frame4) + controls_bar_l.addWidget(controls_bar_frame5) + + self.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Expanding ) + self.setLayout(controls_bar_l) + +class QCanvasControlsRightBar(QFrame): + + def __init__(self): + super().__init__() + #============================================== + btn_poly_color_red = QToolButton() + self.btn_poly_color_red_act = QActionEx( QIconDB.poly_color_red, QStringDB.btn_poly_color_red_tip, shortcut='1', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_color_red.setDefaultAction(self.btn_poly_color_red_act) + btn_poly_color_red.setIconSize(QUIConfig.icon_q_size) + + btn_poly_color_green = QToolButton() + self.btn_poly_color_green_act = QActionEx( QIconDB.poly_color_green, QStringDB.btn_poly_color_green_tip, shortcut='2', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_color_green.setDefaultAction(self.btn_poly_color_green_act) + btn_poly_color_green.setIconSize(QUIConfig.icon_q_size) + + btn_poly_color_blue = QToolButton() + self.btn_poly_color_blue_act = QActionEx( QIconDB.poly_color_blue, QStringDB.btn_poly_color_blue_tip, shortcut='3', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_color_blue.setDefaultAction(self.btn_poly_color_blue_act) + btn_poly_color_blue.setIconSize(QUIConfig.icon_q_size) + + btn_view_baked_mask = QToolButton() + self.btn_view_baked_mask_act = QActionEx( QIconDB.view_baked, QStringDB.btn_view_baked_mask_tip, shortcut='4', shortcut_in_tooltip=True, is_checkable=True) + btn_view_baked_mask.setDefaultAction(self.btn_view_baked_mask_act) + btn_view_baked_mask.setIconSize(QUIConfig.icon_q_size) + + btn_view_xseg_mask = QToolButton() + self.btn_view_xseg_mask_act = QActionEx( QIconDB.view_xseg, QStringDB.btn_view_xseg_mask_tip, shortcut='5', shortcut_in_tooltip=True, is_checkable=True) + btn_view_xseg_mask.setDefaultAction(self.btn_view_xseg_mask_act) + btn_view_xseg_mask.setIconSize(QUIConfig.icon_q_size) + + btn_view_xseg_overlay_mask = QToolButton() + self.btn_view_xseg_overlay_mask_act = QActionEx( QIconDB.view_xseg_overlay, QStringDB.btn_view_xseg_overlay_mask_tip, shortcut='6', shortcut_in_tooltip=True, is_checkable=True) + btn_view_xseg_overlay_mask.setDefaultAction(self.btn_view_xseg_overlay_mask_act) + btn_view_xseg_overlay_mask.setIconSize(QUIConfig.icon_q_size) + + self.btn_poly_color_act_grp = QActionGroup (self) + self.btn_poly_color_act_grp.addAction(self.btn_poly_color_red_act) + self.btn_poly_color_act_grp.addAction(self.btn_poly_color_green_act) + self.btn_poly_color_act_grp.addAction(self.btn_poly_color_blue_act) + self.btn_poly_color_act_grp.addAction(self.btn_view_baked_mask_act) + self.btn_poly_color_act_grp.addAction(self.btn_view_xseg_mask_act) + self.btn_poly_color_act_grp.addAction(self.btn_view_xseg_overlay_mask_act) + self.btn_poly_color_act_grp.setExclusive(True) + #============================================== + + btn_xseg_to_poly = QToolButton() + self.btn_xseg_to_poly_act = QActionEx( QIconDB.view_lock_center, QStringDB.btn_view_lock_center_tip, shortcut_in_tooltip=True, is_checkable=False) + btn_xseg_to_poly.setDefaultAction(self.btn_xseg_to_poly_act) + btn_xseg_to_poly.setIconSize(QUIConfig.icon_q_size) + + controls_bar_frame1_l = QVBoxLayout() + controls_bar_frame1_l.addWidget ( btn_poly_color_red ) + controls_bar_frame1_l.addWidget ( btn_poly_color_green ) + controls_bar_frame1_l.addWidget ( btn_poly_color_blue ) + controls_bar_frame1_l.addWidget ( btn_view_baked_mask ) + controls_bar_frame1_l.addWidget ( btn_view_xseg_mask ) + controls_bar_frame1_l.addWidget ( btn_view_xseg_overlay_mask ) + controls_bar_frame1 = QFrame() + controls_bar_frame1.setFrameShape(QFrame.StyledPanel) + controls_bar_frame1.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame1.setLayout(controls_bar_frame1_l) + + controls_bar_frame2_l = QVBoxLayout() + controls_bar_frame2_l.addWidget ( btn_xseg_to_poly ) + controls_bar_frame2 = QFrame() + controls_bar_frame2.setFrameShape(QFrame.StyledPanel) + controls_bar_frame2.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame2.setLayout(controls_bar_frame2_l) + + controls_bar_l = QVBoxLayout() + controls_bar_l.setContentsMargins(0,0,0,0) + controls_bar_l.addWidget(controls_bar_frame1) + controls_bar_l.addWidget(controls_bar_frame2) + + self.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Expanding ) + self.setLayout(controls_bar_l) + +class QCanvasOperator(QWidget): + def __init__(self, cbar): + super().__init__() + self.cbar = cbar + + self.set_cbar_disabled() + + self.cbar.btn_poly_color_red_act.triggered.connect ( lambda : self.set_color_scheme_id(0) ) + self.cbar.btn_poly_color_green_act.triggered.connect ( lambda : self.set_color_scheme_id(1) ) + self.cbar.btn_poly_color_blue_act.triggered.connect ( lambda : self.set_color_scheme_id(2) ) + self.cbar.btn_view_baked_mask_act.toggled.connect ( lambda : self.set_op_mode(OpMode.VIEW_BAKED) ) + self.cbar.btn_view_xseg_mask_act.toggled.connect ( self.set_view_xseg_mask ) + self.cbar.btn_view_xseg_overlay_mask_act.toggled.connect ( self.set_view_xseg_overlay_mask ) + + self.cbar.btn_poly_type_include_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.INCLUDE) ) + self.cbar.btn_poly_type_exclude_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.EXCLUDE) ) + + self.cbar.btn_undo_pt_act.triggered.connect ( lambda : self.action_undo_pt() ) + self.cbar.btn_redo_pt_act.triggered.connect ( lambda : self.action_redo_pt() ) + + self.cbar.btn_delete_poly_act.triggered.connect ( lambda : self.action_delete_poly() ) + + self.cbar.btn_pt_edit_mode_act.toggled.connect ( lambda is_checked: self.set_pt_edit_mode( PTEditMode.ADD_DEL if is_checked else PTEditMode.MOVE ) ) + self.cbar.btn_view_lock_center_act.toggled.connect ( lambda is_checked: self.set_view_lock( ViewLock.CENTER if is_checked else ViewLock.NONE ) ) + + self.cbar.btn_xseg_to_poly_act.triggered.connect ( lambda : self.action_xseg_to_poly() ) + + + self.mouse_in_widget = False + + QXMainWindow.inst.add_keyPressEvent_listener ( self.on_keyPressEvent ) + QXMainWindow.inst.add_keyReleaseEvent_listener ( self.on_keyReleaseEvent ) + + self.qp = QPainter() + self.initialized = False + self.last_state = None + + def initialize(self, img, img_look_pt=None, view_scale=None, ie_polys=None, xseg_mask=None, canvas_config=None ): + self.img = img + q_img = self.q_img = QImage_from_np(img) + self.img_pixmap = QPixmap.fromImage(q_img) + + self.xseg_mask_in = imagelib.normalize_channels(xseg_mask, 1) + self.xseg_mask_pixmap = None + self.xseg_overlay_mask_pixmap = None + + if xseg_mask is not None: + h,w,c = img.shape + xseg_mask = cv2.resize(xseg_mask, (w,h), cv2.INTER_CUBIC) + xseg_mask = imagelib.normalize_channels(xseg_mask, 1) + xseg_img = img.astype(np.float32)/255.0 + xseg_overlay_mask = xseg_img*(1-xseg_mask)*0.5 + xseg_img*xseg_mask + xseg_overlay_mask = np.clip(xseg_overlay_mask*255, 0, 255).astype(np.uint8) + xseg_mask = np.clip(xseg_mask*255, 0, 255).astype(np.uint8) + self.xseg_mask_pixmap = QPixmap.fromImage(QImage_from_np(xseg_mask)) + self.xseg_overlay_mask_pixmap = QPixmap.fromImage(QImage_from_np(xseg_overlay_mask)) + + self.img_size = QSize_to_np (self.img_pixmap.size()) + + self.img_look_pt = img_look_pt + self.view_scale = view_scale + + if ie_polys is None: + ie_polys = SegIEPolys() + self.ie_polys = ie_polys + + if canvas_config is None: + canvas_config = CanvasConfig() + self.canvas_config = canvas_config + + # UI init + self.set_cbar_disabled() + self.cbar.btn_poly_color_act_grp.setDisabled(False) + self.cbar.btn_poly_type_act_grp.setDisabled(False) + + # Initial vars + self.current_cursor = None + self.mouse_hull_poly = None + self.mouse_wire_poly = None + self.drag_type = DragType.NONE + self.mouse_cli_pt = np.zeros((2,), np.float32 ) + + # Initial state + self.set_op_mode(OpMode.NONE) + self.set_color_scheme_id(1) + self.set_poly_include_type(SegIEPolyType.INCLUDE) + self.set_pt_edit_mode(PTEditMode.MOVE) + self.set_view_lock(ViewLock.NONE) + + # Apply last state + if self.last_state is not None: + self.set_color_scheme_id(self.last_state.color_scheme_id) + if self.last_state.op_mode is not None: + self.set_op_mode(self.last_state.op_mode) + + self.initialized = True + + self.setMouseTracking(True) + self.update_cursor() + self.update() + + + def finalize(self): + if self.initialized: + if self.op_mode == OpMode.DRAW_PTS: + self.set_op_mode(OpMode.EDIT_PTS) + + self.last_state = sn(op_mode = self.op_mode if self.op_mode in [OpMode.VIEW_BAKED, OpMode.VIEW_XSEG_MASK, OpMode.VIEW_XSEG_OVERLAY_MASK] else None, + color_scheme_id = self.color_scheme_id, + ) + + self.img_pixmap = None + self.update_cursor(is_finalize=True) + self.setMouseTracking(False) + self.setFocusPolicy(Qt.NoFocus) + self.set_cbar_disabled() + self.initialized = False + self.update() + + # ==================================================================================== + # ==================================================================================== + # ====================================== GETTERS ===================================== + # ==================================================================================== + # ==================================================================================== + + def is_initialized(self): + return self.initialized + + def get_ie_polys(self): + return self.ie_polys + + def get_cli_center_pt(self): + return np.round(QSize_to_np(self.size())/2.0) + + def get_img_look_pt(self): + img_look_pt = self.img_look_pt + if img_look_pt is None: + img_look_pt = self.img_size / 2 + return img_look_pt + + def get_view_scale(self): + view_scale = self.view_scale + if view_scale is None: + # Calc as scale to fit + min_cli_size = np.min(QSize_to_np(self.size())) + max_img_size = np.max(self.img_size) + view_scale = min_cli_size / max_img_size + + return view_scale + + def get_current_color_scheme(self): + return self.canvas_config.color_schemes[self.color_scheme_id] + + def get_poly_pt_id_under_pt(self, poly, cli_pt): + w = np.argwhere ( npla.norm ( cli_pt - self.img_to_cli_pt( poly.get_pts() ), axis=1 ) <= self.canvas_config.pt_select_radius ) + return None if len(w) == 0 else w[-1][0] + + def get_poly_edge_id_pt_under_pt(self, poly, cli_pt): + cli_pts = self.img_to_cli_pt(poly.get_pts()) + if len(cli_pts) >= 3: + edge_dists, projs = sd.dist_to_edges(cli_pts, cli_pt, is_closed=True) + edge_id = np.argmin(edge_dists) + dist = edge_dists[edge_id] + pt = projs[edge_id] + if dist <= self.canvas_config.pt_select_radius: + return edge_id, pt + return None, None + + def get_poly_by_pt_near_wire(self, cli_pt): + pt_select_radius = self.canvas_config.pt_select_radius + + for poly in reversed(self.ie_polys.get_polys()): + pts = poly.get_pts() + if len(pts) >= 3: + cli_pts = self.img_to_cli_pt(pts) + + edge_dists, _ = sd.dist_to_edges(cli_pts, cli_pt, is_closed=True) + + if np.min(edge_dists) <= pt_select_radius or \ + any( npla.norm ( cli_pt - cli_pts, axis=1 ) <= pt_select_radius ): + return poly + return None + + def get_poly_by_pt_in_hull(self, cli_pos): + img_pos = self.cli_to_img_pt(cli_pos) + + for poly in reversed(self.ie_polys.get_polys()): + pts = poly.get_pts() + if len(pts) >= 3: + if cv2.pointPolygonTest( pts, tuple(img_pos), False) >= 0: + return poly + + return None + + def img_to_cli_pt(self, p): + return (p - self.get_img_look_pt()) * self.get_view_scale() + self.get_cli_center_pt()# QSize_to_np(self.size())/2.0 + + def cli_to_img_pt(self, p): + return (p - self.get_cli_center_pt() ) / self.get_view_scale() + self.get_img_look_pt() + + def img_to_cli_rect(self, rect): + tl = QPoint_to_np(rect.topLeft()) + xy = self.img_to_cli_pt(tl) + xy2 = self.img_to_cli_pt(tl + QSize_to_np(rect.size()) ) - xy + return QRect ( *xy.astype(np.int), *xy2.astype(np.int) ) + + # ==================================================================================== + # ==================================================================================== + # ====================================== SETTERS ===================================== + # ==================================================================================== + # ==================================================================================== + def set_op_mode(self, op_mode, op_poly=None): + if not hasattr(self,'op_mode'): + self.op_mode = None + self.op_poly = None + + if self.op_mode != op_mode: + # Finalize prev mode + if self.op_mode == OpMode.NONE: + self.cbar.btn_poly_type_act_grp.setDisabled(True) + elif self.op_mode == OpMode.DRAW_PTS: + self.cbar.btn_undo_pt_act.setDisabled(True) + self.cbar.btn_redo_pt_act.setDisabled(True) + self.cbar.btn_view_lock_center_act.setDisabled(True) + # Reset view_lock when exit from DRAW_PTS + self.set_view_lock(ViewLock.NONE) + # Remove unfinished poly + if self.op_poly.get_pts_count() < 3: + self.ie_polys.remove_poly(self.op_poly) + + elif self.op_mode == OpMode.EDIT_PTS: + self.cbar.btn_pt_edit_mode_act.setDisabled(True) + self.cbar.btn_delete_poly_act.setDisabled(True) + # Reset pt_edit_move when exit from EDIT_PTS + self.set_pt_edit_mode(PTEditMode.MOVE) + elif self.op_mode == OpMode.VIEW_BAKED: + self.cbar.btn_view_baked_mask_act.setChecked(False) + elif self.op_mode == OpMode.VIEW_XSEG_MASK: + self.cbar.btn_view_xseg_mask_act.setChecked(False) + self.cbar.btn_xseg_to_poly_act.setDisabled(True) + elif self.op_mode == OpMode.VIEW_XSEG_OVERLAY_MASK: + self.cbar.btn_view_xseg_overlay_mask_act.setChecked(False) + self.cbar.btn_xseg_to_poly_act.setDisabled(True) + self.op_mode = op_mode + + # Initialize new mode + if op_mode == OpMode.NONE: + self.cbar.btn_poly_type_act_grp.setDisabled(False) + elif op_mode == OpMode.DRAW_PTS: + self.cbar.btn_undo_pt_act.setDisabled(False) + self.cbar.btn_redo_pt_act.setDisabled(False) + self.cbar.btn_view_lock_center_act.setDisabled(False) + elif op_mode == OpMode.EDIT_PTS: + self.cbar.btn_pt_edit_mode_act.setDisabled(False) + self.cbar.btn_delete_poly_act.setDisabled(False) + elif op_mode == OpMode.VIEW_BAKED: + self.cbar.btn_view_baked_mask_act.setChecked(True ) + n = QImage_to_np ( self.q_img ).astype(np.float32) / 255.0 + h,w,c = n.shape + mask = np.zeros( (h,w,1), dtype=np.float32 ) + self.ie_polys.overlay_mask(mask) + n = (mask*255).astype(np.uint8) + self.img_baked_pixmap = QPixmap.fromImage(QImage_from_np(n)) + elif op_mode == OpMode.VIEW_XSEG_MASK: + self.cbar.btn_view_xseg_mask_act.setChecked(True) + if self.xseg_mask_in is not None: + self.cbar.btn_xseg_to_poly_act.setDisabled(False) + elif op_mode == OpMode.VIEW_XSEG_OVERLAY_MASK: + self.cbar.btn_view_xseg_overlay_mask_act.setChecked(True) + if self.xseg_mask_in is not None: + self.cbar.btn_xseg_to_poly_act.setDisabled(False) + + if op_mode in [OpMode.DRAW_PTS, OpMode.EDIT_PTS]: + self.mouse_op_poly_pt_id = None + self.mouse_op_poly_edge_id = None + self.mouse_op_poly_edge_id_pt = None + + self.op_poly = op_poly + if op_poly is not None: + self.update_mouse_info() + + self.update_cursor() + self.update() + + def set_pt_edit_mode(self, pt_edit_mode): + if not hasattr(self, 'pt_edit_mode') or self.pt_edit_mode != pt_edit_mode: + self.pt_edit_mode = pt_edit_mode + self.update_cursor() + self.update() + self.cbar.btn_pt_edit_mode_act.setChecked( self.pt_edit_mode == PTEditMode.ADD_DEL ) + + def set_view_lock(self, view_lock): + if not hasattr(self, 'view_lock') or self.view_lock != view_lock: + if hasattr(self, 'view_lock') and self.view_lock != view_lock: + if view_lock == ViewLock.CENTER: + self.img_look_pt = self.mouse_img_pt + QCursor.setPos ( self.mapToGlobal( QPoint_from_np(self.img_to_cli_pt(self.img_look_pt)) )) + + self.view_lock = view_lock + self.update() + self.cbar.btn_view_lock_center_act.setChecked( self.view_lock == ViewLock.CENTER ) + + def set_cbar_disabled(self): + self.cbar.btn_delete_poly_act.setDisabled(True) + self.cbar.btn_undo_pt_act.setDisabled(True) + self.cbar.btn_redo_pt_act.setDisabled(True) + self.cbar.btn_pt_edit_mode_act.setDisabled(True) + self.cbar.btn_view_lock_center_act.setDisabled(True) + self.cbar.btn_poly_color_act_grp.setDisabled(True) + self.cbar.btn_poly_type_act_grp.setDisabled(True) + self.cbar.btn_xseg_to_poly_act.setDisabled(True) + + def set_color_scheme_id(self, id): + if self.op_mode == OpMode.VIEW_BAKED: + self.set_op_mode(OpMode.NONE) + + if not hasattr(self, 'color_scheme_id') or self.color_scheme_id != id: + self.color_scheme_id = id + self.update_cursor() + self.update() + + if self.color_scheme_id == 0: + self.cbar.btn_poly_color_red_act.setChecked( True ) + elif self.color_scheme_id == 1: + self.cbar.btn_poly_color_green_act.setChecked( True ) + elif self.color_scheme_id == 2: + self.cbar.btn_poly_color_blue_act.setChecked( True ) + + def set_poly_include_type(self, poly_include_type): + if not hasattr(self, 'poly_include_type' ) or \ + ( self.poly_include_type != poly_include_type and \ + self.op_mode in [OpMode.NONE, OpMode.EDIT_PTS] ): + self.poly_include_type = poly_include_type + self.update() + + self.cbar.btn_poly_type_include_act.setChecked(self.poly_include_type == SegIEPolyType.INCLUDE) + self.cbar.btn_poly_type_exclude_act.setChecked(self.poly_include_type == SegIEPolyType.EXCLUDE) + + def set_view_xseg_mask(self, is_checked): + if is_checked: + self.set_op_mode(OpMode.VIEW_XSEG_MASK) + else: + self.set_op_mode(OpMode.NONE) + + self.cbar.btn_view_xseg_mask_act.setChecked(is_checked ) + + def set_view_xseg_overlay_mask(self, is_checked): + if is_checked: + self.set_op_mode(OpMode.VIEW_XSEG_OVERLAY_MASK) + else: + self.set_op_mode(OpMode.NONE) + + self.cbar.btn_view_xseg_overlay_mask_act.setChecked(is_checked ) + + # ==================================================================================== + # ==================================================================================== + # ====================================== METHODS ===================================== + # ==================================================================================== + # ==================================================================================== + + def update_cursor(self, is_finalize=False): + if not self.initialized: + return + + if not self.mouse_in_widget or is_finalize: + if self.current_cursor is not None: + QApplication.restoreOverrideCursor() + self.current_cursor = None + else: + color_cc = self.get_current_color_scheme().cross_cursor + nc = Qt.ArrowCursor + + if self.drag_type == DragType.IMAGE_LOOK: + nc = Qt.ClosedHandCursor + else: + + if self.op_mode == OpMode.NONE: + nc = color_cc + if self.mouse_wire_poly is not None: + nc = Qt.PointingHandCursor + + elif self.op_mode == OpMode.DRAW_PTS: + nc = color_cc + elif self.op_mode == OpMode.EDIT_PTS: + nc = Qt.ArrowCursor + + if self.mouse_op_poly_pt_id is not None: + nc = Qt.PointingHandCursor + + if self.pt_edit_mode == PTEditMode.ADD_DEL: + + if self.mouse_op_poly_edge_id is not None and \ + self.mouse_op_poly_pt_id is None: + nc = color_cc + if self.current_cursor != nc: + if self.current_cursor is None: + QApplication.setOverrideCursor(nc) + else: + QApplication.changeOverrideCursor(nc) + self.current_cursor = nc + + def update_mouse_info(self, mouse_cli_pt=None): + """ + Update selected polys/edges/points by given mouse position + """ + if mouse_cli_pt is not None: + self.mouse_cli_pt = mouse_cli_pt.astype(np.float32) + + self.mouse_img_pt = self.cli_to_img_pt(self.mouse_cli_pt) + + new_mouse_hull_poly = self.get_poly_by_pt_in_hull(self.mouse_cli_pt) + + if self.mouse_hull_poly != new_mouse_hull_poly: + self.mouse_hull_poly = new_mouse_hull_poly + self.update_cursor() + self.update() + + new_mouse_wire_poly = self.get_poly_by_pt_near_wire(self.mouse_cli_pt) + + if self.mouse_wire_poly != new_mouse_wire_poly: + self.mouse_wire_poly = new_mouse_wire_poly + self.update_cursor() + self.update() + + if self.op_mode in [OpMode.DRAW_PTS, OpMode.EDIT_PTS]: + new_mouse_op_poly_pt_id = self.get_poly_pt_id_under_pt (self.op_poly, self.mouse_cli_pt) + if self.mouse_op_poly_pt_id != new_mouse_op_poly_pt_id: + self.mouse_op_poly_pt_id = new_mouse_op_poly_pt_id + self.update_cursor() + self.update() + + new_mouse_op_poly_edge_id,\ + new_mouse_op_poly_edge_id_pt = self.get_poly_edge_id_pt_under_pt (self.op_poly, self.mouse_cli_pt) + if self.mouse_op_poly_edge_id != new_mouse_op_poly_edge_id: + self.mouse_op_poly_edge_id = new_mouse_op_poly_edge_id + self.update_cursor() + self.update() + + if (self.mouse_op_poly_edge_id_pt.__class__ != new_mouse_op_poly_edge_id_pt.__class__) or \ + (isinstance(self.mouse_op_poly_edge_id_pt, np.ndarray) and \ + all(self.mouse_op_poly_edge_id_pt != new_mouse_op_poly_edge_id_pt)): + + self.mouse_op_poly_edge_id_pt = new_mouse_op_poly_edge_id_pt + self.update_cursor() + self.update() + + + def action_undo_pt(self): + if self.drag_type == DragType.NONE: + if self.op_mode == OpMode.DRAW_PTS: + if self.op_poly.undo() == 0: + self.ie_polys.remove_poly (self.op_poly) + self.set_op_mode(OpMode.NONE) + self.update() + + def action_redo_pt(self): + if self.drag_type == DragType.NONE: + if self.op_mode == OpMode.DRAW_PTS: + self.op_poly.redo() + self.update() + + def action_delete_poly(self): + if self.op_mode == OpMode.EDIT_PTS and \ + self.drag_type == DragType.NONE and \ + self.pt_edit_mode == PTEditMode.MOVE: + # Delete current poly + self.ie_polys.remove_poly (self.op_poly) + self.set_op_mode(OpMode.NONE) + + def action_xseg_to_poly(self): + + cnts = cv2.findContours( (self.xseg_mask_in*255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_TC89_KCOS) + # Sort by countour area + cnts = sorted(cnts[0], key = cv2.contourArea, reverse = True) + if len(cnts) != 0: + h,w,c = self.img.shape + mh,mw,mc = self.xseg_mask_in.shape + + dh = h / mh + dw = w / mw + + new_poly = self.ie_polys.add_poly(SegIEPolyType.INCLUDE) + for pt in cnts[0].squeeze(): + new_poly.add_pt( pt[0]*dw, pt[1]*dh ) + + self.set_op_mode(OpMode.EDIT_PTS, op_poly=new_poly) + + + # ==================================================================================== + # ==================================================================================== + # ================================== OVERRIDE QT METHODS ============================= + # ==================================================================================== + # ==================================================================================== + def on_keyPressEvent(self, ev): + if not self.initialized: + return + key = ev.key() + key_mods = int(ev.modifiers()) + if self.op_mode == OpMode.DRAW_PTS: + self.set_view_lock(ViewLock.CENTER if key_mods == Qt.ShiftModifier else ViewLock.NONE ) + elif self.op_mode == OpMode.EDIT_PTS: + self.set_pt_edit_mode(PTEditMode.ADD_DEL if key_mods == Qt.ControlModifier else PTEditMode.MOVE ) + + def on_keyReleaseEvent(self, ev): + if not self.initialized: + return + key = ev.key() + key_mods = int(ev.modifiers()) + if self.op_mode == OpMode.DRAW_PTS: + self.set_view_lock(ViewLock.CENTER if key_mods == Qt.ShiftModifier else ViewLock.NONE ) + elif self.op_mode == OpMode.EDIT_PTS: + self.set_pt_edit_mode(PTEditMode.ADD_DEL if key_mods == Qt.ControlModifier else PTEditMode.MOVE ) + + def enterEvent(self, ev): + super().enterEvent(ev) + self.mouse_in_widget = True + self.update_cursor() + + def leaveEvent(self, ev): + super().leaveEvent(ev) + self.mouse_in_widget = False + self.update_cursor() + + def mousePressEvent(self, ev): + super().mousePressEvent(ev) + if not self.initialized: + return + + self.update_mouse_info(QPoint_to_np(ev.pos())) + + btn = ev.button() + + if btn == Qt.LeftButton: + if self.op_mode == OpMode.NONE: + # Clicking in NO OPERATION mode + if self.mouse_wire_poly is not None: + # Click on wire on any poly -> switch to EDIT_MODE + self.set_op_mode(OpMode.EDIT_PTS, op_poly=self.mouse_wire_poly) + else: + # Click on empty space -> create new poly with one point + new_poly = self.ie_polys.add_poly(self.poly_include_type) + self.ie_polys.sort() + new_poly.add_pt(*self.mouse_img_pt) + self.set_op_mode(OpMode.DRAW_PTS, op_poly=new_poly ) + + elif self.op_mode == OpMode.DRAW_PTS: + # Clicking in DRAW_PTS mode + if len(self.op_poly.get_pts()) >= 3 and self.mouse_op_poly_pt_id == 0: + # Click on first point -> close poly and switch to edit mode + self.set_op_mode(OpMode.EDIT_PTS, op_poly=self.op_poly) + else: + # Click on empty space -> add point to current poly + self.op_poly.add_pt(*self.mouse_img_pt) + self.update() + + elif self.op_mode == OpMode.EDIT_PTS: + # Clicking in EDIT_PTS mode + + if self.mouse_op_poly_pt_id is not None: + # Click on point of op_poly + if self.pt_edit_mode == PTEditMode.ADD_DEL: + # with mode -> delete point + self.op_poly.remove_pt(self.mouse_op_poly_pt_id) + if self.op_poly.get_pts_count() < 3: + # not enough points -> remove poly + self.ie_polys.remove_poly (self.op_poly) + self.set_op_mode(OpMode.NONE) + self.update() + + elif self.drag_type == DragType.NONE: + # otherwise -> start drag + self.drag_type = DragType.POLY_PT + self.drag_cli_pt = self.mouse_cli_pt + self.drag_poly_pt_id = self.mouse_op_poly_pt_id + self.drag_poly_pt = self.op_poly.get_pts()[ self.drag_poly_pt_id ] + elif self.mouse_op_poly_edge_id is not None: + # Click on edge of op_poly + if self.pt_edit_mode == PTEditMode.ADD_DEL: + # with mode -> insert new point + edge_img_pt = self.cli_to_img_pt(self.mouse_op_poly_edge_id_pt) + self.op_poly.insert_pt (self.mouse_op_poly_edge_id+1, edge_img_pt) + self.update() + else: + # Otherwise do nothing + pass + else: + # other cases -> unselect poly + self.set_op_mode(OpMode.NONE) + + elif btn == Qt.MiddleButton: + if self.drag_type == DragType.NONE: + # Start image drag + self.drag_type = DragType.IMAGE_LOOK + self.drag_cli_pt = self.mouse_cli_pt + self.drag_img_look_pt = self.get_img_look_pt() + self.update_cursor() + + + def mouseReleaseEvent(self, ev): + super().mouseReleaseEvent(ev) + if not self.initialized: + return + + self.update_mouse_info(QPoint_to_np(ev.pos())) + + btn = ev.button() + + if btn == Qt.LeftButton: + if self.op_mode == OpMode.EDIT_PTS: + if self.drag_type == DragType.POLY_PT: + self.drag_type = DragType.NONE + self.update() + + elif btn == Qt.MiddleButton: + if self.drag_type == DragType.IMAGE_LOOK: + self.drag_type = DragType.NONE + self.update_cursor() + self.update() + + def mouseMoveEvent(self, ev): + super().mouseMoveEvent(ev) + if not self.initialized: + return + + prev_mouse_cli_pt = self.mouse_cli_pt + self.update_mouse_info(QPoint_to_np(ev.pos())) + + if self.view_lock == ViewLock.CENTER: + if npla.norm(self.mouse_cli_pt - prev_mouse_cli_pt) >= 1: + self.img_look_pt = self.mouse_img_pt + QCursor.setPos ( self.mapToGlobal( QPoint_from_np(self.img_to_cli_pt(self.img_look_pt)) )) + + self.update() + + if self.drag_type == DragType.IMAGE_LOOK: + delta_pt = self.cli_to_img_pt(self.mouse_cli_pt) - self.cli_to_img_pt(self.drag_cli_pt) + self.img_look_pt = self.drag_img_look_pt - delta_pt + self.update() + + if self.op_mode == OpMode.DRAW_PTS: + self.update() + elif self.op_mode == OpMode.EDIT_PTS: + if self.drag_type == DragType.POLY_PT: + delta_pt = self.cli_to_img_pt(self.mouse_cli_pt) - self.cli_to_img_pt(self.drag_cli_pt) + self.op_poly.set_point(self.drag_poly_pt_id, self.drag_poly_pt + delta_pt) + self.update() + + def wheelEvent(self, ev): + super().wheelEvent(ev) + + if not self.initialized: + return + + mods = int(ev.modifiers()) + delta = ev.angleDelta() + + cli_pt = QPoint_to_np(ev.pos()) + + if self.drag_type == DragType.NONE: + sign = np.sign( delta.y() ) + prev_img_pos = self.cli_to_img_pt (cli_pt) + delta_scale = sign*0.2 + sign * self.get_view_scale() / 10.0 + self.view_scale = np.clip(self.get_view_scale() + delta_scale, 1.0, 20.0) + new_img_pos = self.cli_to_img_pt (cli_pt) + if sign > 0: + self.img_look_pt = self.get_img_look_pt() + (prev_img_pos-new_img_pos)#*1.5 + else: + QCursor.setPos ( self.mapToGlobal(QPoint_from_np(self.img_to_cli_pt(prev_img_pos))) ) + self.update() + + def paintEvent(self, event): + super().paintEvent(event) + if not self.initialized: + return + + qp = self.qp + qp.begin(self) + qp.setRenderHint(QPainter.Antialiasing) + qp.setRenderHint(QPainter.HighQualityAntialiasing) + qp.setRenderHint(QPainter.SmoothPixmapTransform) + + src_rect = QRect(0, 0, *self.img_size) + dst_rect = self.img_to_cli_rect( src_rect ) + + if self.op_mode == OpMode.VIEW_BAKED: + qp.drawPixmap(dst_rect, self.img_baked_pixmap, src_rect) + elif self.op_mode == OpMode.VIEW_XSEG_MASK: + if self.xseg_mask_pixmap is not None: + qp.drawPixmap(dst_rect, self.xseg_mask_pixmap, src_rect) + elif self.op_mode == OpMode.VIEW_XSEG_OVERLAY_MASK: + if self.xseg_overlay_mask_pixmap is not None: + qp.drawPixmap(dst_rect, self.xseg_overlay_mask_pixmap, src_rect) + else: + if self.img_pixmap is not None: + qp.drawPixmap(dst_rect, self.img_pixmap, src_rect) + + polys = self.ie_polys.get_polys() + polys_len = len(polys) + + color_scheme = self.get_current_color_scheme() + + pt_rad = self.canvas_config.pt_radius + pt_rad_x2 = pt_rad*2 + + pt_select_radius = self.canvas_config.pt_select_radius + + op_mode = self.op_mode + op_poly = self.op_poly + + for i,poly in enumerate(polys): + + selected_pt_path = QPainterPath() + poly_line_path = QPainterPath() + pts_line_path = QPainterPath() + + pt_remove_cli_pt = None + poly_pts = poly.get_pts() + for pt_id, img_pt in enumerate(poly_pts): + cli_pt = self.img_to_cli_pt(img_pt) + q_cli_pt = QPoint_from_np(cli_pt) + + if pt_id == 0: + poly_line_path.moveTo(q_cli_pt) + else: + poly_line_path.lineTo(q_cli_pt) + + + if poly == op_poly: + if self.op_mode == OpMode.DRAW_PTS or \ + (self.op_mode == OpMode.EDIT_PTS and \ + (self.pt_edit_mode == PTEditMode.MOVE) or \ + (self.pt_edit_mode == PTEditMode.ADD_DEL and self.mouse_op_poly_pt_id == pt_id) \ + ): + pts_line_path.moveTo( QPoint_from_np(cli_pt + np.float32([0,-pt_rad])) ) + pts_line_path.lineTo( QPoint_from_np(cli_pt + np.float32([0,pt_rad])) ) + pts_line_path.moveTo( QPoint_from_np(cli_pt + np.float32([-pt_rad,0])) ) + pts_line_path.lineTo( QPoint_from_np(cli_pt + np.float32([pt_rad,0])) ) + + if (self.op_mode == OpMode.EDIT_PTS and \ + self.pt_edit_mode == PTEditMode.ADD_DEL and \ + self.mouse_op_poly_pt_id == pt_id): + pt_remove_cli_pt = cli_pt + + if self.op_mode == OpMode.DRAW_PTS and \ + len(op_poly.get_pts()) >= 3 and pt_id == 0 and self.mouse_op_poly_pt_id == pt_id: + # Circle around poly point + selected_pt_path.addEllipse(q_cli_pt, pt_rad_x2, pt_rad_x2) + + + if poly == op_poly: + if op_mode == OpMode.DRAW_PTS: + # Line from last point to mouse + poly_line_path.lineTo( QPoint_from_np(self.mouse_cli_pt) ) + + if self.mouse_op_poly_pt_id is not None: + pass + + if self.mouse_op_poly_edge_id_pt is not None: + if self.pt_edit_mode == PTEditMode.ADD_DEL and self.mouse_op_poly_pt_id is None: + # Ready to insert point on edge + m_cli_pt = self.mouse_op_poly_edge_id_pt + pts_line_path.moveTo( QPoint_from_np(m_cli_pt + np.float32([0,-pt_rad])) ) + pts_line_path.lineTo( QPoint_from_np(m_cli_pt + np.float32([0,pt_rad])) ) + pts_line_path.moveTo( QPoint_from_np(m_cli_pt + np.float32([-pt_rad,0])) ) + pts_line_path.lineTo( QPoint_from_np(m_cli_pt + np.float32([pt_rad,0])) ) + + if len(poly_pts) >= 2: + # Closing poly line + poly_line_path.lineTo( QPoint_from_np(self.img_to_cli_pt(poly_pts[0])) ) + + # Draw calls + qp.setPen(color_scheme.pt_outline_pen) + qp.setBrush(QBrush()) + qp.drawPath(selected_pt_path) + + qp.setPen(color_scheme.poly_outline_solid_pen) + qp.setBrush(QBrush()) + qp.drawPath(pts_line_path) + + if poly.get_type() == SegIEPolyType.INCLUDE: + qp.setPen(color_scheme.poly_outline_solid_pen) + else: + qp.setPen(color_scheme.poly_outline_dot_pen) + + qp.setBrush(color_scheme.poly_unselected_brush) + if op_mode == OpMode.NONE: + if poly == self.mouse_wire_poly: + qp.setBrush(color_scheme.poly_selected_brush) + #else: + # if poly == op_poly: + # qp.setBrush(color_scheme.poly_selected_brush) + + qp.drawPath(poly_line_path) + + if pt_remove_cli_pt is not None: + qp.setPen(color_scheme.poly_outline_solid_pen) + qp.setBrush(QBrush()) + + qp.drawLine( *(pt_remove_cli_pt + np.float32([-pt_rad_x2,-pt_rad_x2])), *(pt_remove_cli_pt + np.float32([pt_rad_x2,pt_rad_x2])) ) + qp.drawLine( *(pt_remove_cli_pt + np.float32([-pt_rad_x2,pt_rad_x2])), *(pt_remove_cli_pt + np.float32([pt_rad_x2,-pt_rad_x2])) ) + + qp.end() + +class QCanvas(QFrame): + def __init__(self): + super().__init__() + + self.canvas_control_left_bar = QCanvasControlsLeftBar() + self.canvas_control_right_bar = QCanvasControlsRightBar() + + cbar = sn( btn_poly_color_red_act = self.canvas_control_right_bar.btn_poly_color_red_act, + btn_poly_color_green_act = self.canvas_control_right_bar.btn_poly_color_green_act, + btn_poly_color_blue_act = self.canvas_control_right_bar.btn_poly_color_blue_act, + btn_view_baked_mask_act = self.canvas_control_right_bar.btn_view_baked_mask_act, + btn_view_xseg_mask_act = self.canvas_control_right_bar.btn_view_xseg_mask_act, + btn_view_xseg_overlay_mask_act = self.canvas_control_right_bar.btn_view_xseg_overlay_mask_act, + btn_poly_color_act_grp = self.canvas_control_right_bar.btn_poly_color_act_grp, + btn_xseg_to_poly_act = self.canvas_control_right_bar.btn_xseg_to_poly_act, + + btn_poly_type_include_act = self.canvas_control_left_bar.btn_poly_type_include_act, + btn_poly_type_exclude_act = self.canvas_control_left_bar.btn_poly_type_exclude_act, + btn_poly_type_act_grp = self.canvas_control_left_bar.btn_poly_type_act_grp, + btn_undo_pt_act = self.canvas_control_left_bar.btn_undo_pt_act, + btn_redo_pt_act = self.canvas_control_left_bar.btn_redo_pt_act, + btn_delete_poly_act = self.canvas_control_left_bar.btn_delete_poly_act, + btn_pt_edit_mode_act = self.canvas_control_left_bar.btn_pt_edit_mode_act, + btn_view_lock_center_act = self.canvas_control_left_bar.btn_view_lock_center_act, ) + + self.op = QCanvasOperator(cbar) + self.l = QHBoxLayout() + self.l.setContentsMargins(0,0,0,0) + self.l.addWidget(self.canvas_control_left_bar) + self.l.addWidget(self.op) + self.l.addWidget(self.canvas_control_right_bar) + self.setLayout(self.l) + +class LoaderQSubprocessor(QSubprocessor): + def __init__(self, image_paths, q_label, q_progressbar, on_finish_func ): + + self.image_paths = image_paths + self.image_paths_len = len(image_paths) + self.idxs = [*range(self.image_paths_len)] + + self.filtered_image_paths = self.image_paths.copy() + + self.image_paths_has_ie_polys = { image_path : False for image_path in self.image_paths } + + self.q_label = q_label + self.q_progressbar = q_progressbar + self.q_progressbar.setRange(0, self.image_paths_len) + self.q_progressbar.setValue(0) + self.q_progressbar.update() + self.on_finish_func = on_finish_func + self.done_count = 0 + super().__init__('LoaderQSubprocessor', LoaderQSubprocessor.Cli, 60) + + def get_data(self, host_dict): + if len (self.idxs) > 0: + idx = self.idxs.pop(0) + image_path = self.image_paths[idx] + self.q_label.setText(f'{QStringDB.loading_tip}... {image_path.name}') + + return idx, image_path + + return None + + def on_clients_finalized(self): + self.on_finish_func([x for x in self.filtered_image_paths if x is not None], self.image_paths_has_ie_polys) + + def on_data_return (self, host_dict, data): + self.idxs.insert(0, data[0]) + + def on_result (self, host_dict, data, result): + idx, has_dflimg, has_ie_polys = result + + if not has_dflimg: + self.filtered_image_paths[idx] = None + self.image_paths_has_ie_polys[self.image_paths[idx]] = has_ie_polys + + self.done_count += 1 + if self.q_progressbar is not None: + self.q_progressbar.setValue(self.done_count) + + class Cli(QSubprocessor.Cli): + def process_data(self, data): + idx, filename = data + dflimg = DFLIMG.load(filename) + if dflimg is not None and dflimg.has_data(): + ie_polys = dflimg.get_seg_ie_polys() + + return idx, True, ie_polys.has_polys() + return idx, False, False + +class MainWindow(QXMainWindow): + + def __init__(self, input_dirpath, cfg_root_path): + self.loading_frame = None + self.help_frame = None + + super().__init__() + + self.input_dirpath = input_dirpath + self.cfg_root_path = cfg_root_path + + self.cfg_path = cfg_root_path / 'MainWindow_cfg.dat' + self.cfg_dict = pickle.loads(self.cfg_path.read_bytes()) if self.cfg_path.exists() else {} + + self.cached_images = {} + self.cached_has_ie_polys = {} + + self.initialize_ui() + + # Loader + self.loading_frame = QFrame(self.main_canvas_frame) + self.loading_frame.setAutoFillBackground(True) + self.loading_frame.setFrameShape(QFrame.StyledPanel) + self.loader_label = QLabel() + self.loader_progress_bar = QProgressBar() + + intro_image = QLabel() + intro_image.setPixmap( QPixmap.fromImage(QImageDB.intro) ) + + intro_image_frame_l = QVBoxLayout() + intro_image_frame_l.addWidget(intro_image, alignment=Qt.AlignCenter) + intro_image_frame = QFrame() + intro_image_frame.setSizePolicy (QSizePolicy.Expanding, QSizePolicy.Expanding) + intro_image_frame.setLayout(intro_image_frame_l) + + loading_frame_l = QVBoxLayout() + loading_frame_l.addWidget (intro_image_frame) + loading_frame_l.addWidget (self.loader_label) + loading_frame_l.addWidget (self.loader_progress_bar) + self.loading_frame.setLayout(loading_frame_l) + + self.loader_subprocessor = LoaderQSubprocessor( image_paths=pathex.get_image_paths(input_dirpath, return_Path_class=True), + q_label=self.loader_label, + q_progressbar=self.loader_progress_bar, + on_finish_func=self.on_loader_finish ) + + + def on_loader_finish(self, image_paths, image_paths_has_ie_polys): + self.image_paths_done = [] + self.image_paths = image_paths + self.image_paths_has_ie_polys = image_paths_has_ie_polys + self.set_has_ie_polys_count ( len([ 1 for x in self.image_paths_has_ie_polys if self.image_paths_has_ie_polys[x] == True]) ) + self.loading_frame.hide() + self.loading_frame = None + + self.process_next_image(first_initialization=True) + + def closeEvent(self, ev): + self.cfg_dict['geometry'] = self.saveGeometry().data() + self.cfg_path.write_bytes( pickle.dumps(self.cfg_dict) ) + + + def update_cached_images (self, count=5): + d = self.cached_images + + for image_path in self.image_paths_done[:-count]+self.image_paths[count:]: + if image_path in d: + del d[image_path] + + for image_path in self.image_paths[:count]+self.image_paths_done[-count:]: + if image_path not in d: + img = cv2_imread(image_path) + if img is not None: + d[image_path] = img + + def load_image(self, image_path): + try: + img = self.cached_images.get(image_path, None) + if img is None: + img = cv2_imread(image_path) + self.cached_images[image_path] = img + if img is None: + io.log_err(f'Unable to load {image_path}') + except: + img = None + + return img + + def update_preview_bar(self): + count = self.image_bar.get_preview_images_count() + d = self.cached_images + prev_imgs = [ d.get(image_path, None) for image_path in self.image_paths_done[-1:-count:-1] ] + next_imgs = [ d.get(image_path, None) for image_path in self.image_paths[:count] ] + self.image_bar.update_images(prev_imgs, next_imgs) + + + def canvas_initialize(self, image_path, only_has_polys=False): + if only_has_polys and not self.image_paths_has_ie_polys[image_path]: + return False + + dflimg = DFLIMG.load(image_path) + if not dflimg or not dflimg.has_data(): + return False + + ie_polys = dflimg.get_seg_ie_polys() + xseg_mask = dflimg.get_xseg_mask() + img = self.load_image(image_path) + if img is None: + return False + + self.canvas.op.initialize ( img, ie_polys=ie_polys, xseg_mask=xseg_mask ) + + self.filename_label.setText(f"{image_path.name}") + + return True + + def canvas_finalize(self, image_path): + self.canvas.op.finalize() + + if image_path.exists(): + dflimg = DFLIMG.load(image_path) + ie_polys = dflimg.get_seg_ie_polys() + new_ie_polys = self.canvas.op.get_ie_polys() + + if not new_ie_polys.identical(ie_polys): + new_has_ie_polys = new_ie_polys.has_polys() + self.set_has_ie_polys_count ( self.get_has_ie_polys_count() + (1 if new_has_ie_polys else -1) ) + self.image_paths_has_ie_polys[image_path] = new_has_ie_polys + dflimg.set_seg_ie_polys( new_ie_polys ) + dflimg.save() + + self.filename_label.setText(f"") + + def process_prev_image(self): + key_mods = QApplication.keyboardModifiers() + step = 5 if key_mods == Qt.ShiftModifier else 1 + only_has_polys = key_mods == Qt.ControlModifier + + if self.canvas.op.is_initialized(): + self.canvas_finalize(self.image_paths[0]) + + while True: + for _ in range(step): + if len(self.image_paths_done) != 0: + self.image_paths.insert (0, self.image_paths_done.pop(-1)) + else: + break + if len(self.image_paths) == 0: + break + + ret = self.canvas_initialize(self.image_paths[0], len(self.image_paths_done) != 0 and only_has_polys) + + if ret or len(self.image_paths_done) == 0: + break + + self.update_cached_images() + self.update_preview_bar() + + def process_next_image(self, first_initialization=False): + key_mods = QApplication.keyboardModifiers() + + step = 0 if first_initialization else 5 if key_mods == Qt.ShiftModifier else 1 + only_has_polys = False if first_initialization else key_mods == Qt.ControlModifier + + if self.canvas.op.is_initialized(): + self.canvas_finalize(self.image_paths[0]) + + while True: + for _ in range(step): + if len(self.image_paths) != 0: + self.image_paths_done.append(self.image_paths.pop(0)) + else: + break + if len(self.image_paths) == 0: + break + if self.canvas_initialize(self.image_paths[0], only_has_polys): + break + + self.update_cached_images() + self.update_preview_bar() + + def initialize_ui(self): + + self.canvas = QCanvas() + + image_bar = self.image_bar = ImagePreviewSequenceBar(preview_images_count=9, icon_size=QUIConfig.preview_bar_icon_q_size.width()) + image_bar.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed ) + + + btn_prev_image = QXIconButton(QIconDB.left, QStringDB.btn_prev_image_tip, shortcut='A', click_func=self.process_prev_image) + btn_prev_image.setIconSize(QUIConfig.preview_bar_icon_q_size) + + btn_next_image = QXIconButton(QIconDB.right, QStringDB.btn_next_image_tip, shortcut='D', click_func=self.process_next_image) + btn_next_image.setIconSize(QUIConfig.preview_bar_icon_q_size) + + + preview_image_bar_frame_l = QHBoxLayout() + preview_image_bar_frame_l.setContentsMargins(0,0,0,0) + preview_image_bar_frame_l.addWidget ( btn_prev_image, alignment=Qt.AlignCenter) + preview_image_bar_frame_l.addWidget ( image_bar) + preview_image_bar_frame_l.addWidget ( btn_next_image, alignment=Qt.AlignCenter) + + preview_image_bar_frame = QFrame() + preview_image_bar_frame.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed ) + preview_image_bar_frame.setLayout(preview_image_bar_frame_l) + + preview_image_bar_l = QHBoxLayout() + preview_image_bar_l.addWidget (preview_image_bar_frame) + + preview_image_bar = QFrame() + preview_image_bar.setFrameShape(QFrame.StyledPanel) + preview_image_bar.setSizePolicy ( QSizePolicy.Expanding, QSizePolicy.Fixed ) + preview_image_bar.setLayout(preview_image_bar_l) + + label_font = QFont('Courier New') + self.filename_label = QLabel() + self.filename_label.setFont(label_font) + + self.has_ie_polys_count_label = QLabel() + + status_frame_l = QHBoxLayout() + status_frame_l.setContentsMargins(0,0,0,0) + status_frame_l.addWidget ( QLabel(), alignment=Qt.AlignCenter) + status_frame_l.addWidget (self.filename_label, alignment=Qt.AlignCenter) + status_frame_l.addWidget (self.has_ie_polys_count_label, alignment=Qt.AlignCenter) + status_frame = QFrame() + status_frame.setLayout(status_frame_l) + + main_canvas_l = QVBoxLayout() + main_canvas_l.setContentsMargins(0,0,0,0) + main_canvas_l.addWidget (self.canvas) + main_canvas_l.addWidget (status_frame) + main_canvas_l.addWidget (preview_image_bar) + + self.main_canvas_frame = QFrame() + self.main_canvas_frame.setLayout(main_canvas_l) + + self.main_l = QHBoxLayout() + self.main_l.setContentsMargins(0,0,0,0) + self.main_l.addWidget (self.main_canvas_frame) + + self.setLayout(self.main_l) + + geometry = self.cfg_dict.get('geometry', None) + if geometry is not None: + self.restoreGeometry(geometry) + else: + self.move( QPoint(0,0)) + + def get_has_ie_polys_count(self): + return self.has_ie_polys_count + + def set_has_ie_polys_count(self, c): + self.has_ie_polys_count = c + self.has_ie_polys_count_label.setText(f"{c} {QStringDB.labeled_tip}") + + def resizeEvent(self, ev): + if self.loading_frame is not None: + self.loading_frame.resize( ev.size() ) + if self.help_frame is not None: + self.help_frame.resize( ev.size() ) + +def start(input_dirpath): + """ + returns exit_code + """ + io.log_info("Running XSeg editor.") + + if PackedFaceset.path_contains(input_dirpath): + io.log_info (f'\n{input_dirpath} contains packed faceset! Unpack it first.\n') + return 1 + + root_path = Path(__file__).parent + cfg_root_path = Path(tempfile.gettempdir()) + + QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True) + QApplication.setAttribute(Qt.AA_UseHighDpiPixmaps, True) + + app = QApplication([]) + app.setApplicationName("XSegEditor") + app.setStyle('Fusion') + + QFontDatabase.addApplicationFont( str(root_path / 'gfx' / 'fonts' / 'NotoSans-Medium.ttf') ) + + app.setFont( QFont('NotoSans')) + + QUIConfig.initialize() + QStringDB.initialize() + + QIconDB.initialize( root_path / 'gfx' / 'icons' ) + QCursorDB.initialize( root_path / 'gfx' / 'cursors' ) + QImageDB.initialize( root_path / 'gfx' / 'images' ) + + app.setWindowIcon(QIconDB.app_icon) + app.setPalette( QDarkPalette() ) + + win = MainWindow( input_dirpath=input_dirpath, cfg_root_path=cfg_root_path) + + win.show() + win.raise_() + + app.exec_() + return 0 diff --git a/XSegEditor/XSegEditor.py b/XSegEditor/XSegEditor.py new file mode 100644 index 0000000000000000000000000000000000000000..affc9f6a57cf80d25445f2fb33920ee635f81673 --- /dev/null +++ b/XSegEditor/XSegEditor.py @@ -0,0 +1,1494 @@ +import json +import multiprocessing +import os +import pickle +import sys +import tempfile +import time +import traceback +from enum import IntEnum +from types import SimpleNamespace as sn + +import cv2 +import numpy as np +import numpy.linalg as npla +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * + +from core import imagelib, pathex +from core.cv2ex import * +from core.imagelib import SegIEPoly, SegIEPolys, SegIEPolyType, sd +from core.qtex import * +from DFLIMG import * +from localization import StringsDB, system_language +from samplelib import PackedFaceset + +from .QCursorDB import QCursorDB +from .QIconDB import QIconDB +from .QStringDB import QStringDB +from .QImageDB import QImageDB + +class OpMode(IntEnum): + NONE = 0 + DRAW_PTS = 1 + EDIT_PTS = 2 + VIEW_BAKED = 3 + VIEW_XSEG_MASK = 4 + +class PTEditMode(IntEnum): + MOVE = 0 + ADD_DEL = 1 + +class DragType(IntEnum): + NONE = 0 + IMAGE_LOOK = 1 + POLY_PT = 2 + +class ViewLock(IntEnum): + NONE = 0 + CENTER = 1 + +class QUIConfig(): + @staticmethod + def initialize(icon_size = 48, icon_spacer_size=16, preview_bar_icon_size=64): + QUIConfig.icon_q_size = QSize(icon_size, icon_size) + QUIConfig.icon_spacer_q_size = QSize(icon_spacer_size, icon_spacer_size) + QUIConfig.preview_bar_icon_q_size = QSize(preview_bar_icon_size, preview_bar_icon_size) + +class ImagePreviewSequenceBar(QFrame): + def __init__(self, preview_images_count, icon_size): + super().__init__() + self.preview_images_count = preview_images_count = max(1, preview_images_count + (preview_images_count % 2 -1) ) + + self.icon_size = icon_size + + black_q_img = QImage(np.zeros( (icon_size,icon_size,3) ).data, icon_size, icon_size, 3*icon_size, QImage.Format_RGB888) + self.black_q_pixmap = QPixmap.fromImage(black_q_img) + + self.image_containers = [ QLabel() for i in range(preview_images_count)] + + main_frame_l_cont_hl = QGridLayout() + main_frame_l_cont_hl.setContentsMargins(0,0,0,0) + #main_frame_l_cont_hl.setSpacing(0) + + + + for i in range(len(self.image_containers)): + q_label = self.image_containers[i] + q_label.setScaledContents(True) + if i == preview_images_count//2: + q_label.setMinimumSize(icon_size+16, icon_size+16 ) + q_label.setMaximumSize(icon_size+16, icon_size+16 ) + else: + q_label.setMinimumSize(icon_size, icon_size ) + q_label.setMaximumSize(icon_size, icon_size ) + opacity_effect = QGraphicsOpacityEffect() + opacity_effect.setOpacity(0.5) + q_label.setGraphicsEffect(opacity_effect) + + q_label.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + + main_frame_l_cont_hl.addWidget (q_label, 0, i) + + self.setLayout(main_frame_l_cont_hl) + + self.prev_img_conts = self.image_containers[(preview_images_count//2) -1::-1] + self.next_img_conts = self.image_containers[preview_images_count//2:] + + self.update_images() + + def get_preview_images_count(self): + return self.preview_images_count + + def update_images(self, prev_imgs=None, next_imgs=None): + # Fix arrays + if prev_imgs is None: + prev_imgs = [] + prev_img_conts_len = len(self.prev_img_conts) + prev_q_imgs_len = len(prev_imgs) + if prev_q_imgs_len < prev_img_conts_len: + for i in range ( prev_img_conts_len - prev_q_imgs_len ): + prev_imgs.append(None) + elif prev_q_imgs_len > prev_img_conts_len: + prev_imgs = prev_imgs[:prev_img_conts_len] + + if next_imgs is None: + next_imgs = [] + next_img_conts_len = len(self.next_img_conts) + next_q_imgs_len = len(next_imgs) + if next_q_imgs_len < next_img_conts_len: + for i in range ( next_img_conts_len - next_q_imgs_len ): + next_imgs.append(None) + elif next_q_imgs_len > next_img_conts_len: + next_imgs = next_imgs[:next_img_conts_len] + + for i,img in enumerate(prev_imgs): + self.prev_img_conts[i].setPixmap( QPixmap.fromImage( QImage_from_np(img) ) if img is not None else self.black_q_pixmap ) + + for i,img in enumerate(next_imgs): + self.next_img_conts[i].setPixmap( QPixmap.fromImage( QImage_from_np(img) ) if img is not None else self.black_q_pixmap ) + +class ColorScheme(): + def __init__(self, unselected_color, selected_color, outline_color, outline_width, pt_outline_color, cross_cursor): + self.poly_unselected_brush = QBrush(unselected_color) + self.poly_selected_brush = QBrush(selected_color) + + self.poly_outline_solid_pen = QPen(outline_color, outline_width, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin) + self.poly_outline_dot_pen = QPen(outline_color, outline_width, Qt.DotLine, Qt.RoundCap, Qt.RoundJoin) + + self.pt_outline_pen = QPen(pt_outline_color) + self.cross_cursor = cross_cursor + +class CanvasConfig(): + + def __init__(self, + pt_radius=4, + pt_select_radius=8, + color_schemes=None, + **kwargs): + self.pt_radius = pt_radius + self.pt_select_radius = pt_select_radius + + if color_schemes is None: + color_schemes = [ + ColorScheme( QColor(192,0,0,alpha=0), QColor(192,0,0,alpha=72), QColor(192,0,0), 2, QColor(255,255,255), QCursorDB.cross_red ), + ColorScheme( QColor(0,192,0,alpha=0), QColor(0,192,0,alpha=72), QColor(0,192,0), 2, QColor(255,255,255), QCursorDB.cross_green ), + ColorScheme( QColor(0,0,192,alpha=0), QColor(0,0,192,alpha=72), QColor(0,0,192), 2, QColor(255,255,255), QCursorDB.cross_blue ), + ] + self.color_schemes = color_schemes + +class QCanvasControlsLeftBar(QFrame): + + def __init__(self): + super().__init__() + #============================================== + btn_poly_type_include = QToolButton() + self.btn_poly_type_include_act = QActionEx( QIconDB.poly_type_include, QStringDB.btn_poly_type_include_tip, shortcut='Q', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_type_include.setDefaultAction(self.btn_poly_type_include_act) + btn_poly_type_include.setIconSize(QUIConfig.icon_q_size) + + btn_poly_type_exclude = QToolButton() + self.btn_poly_type_exclude_act = QActionEx( QIconDB.poly_type_exclude, QStringDB.btn_poly_type_exclude_tip, shortcut='W', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_type_exclude.setDefaultAction(self.btn_poly_type_exclude_act) + btn_poly_type_exclude.setIconSize(QUIConfig.icon_q_size) + + self.btn_poly_type_act_grp = QActionGroup (self) + self.btn_poly_type_act_grp.addAction(self.btn_poly_type_include_act) + self.btn_poly_type_act_grp.addAction(self.btn_poly_type_exclude_act) + self.btn_poly_type_act_grp.setExclusive(True) + #============================================== + btn_undo_pt = QToolButton() + self.btn_undo_pt_act = QActionEx( QIconDB.undo_pt, QStringDB.btn_undo_pt_tip, shortcut='Ctrl+Z', shortcut_in_tooltip=True, is_auto_repeat=True) + btn_undo_pt.setDefaultAction(self.btn_undo_pt_act) + btn_undo_pt.setIconSize(QUIConfig.icon_q_size) + + btn_redo_pt = QToolButton() + self.btn_redo_pt_act = QActionEx( QIconDB.redo_pt, QStringDB.btn_redo_pt_tip, shortcut='Ctrl+Shift+Z', shortcut_in_tooltip=True, is_auto_repeat=True) + btn_redo_pt.setDefaultAction(self.btn_redo_pt_act) + btn_redo_pt.setIconSize(QUIConfig.icon_q_size) + + btn_delete_poly = QToolButton() + self.btn_delete_poly_act = QActionEx( QIconDB.delete_poly, QStringDB.btn_delete_poly_tip, shortcut='Delete', shortcut_in_tooltip=True) + btn_delete_poly.setDefaultAction(self.btn_delete_poly_act) + btn_delete_poly.setIconSize(QUIConfig.icon_q_size) + #============================================== + btn_pt_edit_mode = QToolButton() + self.btn_pt_edit_mode_act = QActionEx( QIconDB.pt_edit_mode, QStringDB.btn_pt_edit_mode_tip, shortcut_in_tooltip=True, is_checkable=True) + btn_pt_edit_mode.setDefaultAction(self.btn_pt_edit_mode_act) + btn_pt_edit_mode.setIconSize(QUIConfig.icon_q_size) + #============================================== + + controls_bar_frame2_l = QVBoxLayout() + controls_bar_frame2_l.addWidget ( btn_poly_type_include ) + controls_bar_frame2_l.addWidget ( btn_poly_type_exclude ) + controls_bar_frame2 = QFrame() + controls_bar_frame2.setFrameShape(QFrame.StyledPanel) + controls_bar_frame2.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame2.setLayout(controls_bar_frame2_l) + + controls_bar_frame3_l = QVBoxLayout() + controls_bar_frame3_l.addWidget ( btn_undo_pt ) + controls_bar_frame3_l.addWidget ( btn_redo_pt ) + controls_bar_frame3_l.addWidget ( btn_delete_poly ) + controls_bar_frame3 = QFrame() + controls_bar_frame3.setFrameShape(QFrame.StyledPanel) + controls_bar_frame3.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame3.setLayout(controls_bar_frame3_l) + + controls_bar_frame4_l = QVBoxLayout() + controls_bar_frame4_l.addWidget ( btn_pt_edit_mode ) + controls_bar_frame4 = QFrame() + controls_bar_frame4.setFrameShape(QFrame.StyledPanel) + controls_bar_frame4.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame4.setLayout(controls_bar_frame4_l) + + controls_bar_l = QVBoxLayout() + controls_bar_l.setContentsMargins(0,0,0,0) + controls_bar_l.addWidget(controls_bar_frame2) + controls_bar_l.addWidget(controls_bar_frame3) + controls_bar_l.addWidget(controls_bar_frame4) + + self.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Expanding ) + self.setLayout(controls_bar_l) + +class QCanvasControlsRightBar(QFrame): + + def __init__(self): + super().__init__() + #============================================== + btn_poly_color_red = QToolButton() + self.btn_poly_color_red_act = QActionEx( QIconDB.poly_color_red, QStringDB.btn_poly_color_red_tip, shortcut='1', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_color_red.setDefaultAction(self.btn_poly_color_red_act) + btn_poly_color_red.setIconSize(QUIConfig.icon_q_size) + + btn_poly_color_green = QToolButton() + self.btn_poly_color_green_act = QActionEx( QIconDB.poly_color_green, QStringDB.btn_poly_color_green_tip, shortcut='2', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_color_green.setDefaultAction(self.btn_poly_color_green_act) + btn_poly_color_green.setIconSize(QUIConfig.icon_q_size) + + btn_poly_color_blue = QToolButton() + self.btn_poly_color_blue_act = QActionEx( QIconDB.poly_color_blue, QStringDB.btn_poly_color_blue_tip, shortcut='3', shortcut_in_tooltip=True, is_checkable=True) + btn_poly_color_blue.setDefaultAction(self.btn_poly_color_blue_act) + btn_poly_color_blue.setIconSize(QUIConfig.icon_q_size) + + btn_view_baked_mask = QToolButton() + self.btn_view_baked_mask_act = QActionEx( QIconDB.view_baked, QStringDB.btn_view_baked_mask_tip, shortcut='4', shortcut_in_tooltip=True, is_checkable=True) + btn_view_baked_mask.setDefaultAction(self.btn_view_baked_mask_act) + btn_view_baked_mask.setIconSize(QUIConfig.icon_q_size) + + btn_view_xseg_mask = QToolButton() + self.btn_view_xseg_mask_act = QActionEx( QIconDB.view_xseg, QStringDB.btn_view_xseg_mask_tip, shortcut='5', shortcut_in_tooltip=True, is_checkable=True) + btn_view_xseg_mask.setDefaultAction(self.btn_view_xseg_mask_act) + btn_view_xseg_mask.setIconSize(QUIConfig.icon_q_size) + + btn_view_xseg_overlay_mask = QToolButton() + self.btn_view_xseg_overlay_mask_act = QActionEx( QIconDB.view_xseg_overlay, QStringDB.btn_view_xseg_overlay_mask_tip, shortcut='`', shortcut_in_tooltip=True, is_checkable=True) + btn_view_xseg_overlay_mask.setDefaultAction(self.btn_view_xseg_overlay_mask_act) + btn_view_xseg_overlay_mask.setIconSize(QUIConfig.icon_q_size) + + self.btn_poly_color_act_grp = QActionGroup (self) + self.btn_poly_color_act_grp.addAction(self.btn_poly_color_red_act) + self.btn_poly_color_act_grp.addAction(self.btn_poly_color_green_act) + self.btn_poly_color_act_grp.addAction(self.btn_poly_color_blue_act) + self.btn_poly_color_act_grp.addAction(self.btn_view_baked_mask_act) + self.btn_poly_color_act_grp.addAction(self.btn_view_xseg_mask_act) + self.btn_poly_color_act_grp.setExclusive(True) + #============================================== + btn_view_lock_center = QToolButton() + self.btn_view_lock_center_act = QActionEx( QIconDB.view_lock_center, QStringDB.btn_view_lock_center_tip, shortcut_in_tooltip=True, is_checkable=True) + btn_view_lock_center.setDefaultAction(self.btn_view_lock_center_act) + btn_view_lock_center.setIconSize(QUIConfig.icon_q_size) + + controls_bar_frame2_l = QVBoxLayout() + controls_bar_frame2_l.addWidget ( btn_view_xseg_overlay_mask ) + controls_bar_frame2 = QFrame() + controls_bar_frame2.setFrameShape(QFrame.StyledPanel) + controls_bar_frame2.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame2.setLayout(controls_bar_frame2_l) + + controls_bar_frame1_l = QVBoxLayout() + controls_bar_frame1_l.addWidget ( btn_poly_color_red ) + controls_bar_frame1_l.addWidget ( btn_poly_color_green ) + controls_bar_frame1_l.addWidget ( btn_poly_color_blue ) + controls_bar_frame1_l.addWidget ( btn_view_baked_mask ) + controls_bar_frame1_l.addWidget ( btn_view_xseg_mask ) + controls_bar_frame1 = QFrame() + controls_bar_frame1.setFrameShape(QFrame.StyledPanel) + controls_bar_frame1.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame1.setLayout(controls_bar_frame1_l) + + controls_bar_frame3_l = QVBoxLayout() + controls_bar_frame3_l.addWidget ( btn_view_lock_center ) + controls_bar_frame3 = QFrame() + controls_bar_frame3.setFrameShape(QFrame.StyledPanel) + controls_bar_frame3.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed) + controls_bar_frame3.setLayout(controls_bar_frame3_l) + + controls_bar_l = QVBoxLayout() + controls_bar_l.setContentsMargins(0,0,0,0) + controls_bar_l.addWidget(controls_bar_frame2) + controls_bar_l.addWidget(controls_bar_frame1) + controls_bar_l.addWidget(controls_bar_frame3) + + self.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Expanding ) + self.setLayout(controls_bar_l) + +class QCanvasOperator(QWidget): + def __init__(self, cbar): + super().__init__() + self.cbar = cbar + + self.set_cbar_disabled() + + self.cbar.btn_poly_color_red_act.triggered.connect ( lambda : self.set_color_scheme_id(0) ) + self.cbar.btn_poly_color_green_act.triggered.connect ( lambda : self.set_color_scheme_id(1) ) + self.cbar.btn_poly_color_blue_act.triggered.connect ( lambda : self.set_color_scheme_id(2) ) + self.cbar.btn_view_baked_mask_act.triggered.connect ( lambda : self.set_op_mode(OpMode.VIEW_BAKED) ) + self.cbar.btn_view_xseg_mask_act.triggered.connect ( lambda : self.set_op_mode(OpMode.VIEW_XSEG_MASK) ) + + self.cbar.btn_view_xseg_overlay_mask_act.toggled.connect ( lambda is_checked: self.update() ) + + self.cbar.btn_poly_type_include_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.INCLUDE) ) + self.cbar.btn_poly_type_exclude_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.EXCLUDE) ) + + self.cbar.btn_undo_pt_act.triggered.connect ( lambda : self.action_undo_pt() ) + self.cbar.btn_redo_pt_act.triggered.connect ( lambda : self.action_redo_pt() ) + + self.cbar.btn_delete_poly_act.triggered.connect ( lambda : self.action_delete_poly() ) + + self.cbar.btn_pt_edit_mode_act.toggled.connect ( lambda is_checked: self.set_pt_edit_mode( PTEditMode.ADD_DEL if is_checked else PTEditMode.MOVE ) ) + self.cbar.btn_view_lock_center_act.toggled.connect ( lambda is_checked: self.set_view_lock( ViewLock.CENTER if is_checked else ViewLock.NONE ) ) + + self.mouse_in_widget = False + + QXMainWindow.inst.add_keyPressEvent_listener ( self.on_keyPressEvent ) + QXMainWindow.inst.add_keyReleaseEvent_listener ( self.on_keyReleaseEvent ) + + self.qp = QPainter() + self.initialized = False + self.last_state = None + + def initialize(self, img, img_look_pt=None, view_scale=None, ie_polys=None, xseg_mask=None, canvas_config=None ): + q_img = self.q_img = QImage_from_np(img) + self.img_pixmap = QPixmap.fromImage(q_img) + + self.xseg_mask_pixmap = None + self.xseg_overlay_mask_pixmap = None + if xseg_mask is not None: + h,w,c = img.shape + xseg_mask = cv2.resize(xseg_mask, (w,h), interpolation=cv2.INTER_CUBIC) + xseg_mask = imagelib.normalize_channels(xseg_mask, 1) + xseg_img = img.astype(np.float32)/255.0 + xseg_overlay_mask = xseg_img*(1-xseg_mask)*0.5 + xseg_img*xseg_mask + xseg_overlay_mask = np.clip(xseg_overlay_mask*255, 0, 255).astype(np.uint8) + xseg_mask = np.clip(xseg_mask*255, 0, 255).astype(np.uint8) + self.xseg_mask_pixmap = QPixmap.fromImage(QImage_from_np(xseg_mask)) + self.xseg_overlay_mask_pixmap = QPixmap.fromImage(QImage_from_np(xseg_overlay_mask)) + + self.img_size = QSize_to_np (self.img_pixmap.size()) + + self.img_look_pt = img_look_pt + self.view_scale = view_scale + + if ie_polys is None: + ie_polys = SegIEPolys() + self.ie_polys = ie_polys + + if canvas_config is None: + canvas_config = CanvasConfig() + self.canvas_config = canvas_config + + # UI init + self.set_cbar_disabled() + self.cbar.btn_poly_color_act_grp.setDisabled(False) + self.cbar.btn_view_xseg_overlay_mask_act.setDisabled(False) + self.cbar.btn_poly_type_act_grp.setDisabled(False) + + # Initial vars + self.current_cursor = None + self.mouse_hull_poly = None + self.mouse_wire_poly = None + self.drag_type = DragType.NONE + self.mouse_cli_pt = np.zeros((2,), np.float32 ) + + # Initial state + self.set_op_mode(OpMode.NONE) + self.set_color_scheme_id(1) + self.set_poly_include_type(SegIEPolyType.INCLUDE) + self.set_pt_edit_mode(PTEditMode.MOVE) + self.set_view_lock(ViewLock.NONE) + + # Apply last state + if self.last_state is not None: + self.set_color_scheme_id(self.last_state.color_scheme_id) + if self.last_state.op_mode is not None: + self.set_op_mode(self.last_state.op_mode) + + self.initialized = True + + self.setMouseTracking(True) + self.update_cursor() + self.update() + + + def finalize(self): + if self.initialized: + if self.op_mode == OpMode.DRAW_PTS: + self.set_op_mode(OpMode.EDIT_PTS) + + self.last_state = sn(op_mode = self.op_mode if self.op_mode in [OpMode.VIEW_BAKED, OpMode.VIEW_XSEG_MASK] else None, + color_scheme_id = self.color_scheme_id) + + self.img_pixmap = None + self.update_cursor(is_finalize=True) + self.setMouseTracking(False) + self.setFocusPolicy(Qt.NoFocus) + self.set_cbar_disabled() + self.initialized = False + self.update() + + # ==================================================================================== + # ==================================================================================== + # ====================================== GETTERS ===================================== + # ==================================================================================== + # ==================================================================================== + def is_initialized(self): + return self.initialized + + def get_ie_polys(self): + return self.ie_polys + + def get_cli_center_pt(self): + return np.round(QSize_to_np(self.size())/2.0) + + def get_img_look_pt(self): + img_look_pt = self.img_look_pt + if img_look_pt is None: + img_look_pt = self.img_size / 2 + return img_look_pt + + def get_view_scale(self): + view_scale = self.view_scale + if view_scale is None: + # Calc as scale to fit + min_cli_size = np.min(QSize_to_np(self.size())) + max_img_size = np.max(self.img_size) + view_scale = min_cli_size / max_img_size + + return view_scale + + def get_current_color_scheme(self): + return self.canvas_config.color_schemes[self.color_scheme_id] + + def get_poly_pt_id_under_pt(self, poly, cli_pt): + w = np.argwhere ( npla.norm ( cli_pt - self.img_to_cli_pt( poly.get_pts() ), axis=1 ) <= self.canvas_config.pt_select_radius ) + return None if len(w) == 0 else w[-1][0] + + def get_poly_edge_id_pt_under_pt(self, poly, cli_pt): + cli_pts = self.img_to_cli_pt(poly.get_pts()) + if len(cli_pts) >= 3: + edge_dists, projs = sd.dist_to_edges(cli_pts, cli_pt, is_closed=True) + edge_id = np.argmin(edge_dists) + dist = edge_dists[edge_id] + pt = projs[edge_id] + if dist <= self.canvas_config.pt_select_radius: + return edge_id, pt + return None, None + + def get_poly_by_pt_near_wire(self, cli_pt): + pt_select_radius = self.canvas_config.pt_select_radius + + for poly in reversed(self.ie_polys.get_polys()): + pts = poly.get_pts() + if len(pts) >= 3: + cli_pts = self.img_to_cli_pt(pts) + + edge_dists, _ = sd.dist_to_edges(cli_pts, cli_pt, is_closed=True) + + if np.min(edge_dists) <= pt_select_radius or \ + any( npla.norm ( cli_pt - cli_pts, axis=1 ) <= pt_select_radius ): + return poly + return None + + def get_poly_by_pt_in_hull(self, cli_pos): + img_pos = self.cli_to_img_pt(cli_pos) + + for poly in reversed(self.ie_polys.get_polys()): + pts = poly.get_pts() + if len(pts) >= 3: + if cv2.pointPolygonTest( pts, tuple(img_pos), False) >= 0: + return poly + + return None + + def img_to_cli_pt(self, p): + return (p - self.get_img_look_pt()) * self.get_view_scale() + self.get_cli_center_pt()# QSize_to_np(self.size())/2.0 + + def cli_to_img_pt(self, p): + return (p - self.get_cli_center_pt() ) / self.get_view_scale() + self.get_img_look_pt() + + def img_to_cli_rect(self, rect): + tl = QPoint_to_np(rect.topLeft()) + xy = self.img_to_cli_pt(tl) + xy2 = self.img_to_cli_pt(tl + QSize_to_np(rect.size()) ) - xy + return QRect ( *xy.astype(np.int), *xy2.astype(np.int) ) + + # ==================================================================================== + # ==================================================================================== + # ====================================== SETTERS ===================================== + # ==================================================================================== + # ==================================================================================== + def set_op_mode(self, op_mode, op_poly=None): + if not hasattr(self,'op_mode'): + self.op_mode = None + self.op_poly = None + + if self.op_mode != op_mode: + # Finalize prev mode + if self.op_mode == OpMode.NONE: + self.cbar.btn_poly_type_act_grp.setDisabled(True) + elif self.op_mode == OpMode.DRAW_PTS: + self.cbar.btn_undo_pt_act.setDisabled(True) + self.cbar.btn_redo_pt_act.setDisabled(True) + self.cbar.btn_view_lock_center_act.setDisabled(True) + # Reset view_lock when exit from DRAW_PTS + self.set_view_lock(ViewLock.NONE) + # Remove unfinished poly + if self.op_poly.get_pts_count() < 3: + self.ie_polys.remove_poly(self.op_poly) + + elif self.op_mode == OpMode.EDIT_PTS: + self.cbar.btn_pt_edit_mode_act.setDisabled(True) + self.cbar.btn_delete_poly_act.setDisabled(True) + # Reset pt_edit_move when exit from EDIT_PTS + self.set_pt_edit_mode(PTEditMode.MOVE) + elif self.op_mode == OpMode.VIEW_BAKED: + self.cbar.btn_view_baked_mask_act.setChecked(False) + elif self.op_mode == OpMode.VIEW_XSEG_MASK: + self.cbar.btn_view_xseg_mask_act.setChecked(False) + + self.op_mode = op_mode + + # Initialize new mode + if op_mode == OpMode.NONE: + self.cbar.btn_poly_type_act_grp.setDisabled(False) + elif op_mode == OpMode.DRAW_PTS: + self.cbar.btn_undo_pt_act.setDisabled(False) + self.cbar.btn_redo_pt_act.setDisabled(False) + self.cbar.btn_view_lock_center_act.setDisabled(False) + elif op_mode == OpMode.EDIT_PTS: + self.cbar.btn_pt_edit_mode_act.setDisabled(False) + self.cbar.btn_delete_poly_act.setDisabled(False) + elif op_mode == OpMode.VIEW_BAKED: + self.cbar.btn_view_baked_mask_act.setChecked(True ) + n = QImage_to_np ( self.q_img ).astype(np.float32) / 255.0 + h,w,c = n.shape + mask = np.zeros( (h,w,1), dtype=np.float32 ) + self.ie_polys.overlay_mask(mask) + n = (mask*255).astype(np.uint8) + self.img_baked_pixmap = QPixmap.fromImage(QImage_from_np(n)) + elif op_mode == OpMode.VIEW_XSEG_MASK: + self.cbar.btn_view_xseg_mask_act.setChecked(True) + + if op_mode in [OpMode.DRAW_PTS, OpMode.EDIT_PTS]: + self.mouse_op_poly_pt_id = None + self.mouse_op_poly_edge_id = None + self.mouse_op_poly_edge_id_pt = None + + self.op_poly = op_poly + if op_poly is not None: + self.update_mouse_info() + + self.update_cursor() + self.update() + + def set_pt_edit_mode(self, pt_edit_mode): + if not hasattr(self, 'pt_edit_mode') or self.pt_edit_mode != pt_edit_mode: + self.pt_edit_mode = pt_edit_mode + self.update_cursor() + self.update() + self.cbar.btn_pt_edit_mode_act.setChecked( self.pt_edit_mode == PTEditMode.ADD_DEL ) + + def set_view_lock(self, view_lock): + if not hasattr(self, 'view_lock') or self.view_lock != view_lock: + if hasattr(self, 'view_lock') and self.view_lock != view_lock: + if view_lock == ViewLock.CENTER: + self.img_look_pt = self.mouse_img_pt + QCursor.setPos ( self.mapToGlobal( QPoint_from_np(self.img_to_cli_pt(self.img_look_pt)) )) + + self.view_lock = view_lock + self.update() + self.cbar.btn_view_lock_center_act.setChecked( self.view_lock == ViewLock.CENTER ) + + def set_cbar_disabled(self): + self.cbar.btn_delete_poly_act.setDisabled(True) + self.cbar.btn_undo_pt_act.setDisabled(True) + self.cbar.btn_redo_pt_act.setDisabled(True) + self.cbar.btn_pt_edit_mode_act.setDisabled(True) + self.cbar.btn_view_lock_center_act.setDisabled(True) + self.cbar.btn_poly_color_act_grp.setDisabled(True) + self.cbar.btn_view_xseg_overlay_mask_act.setDisabled(True) + self.cbar.btn_poly_type_act_grp.setDisabled(True) + + + def set_color_scheme_id(self, id): + if self.op_mode == OpMode.VIEW_BAKED or self.op_mode == OpMode.VIEW_XSEG_MASK: + self.set_op_mode(OpMode.NONE) + + if not hasattr(self, 'color_scheme_id') or self.color_scheme_id != id: + self.color_scheme_id = id + self.update_cursor() + self.update() + + if self.color_scheme_id == 0: + self.cbar.btn_poly_color_red_act.setChecked( True ) + elif self.color_scheme_id == 1: + self.cbar.btn_poly_color_green_act.setChecked( True ) + elif self.color_scheme_id == 2: + self.cbar.btn_poly_color_blue_act.setChecked( True ) + + def set_poly_include_type(self, poly_include_type): + if not hasattr(self, 'poly_include_type' ) or \ + ( self.poly_include_type != poly_include_type and \ + self.op_mode in [OpMode.NONE, OpMode.EDIT_PTS] ): + self.poly_include_type = poly_include_type + self.update() + self.cbar.btn_poly_type_include_act.setChecked(self.poly_include_type == SegIEPolyType.INCLUDE) + self.cbar.btn_poly_type_exclude_act.setChecked(self.poly_include_type == SegIEPolyType.EXCLUDE) + + # ==================================================================================== + # ==================================================================================== + # ====================================== METHODS ===================================== + # ==================================================================================== + # ==================================================================================== + + def update_cursor(self, is_finalize=False): + if not self.initialized: + return + + if not self.mouse_in_widget or is_finalize: + if self.current_cursor is not None: + QApplication.restoreOverrideCursor() + self.current_cursor = None + else: + color_cc = self.get_current_color_scheme().cross_cursor + nc = Qt.ArrowCursor + + if self.drag_type == DragType.IMAGE_LOOK: + nc = Qt.ClosedHandCursor + else: + + if self.op_mode == OpMode.NONE: + nc = color_cc + if self.mouse_wire_poly is not None: + nc = Qt.PointingHandCursor + + elif self.op_mode == OpMode.DRAW_PTS: + nc = color_cc + elif self.op_mode == OpMode.EDIT_PTS: + nc = Qt.ArrowCursor + + if self.mouse_op_poly_pt_id is not None: + nc = Qt.PointingHandCursor + + if self.pt_edit_mode == PTEditMode.ADD_DEL: + + if self.mouse_op_poly_edge_id is not None and \ + self.mouse_op_poly_pt_id is None: + nc = color_cc + if self.current_cursor != nc: + if self.current_cursor is None: + QApplication.setOverrideCursor(nc) + else: + QApplication.changeOverrideCursor(nc) + self.current_cursor = nc + + def update_mouse_info(self, mouse_cli_pt=None): + """ + Update selected polys/edges/points by given mouse position + """ + if mouse_cli_pt is not None: + self.mouse_cli_pt = mouse_cli_pt.astype(np.float32) + + self.mouse_img_pt = self.cli_to_img_pt(self.mouse_cli_pt) + + new_mouse_hull_poly = self.get_poly_by_pt_in_hull(self.mouse_cli_pt) + + if self.mouse_hull_poly != new_mouse_hull_poly: + self.mouse_hull_poly = new_mouse_hull_poly + self.update_cursor() + self.update() + + new_mouse_wire_poly = self.get_poly_by_pt_near_wire(self.mouse_cli_pt) + + if self.mouse_wire_poly != new_mouse_wire_poly: + self.mouse_wire_poly = new_mouse_wire_poly + self.update_cursor() + self.update() + + if self.op_mode in [OpMode.DRAW_PTS, OpMode.EDIT_PTS]: + new_mouse_op_poly_pt_id = self.get_poly_pt_id_under_pt (self.op_poly, self.mouse_cli_pt) + if self.mouse_op_poly_pt_id != new_mouse_op_poly_pt_id: + self.mouse_op_poly_pt_id = new_mouse_op_poly_pt_id + self.update_cursor() + self.update() + + new_mouse_op_poly_edge_id,\ + new_mouse_op_poly_edge_id_pt = self.get_poly_edge_id_pt_under_pt (self.op_poly, self.mouse_cli_pt) + if self.mouse_op_poly_edge_id != new_mouse_op_poly_edge_id: + self.mouse_op_poly_edge_id = new_mouse_op_poly_edge_id + self.update_cursor() + self.update() + + if (self.mouse_op_poly_edge_id_pt.__class__ != new_mouse_op_poly_edge_id_pt.__class__) or \ + (isinstance(self.mouse_op_poly_edge_id_pt, np.ndarray) and \ + all(self.mouse_op_poly_edge_id_pt != new_mouse_op_poly_edge_id_pt)): + + self.mouse_op_poly_edge_id_pt = new_mouse_op_poly_edge_id_pt + self.update_cursor() + self.update() + + + def action_undo_pt(self): + if self.drag_type == DragType.NONE: + if self.op_mode == OpMode.DRAW_PTS: + if self.op_poly.undo() == 0: + self.ie_polys.remove_poly (self.op_poly) + self.set_op_mode(OpMode.NONE) + self.update() + + def action_redo_pt(self): + if self.drag_type == DragType.NONE: + if self.op_mode == OpMode.DRAW_PTS: + self.op_poly.redo() + self.update() + + def action_delete_poly(self): + if self.op_mode == OpMode.EDIT_PTS and \ + self.drag_type == DragType.NONE and \ + self.pt_edit_mode == PTEditMode.MOVE: + # Delete current poly + self.ie_polys.remove_poly (self.op_poly) + self.set_op_mode(OpMode.NONE) + + # ==================================================================================== + # ==================================================================================== + # ================================== OVERRIDE QT METHODS ============================= + # ==================================================================================== + # ==================================================================================== + def on_keyPressEvent(self, ev): + if not self.initialized: + return + key = ev.key() + key_mods = int(ev.modifiers()) + if self.op_mode == OpMode.DRAW_PTS: + self.set_view_lock(ViewLock.CENTER if key_mods == Qt.ShiftModifier else ViewLock.NONE ) + elif self.op_mode == OpMode.EDIT_PTS: + self.set_pt_edit_mode(PTEditMode.ADD_DEL if key_mods == Qt.ControlModifier else PTEditMode.MOVE ) + + def on_keyReleaseEvent(self, ev): + if not self.initialized: + return + key = ev.key() + key_mods = int(ev.modifiers()) + if self.op_mode == OpMode.DRAW_PTS: + self.set_view_lock(ViewLock.CENTER if key_mods == Qt.ShiftModifier else ViewLock.NONE ) + elif self.op_mode == OpMode.EDIT_PTS: + self.set_pt_edit_mode(PTEditMode.ADD_DEL if key_mods == Qt.ControlModifier else PTEditMode.MOVE ) + + def enterEvent(self, ev): + super().enterEvent(ev) + self.mouse_in_widget = True + self.update_cursor() + + def leaveEvent(self, ev): + super().leaveEvent(ev) + self.mouse_in_widget = False + self.update_cursor() + + def mousePressEvent(self, ev): + super().mousePressEvent(ev) + if not self.initialized: + return + + self.update_mouse_info(QPoint_to_np(ev.pos())) + + btn = ev.button() + + if btn == Qt.LeftButton: + if self.op_mode == OpMode.NONE: + # Clicking in NO OPERATION mode + if self.mouse_wire_poly is not None: + # Click on wire on any poly -> switch to EDIT_MODE + self.set_op_mode(OpMode.EDIT_PTS, op_poly=self.mouse_wire_poly) + else: + # Click on empty space -> create new poly with one point + new_poly = self.ie_polys.add_poly(self.poly_include_type) + self.ie_polys.sort() + new_poly.add_pt(*self.mouse_img_pt) + self.set_op_mode(OpMode.DRAW_PTS, op_poly=new_poly ) + + elif self.op_mode == OpMode.DRAW_PTS: + # Clicking in DRAW_PTS mode + if len(self.op_poly.get_pts()) >= 3 and self.mouse_op_poly_pt_id == 0: + # Click on first point -> close poly and switch to edit mode + self.set_op_mode(OpMode.EDIT_PTS, op_poly=self.op_poly) + else: + # Click on empty space -> add point to current poly + self.op_poly.add_pt(*self.mouse_img_pt) + self.update() + + elif self.op_mode == OpMode.EDIT_PTS: + # Clicking in EDIT_PTS mode + + if self.mouse_op_poly_pt_id is not None: + # Click on point of op_poly + if self.pt_edit_mode == PTEditMode.ADD_DEL: + # in mode 'delete point' + self.op_poly.remove_pt(self.mouse_op_poly_pt_id) + if self.op_poly.get_pts_count() < 3: + # not enough points after delete -> remove poly + self.ie_polys.remove_poly (self.op_poly) + self.set_op_mode(OpMode.NONE) + self.update() + + elif self.drag_type == DragType.NONE: + # otherwise -> start drag + self.drag_type = DragType.POLY_PT + self.drag_cli_pt = self.mouse_cli_pt + self.drag_poly_pt_id = self.mouse_op_poly_pt_id + self.drag_poly_pt = self.op_poly.get_pts()[ self.drag_poly_pt_id ] + elif self.mouse_op_poly_edge_id is not None: + # Click on edge of op_poly + if self.pt_edit_mode == PTEditMode.ADD_DEL: + # in mode 'insert new point' + edge_img_pt = self.cli_to_img_pt(self.mouse_op_poly_edge_id_pt) + self.op_poly.insert_pt (self.mouse_op_poly_edge_id+1, edge_img_pt) + self.update() + else: + # Otherwise do nothing + pass + else: + # other cases -> unselect poly + self.set_op_mode(OpMode.NONE) + + elif btn == Qt.MiddleButton: + if self.drag_type == DragType.NONE: + # Start image drag + self.drag_type = DragType.IMAGE_LOOK + self.drag_cli_pt = self.mouse_cli_pt + self.drag_img_look_pt = self.get_img_look_pt() + self.update_cursor() + + + def mouseReleaseEvent(self, ev): + super().mouseReleaseEvent(ev) + if not self.initialized: + return + + self.update_mouse_info(QPoint_to_np(ev.pos())) + + btn = ev.button() + + if btn == Qt.LeftButton: + if self.op_mode == OpMode.EDIT_PTS: + if self.drag_type == DragType.POLY_PT: + self.drag_type = DragType.NONE + self.update() + + elif btn == Qt.MiddleButton: + if self.drag_type == DragType.IMAGE_LOOK: + self.drag_type = DragType.NONE + self.update_cursor() + self.update() + + def mouseMoveEvent(self, ev): + super().mouseMoveEvent(ev) + if not self.initialized: + return + + prev_mouse_cli_pt = self.mouse_cli_pt + self.update_mouse_info(QPoint_to_np(ev.pos())) + + if self.view_lock == ViewLock.CENTER: + if npla.norm(self.mouse_cli_pt - prev_mouse_cli_pt) >= 1: + self.img_look_pt = self.mouse_img_pt + QCursor.setPos ( self.mapToGlobal( QPoint_from_np(self.img_to_cli_pt(self.img_look_pt)) )) + self.update() + + if self.drag_type == DragType.IMAGE_LOOK: + delta_pt = self.cli_to_img_pt(self.mouse_cli_pt) - self.cli_to_img_pt(self.drag_cli_pt) + self.img_look_pt = self.drag_img_look_pt - delta_pt + self.update() + + if self.op_mode == OpMode.DRAW_PTS: + self.update() + elif self.op_mode == OpMode.EDIT_PTS: + if self.drag_type == DragType.POLY_PT: + delta_pt = self.cli_to_img_pt(self.mouse_cli_pt) - self.cli_to_img_pt(self.drag_cli_pt) + self.op_poly.set_point(self.drag_poly_pt_id, self.drag_poly_pt + delta_pt) + self.update() + + def wheelEvent(self, ev): + super().wheelEvent(ev) + + if not self.initialized: + return + + mods = int(ev.modifiers()) + delta = ev.angleDelta() + + cli_pt = QPoint_to_np(ev.pos()) + + if self.drag_type == DragType.NONE: + sign = np.sign( delta.y() ) + prev_img_pos = self.cli_to_img_pt (cli_pt) + delta_scale = sign*0.2 + sign * self.get_view_scale() / 10.0 + self.view_scale = np.clip(self.get_view_scale() + delta_scale, 1.0, 20.0) + new_img_pos = self.cli_to_img_pt (cli_pt) + if sign > 0: + self.img_look_pt = self.get_img_look_pt() + (prev_img_pos-new_img_pos)#*1.5 + else: + QCursor.setPos ( self.mapToGlobal(QPoint_from_np(self.img_to_cli_pt(prev_img_pos))) ) + self.update() + + def paintEvent(self, event): + super().paintEvent(event) + if not self.initialized: + return + + qp = self.qp + qp.begin(self) + qp.setRenderHint(QPainter.Antialiasing) + qp.setRenderHint(QPainter.HighQualityAntialiasing) + qp.setRenderHint(QPainter.SmoothPixmapTransform) + + src_rect = QRect(0, 0, *self.img_size) + dst_rect = self.img_to_cli_rect( src_rect ) + + if self.op_mode == OpMode.VIEW_BAKED: + qp.drawPixmap(dst_rect, self.img_baked_pixmap, src_rect) + elif self.op_mode == OpMode.VIEW_XSEG_MASK: + if self.xseg_mask_pixmap is not None: + qp.drawPixmap(dst_rect, self.xseg_mask_pixmap, src_rect) + else: + if self.cbar.btn_view_xseg_overlay_mask_act.isChecked() and \ + self.xseg_overlay_mask_pixmap is not None: + qp.drawPixmap(dst_rect, self.xseg_overlay_mask_pixmap, src_rect) + elif self.img_pixmap is not None: + qp.drawPixmap(dst_rect, self.img_pixmap, src_rect) + + polys = self.ie_polys.get_polys() + polys_len = len(polys) + + color_scheme = self.get_current_color_scheme() + + pt_rad = self.canvas_config.pt_radius + pt_rad_x2 = pt_rad*2 + + pt_select_radius = self.canvas_config.pt_select_radius + + op_mode = self.op_mode + op_poly = self.op_poly + + for i,poly in enumerate(polys): + + selected_pt_path = QPainterPath() + poly_line_path = QPainterPath() + pts_line_path = QPainterPath() + + pt_remove_cli_pt = None + poly_pts = poly.get_pts() + for pt_id, img_pt in enumerate(poly_pts): + cli_pt = self.img_to_cli_pt(img_pt) + q_cli_pt = QPoint_from_np(cli_pt) + + if pt_id == 0: + poly_line_path.moveTo(q_cli_pt) + else: + poly_line_path.lineTo(q_cli_pt) + + + if poly == op_poly: + if self.op_mode == OpMode.DRAW_PTS or \ + (self.op_mode == OpMode.EDIT_PTS and \ + (self.pt_edit_mode == PTEditMode.MOVE) or \ + (self.pt_edit_mode == PTEditMode.ADD_DEL and self.mouse_op_poly_pt_id == pt_id) \ + ): + pts_line_path.moveTo( QPoint_from_np(cli_pt + np.float32([0,-pt_rad])) ) + pts_line_path.lineTo( QPoint_from_np(cli_pt + np.float32([0,pt_rad])) ) + pts_line_path.moveTo( QPoint_from_np(cli_pt + np.float32([-pt_rad,0])) ) + pts_line_path.lineTo( QPoint_from_np(cli_pt + np.float32([pt_rad,0])) ) + + if (self.op_mode == OpMode.EDIT_PTS and \ + self.pt_edit_mode == PTEditMode.ADD_DEL and \ + self.mouse_op_poly_pt_id == pt_id): + pt_remove_cli_pt = cli_pt + + if self.op_mode == OpMode.DRAW_PTS and \ + len(op_poly.get_pts()) >= 3 and pt_id == 0 and self.mouse_op_poly_pt_id == pt_id: + # Circle around poly point + selected_pt_path.addEllipse(q_cli_pt, pt_rad_x2, pt_rad_x2) + + + if poly == op_poly: + if op_mode == OpMode.DRAW_PTS: + # Line from last point to mouse + poly_line_path.lineTo( QPoint_from_np(self.mouse_cli_pt) ) + + if self.mouse_op_poly_pt_id is not None: + pass + + if self.mouse_op_poly_edge_id_pt is not None: + if self.pt_edit_mode == PTEditMode.ADD_DEL and self.mouse_op_poly_pt_id is None: + # Ready to insert point on edge + m_cli_pt = self.mouse_op_poly_edge_id_pt + pts_line_path.moveTo( QPoint_from_np(m_cli_pt + np.float32([0,-pt_rad])) ) + pts_line_path.lineTo( QPoint_from_np(m_cli_pt + np.float32([0,pt_rad])) ) + pts_line_path.moveTo( QPoint_from_np(m_cli_pt + np.float32([-pt_rad,0])) ) + pts_line_path.lineTo( QPoint_from_np(m_cli_pt + np.float32([pt_rad,0])) ) + + if len(poly_pts) >= 2: + # Closing poly line + poly_line_path.lineTo( QPoint_from_np(self.img_to_cli_pt(poly_pts[0])) ) + + # Draw calls + qp.setPen(color_scheme.pt_outline_pen) + qp.setBrush(QBrush()) + qp.drawPath(selected_pt_path) + + qp.setPen(color_scheme.poly_outline_solid_pen) + qp.setBrush(QBrush()) + qp.drawPath(pts_line_path) + + if poly.get_type() == SegIEPolyType.INCLUDE: + qp.setPen(color_scheme.poly_outline_solid_pen) + else: + qp.setPen(color_scheme.poly_outline_dot_pen) + + qp.setBrush(color_scheme.poly_unselected_brush) + if op_mode == OpMode.NONE: + if poly == self.mouse_wire_poly: + qp.setBrush(color_scheme.poly_selected_brush) + #else: + # if poly == op_poly: + # qp.setBrush(color_scheme.poly_selected_brush) + + qp.drawPath(poly_line_path) + + if pt_remove_cli_pt is not None: + qp.setPen(color_scheme.poly_outline_solid_pen) + qp.setBrush(QBrush()) + + qp.drawLine( *(pt_remove_cli_pt + np.float32([-pt_rad_x2,-pt_rad_x2])), *(pt_remove_cli_pt + np.float32([pt_rad_x2,pt_rad_x2])) ) + qp.drawLine( *(pt_remove_cli_pt + np.float32([-pt_rad_x2,pt_rad_x2])), *(pt_remove_cli_pt + np.float32([pt_rad_x2,-pt_rad_x2])) ) + + qp.end() + +class QCanvas(QFrame): + def __init__(self): + super().__init__() + + self.canvas_control_left_bar = QCanvasControlsLeftBar() + self.canvas_control_right_bar = QCanvasControlsRightBar() + + cbar = sn( btn_poly_color_red_act = self.canvas_control_right_bar.btn_poly_color_red_act, + btn_poly_color_green_act = self.canvas_control_right_bar.btn_poly_color_green_act, + btn_poly_color_blue_act = self.canvas_control_right_bar.btn_poly_color_blue_act, + btn_view_baked_mask_act = self.canvas_control_right_bar.btn_view_baked_mask_act, + btn_view_xseg_mask_act = self.canvas_control_right_bar.btn_view_xseg_mask_act, + btn_view_xseg_overlay_mask_act = self.canvas_control_right_bar.btn_view_xseg_overlay_mask_act, + btn_poly_color_act_grp = self.canvas_control_right_bar.btn_poly_color_act_grp, + btn_view_lock_center_act = self.canvas_control_right_bar.btn_view_lock_center_act, + + btn_poly_type_include_act = self.canvas_control_left_bar.btn_poly_type_include_act, + btn_poly_type_exclude_act = self.canvas_control_left_bar.btn_poly_type_exclude_act, + btn_poly_type_act_grp = self.canvas_control_left_bar.btn_poly_type_act_grp, + btn_undo_pt_act = self.canvas_control_left_bar.btn_undo_pt_act, + btn_redo_pt_act = self.canvas_control_left_bar.btn_redo_pt_act, + btn_delete_poly_act = self.canvas_control_left_bar.btn_delete_poly_act, + btn_pt_edit_mode_act = self.canvas_control_left_bar.btn_pt_edit_mode_act ) + + self.op = QCanvasOperator(cbar) + self.l = QHBoxLayout() + self.l.setContentsMargins(0,0,0,0) + self.l.addWidget(self.canvas_control_left_bar) + self.l.addWidget(self.op) + self.l.addWidget(self.canvas_control_right_bar) + self.setLayout(self.l) + +class LoaderQSubprocessor(QSubprocessor): + def __init__(self, image_paths, q_label, q_progressbar, on_finish_func ): + + self.image_paths = image_paths + self.image_paths_len = len(image_paths) + self.idxs = [*range(self.image_paths_len)] + + self.filtered_image_paths = self.image_paths.copy() + + self.image_paths_has_ie_polys = { image_path : False for image_path in self.image_paths } + + self.q_label = q_label + self.q_progressbar = q_progressbar + self.q_progressbar.setRange(0, self.image_paths_len) + self.q_progressbar.setValue(0) + self.q_progressbar.update() + self.on_finish_func = on_finish_func + self.done_count = 0 + super().__init__('LoaderQSubprocessor', LoaderQSubprocessor.Cli, 60) + + def get_data(self, host_dict): + if len (self.idxs) > 0: + idx = self.idxs.pop(0) + image_path = self.image_paths[idx] + self.q_label.setText(f'{QStringDB.loading_tip}... {image_path.name}') + + return idx, image_path + + return None + + def on_clients_finalized(self): + self.on_finish_func([x for x in self.filtered_image_paths if x is not None], self.image_paths_has_ie_polys) + + def on_data_return (self, host_dict, data): + self.idxs.insert(0, data[0]) + + def on_result (self, host_dict, data, result): + idx, has_dflimg, has_ie_polys = result + + if not has_dflimg: + self.filtered_image_paths[idx] = None + self.image_paths_has_ie_polys[self.image_paths[idx]] = has_ie_polys + + self.done_count += 1 + if self.q_progressbar is not None: + self.q_progressbar.setValue(self.done_count) + + class Cli(QSubprocessor.Cli): + def process_data(self, data): + idx, filename = data + dflimg = DFLIMG.load(filename) + if dflimg is not None and dflimg.has_data(): + ie_polys = dflimg.get_seg_ie_polys() + + return idx, True, ie_polys.has_polys() + return idx, False, False + +class MainWindow(QXMainWindow): + + def __init__(self, input_dirpath, cfg_root_path): + self.loading_frame = None + self.help_frame = None + + super().__init__() + + self.input_dirpath = input_dirpath + self.trash_dirpath = input_dirpath.parent / (input_dirpath.name + '_trash') + self.cfg_root_path = cfg_root_path + + self.cfg_path = cfg_root_path / 'MainWindow_cfg.dat' + self.cfg_dict = pickle.loads(self.cfg_path.read_bytes()) if self.cfg_path.exists() else {} + + self.cached_images = {} + self.cached_has_ie_polys = {} + + self.initialize_ui() + + # Loader + self.loading_frame = QFrame(self.main_canvas_frame) + self.loading_frame.setAutoFillBackground(True) + self.loading_frame.setFrameShape(QFrame.StyledPanel) + self.loader_label = QLabel() + self.loader_progress_bar = QProgressBar() + + intro_image = QLabel() + intro_image.setPixmap( QPixmap.fromImage(QImageDB.intro) ) + + intro_image_frame_l = QVBoxLayout() + intro_image_frame_l.addWidget(intro_image, alignment=Qt.AlignCenter) + intro_image_frame = QFrame() + intro_image_frame.setSizePolicy (QSizePolicy.Expanding, QSizePolicy.Expanding) + intro_image_frame.setLayout(intro_image_frame_l) + + loading_frame_l = QVBoxLayout() + loading_frame_l.addWidget (intro_image_frame) + loading_frame_l.addWidget (self.loader_label) + loading_frame_l.addWidget (self.loader_progress_bar) + self.loading_frame.setLayout(loading_frame_l) + + self.loader_subprocessor = LoaderQSubprocessor( image_paths=pathex.get_image_paths(input_dirpath, return_Path_class=True), + q_label=self.loader_label, + q_progressbar=self.loader_progress_bar, + on_finish_func=self.on_loader_finish ) + + + def on_loader_finish(self, image_paths, image_paths_has_ie_polys): + self.image_paths_done = [] + self.image_paths = image_paths + self.image_paths_has_ie_polys = image_paths_has_ie_polys + self.set_has_ie_polys_count ( len([ 1 for x in self.image_paths_has_ie_polys if self.image_paths_has_ie_polys[x] == True]) ) + self.loading_frame.hide() + self.loading_frame = None + + self.process_next_image(first_initialization=True) + + def closeEvent(self, ev): + self.cfg_dict['geometry'] = self.saveGeometry().data() + self.cfg_path.write_bytes( pickle.dumps(self.cfg_dict) ) + + + def update_cached_images (self, count=5): + d = self.cached_images + + for image_path in self.image_paths_done[:-count]+self.image_paths[count:]: + if image_path in d: + del d[image_path] + + for image_path in self.image_paths[:count]+self.image_paths_done[-count:]: + if image_path not in d: + img = cv2_imread(image_path) + if img is not None: + d[image_path] = img + + def load_image(self, image_path): + try: + img = self.cached_images.get(image_path, None) + if img is None: + img = cv2_imread(image_path) + self.cached_images[image_path] = img + if img is None: + io.log_err(f'Unable to load {image_path}') + except: + img = None + + return img + + def update_preview_bar(self): + count = self.image_bar.get_preview_images_count() + d = self.cached_images + prev_imgs = [ d.get(image_path, None) for image_path in self.image_paths_done[-1:-count:-1] ] + next_imgs = [ d.get(image_path, None) for image_path in self.image_paths[:count] ] + self.image_bar.update_images(prev_imgs, next_imgs) + + + def canvas_initialize(self, image_path, only_has_polys=False): + if only_has_polys and not self.image_paths_has_ie_polys[image_path]: + return False + + dflimg = DFLIMG.load(image_path) + if not dflimg or not dflimg.has_data(): + return False + + ie_polys = dflimg.get_seg_ie_polys() + xseg_mask = dflimg.get_xseg_mask() + img = self.load_image(image_path) + if img is None: + return False + + self.canvas.op.initialize ( img, ie_polys=ie_polys, xseg_mask=xseg_mask ) + + self.filename_label.setText(f"{image_path.name}") + + return True + + def canvas_finalize(self, image_path): + self.canvas.op.finalize() + + if image_path.exists(): + dflimg = DFLIMG.load(image_path) + ie_polys = dflimg.get_seg_ie_polys() + new_ie_polys = self.canvas.op.get_ie_polys() + + if not new_ie_polys.identical(ie_polys): + prev_has_polys = self.image_paths_has_ie_polys[image_path] + self.image_paths_has_ie_polys[image_path] = new_ie_polys.has_polys() + new_has_polys = self.image_paths_has_ie_polys[image_path] + + if not prev_has_polys and new_has_polys: + self.set_has_ie_polys_count ( self.get_has_ie_polys_count() +1) + elif prev_has_polys and not new_has_polys: + self.set_has_ie_polys_count ( self.get_has_ie_polys_count() -1) + + dflimg.set_seg_ie_polys( new_ie_polys ) + dflimg.save() + + self.filename_label.setText(f"") + + def process_prev_image(self): + key_mods = QApplication.keyboardModifiers() + step = 5 if key_mods == Qt.ShiftModifier else 1 + only_has_polys = key_mods == Qt.ControlModifier + + if self.canvas.op.is_initialized(): + self.canvas_finalize(self.image_paths[0]) + + while True: + for _ in range(step): + if len(self.image_paths_done) != 0: + self.image_paths.insert (0, self.image_paths_done.pop(-1)) + else: + break + if len(self.image_paths) == 0: + break + + ret = self.canvas_initialize(self.image_paths[0], len(self.image_paths_done) != 0 and only_has_polys) + + if ret or len(self.image_paths_done) == 0: + break + + self.update_cached_images() + self.update_preview_bar() + + def process_next_image(self, first_initialization=False): + key_mods = QApplication.keyboardModifiers() + + step = 0 if first_initialization else 5 if key_mods == Qt.ShiftModifier else 1 + only_has_polys = False if first_initialization else key_mods == Qt.ControlModifier + + if self.canvas.op.is_initialized(): + self.canvas_finalize(self.image_paths[0]) + + while True: + for _ in range(step): + if len(self.image_paths) != 0: + self.image_paths_done.append(self.image_paths.pop(0)) + else: + break + if len(self.image_paths) == 0: + break + if self.canvas_initialize(self.image_paths[0], only_has_polys): + break + + self.update_cached_images() + self.update_preview_bar() + + def trash_current_image(self): + self.process_next_image() + + img_path = self.image_paths_done.pop(-1) + img_path = Path(img_path) + self.trash_dirpath.mkdir(parents=True, exist_ok=True) + img_path.rename( self.trash_dirpath / img_path.name ) + + self.update_cached_images() + self.update_preview_bar() + + def initialize_ui(self): + + self.canvas = QCanvas() + + image_bar = self.image_bar = ImagePreviewSequenceBar(preview_images_count=9, icon_size=QUIConfig.preview_bar_icon_q_size.width()) + image_bar.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed ) + + + btn_prev_image = QXIconButton(QIconDB.left, QStringDB.btn_prev_image_tip, shortcut='A', click_func=self.process_prev_image) + btn_prev_image.setIconSize(QUIConfig.preview_bar_icon_q_size) + + btn_next_image = QXIconButton(QIconDB.right, QStringDB.btn_next_image_tip, shortcut='D', click_func=self.process_next_image) + btn_next_image.setIconSize(QUIConfig.preview_bar_icon_q_size) + + btn_delete_image = QXIconButton(QIconDB.trashcan, QStringDB.btn_delete_image_tip, shortcut='X', click_func=self.trash_current_image) + btn_delete_image.setIconSize(QUIConfig.preview_bar_icon_q_size) + + pad_image = QWidget() + pad_image.setFixedSize(QUIConfig.preview_bar_icon_q_size) + + preview_image_bar_frame_l = QHBoxLayout() + preview_image_bar_frame_l.setContentsMargins(0,0,0,0) + preview_image_bar_frame_l.addWidget ( pad_image, alignment=Qt.AlignCenter) + preview_image_bar_frame_l.addWidget ( btn_prev_image, alignment=Qt.AlignCenter) + preview_image_bar_frame_l.addWidget ( image_bar) + preview_image_bar_frame_l.addWidget ( btn_next_image, alignment=Qt.AlignCenter) + #preview_image_bar_frame_l.addWidget ( btn_delete_image, alignment=Qt.AlignCenter) + + preview_image_bar_frame = QFrame() + preview_image_bar_frame.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed ) + preview_image_bar_frame.setLayout(preview_image_bar_frame_l) + + preview_image_bar_frame2_l = QHBoxLayout() + preview_image_bar_frame2_l.setContentsMargins(0,0,0,0) + preview_image_bar_frame2_l.addWidget ( btn_delete_image, alignment=Qt.AlignCenter) + + preview_image_bar_frame2 = QFrame() + preview_image_bar_frame2.setSizePolicy ( QSizePolicy.Fixed, QSizePolicy.Fixed ) + preview_image_bar_frame2.setLayout(preview_image_bar_frame2_l) + + preview_image_bar_l = QHBoxLayout() + preview_image_bar_l.addWidget (preview_image_bar_frame, alignment=Qt.AlignCenter) + preview_image_bar_l.addWidget (preview_image_bar_frame2) + + preview_image_bar = QFrame() + preview_image_bar.setFrameShape(QFrame.StyledPanel) + preview_image_bar.setSizePolicy ( QSizePolicy.Expanding, QSizePolicy.Fixed ) + preview_image_bar.setLayout(preview_image_bar_l) + + label_font = QFont('Courier New') + self.filename_label = QLabel() + self.filename_label.setFont(label_font) + + self.has_ie_polys_count_label = QLabel() + + status_frame_l = QHBoxLayout() + status_frame_l.setContentsMargins(0,0,0,0) + status_frame_l.addWidget ( QLabel(), alignment=Qt.AlignCenter) + status_frame_l.addWidget (self.filename_label, alignment=Qt.AlignCenter) + status_frame_l.addWidget (self.has_ie_polys_count_label, alignment=Qt.AlignCenter) + status_frame = QFrame() + status_frame.setLayout(status_frame_l) + + main_canvas_l = QVBoxLayout() + main_canvas_l.setContentsMargins(0,0,0,0) + main_canvas_l.addWidget (self.canvas) + main_canvas_l.addWidget (status_frame) + main_canvas_l.addWidget (preview_image_bar) + + self.main_canvas_frame = QFrame() + self.main_canvas_frame.setLayout(main_canvas_l) + + self.main_l = QHBoxLayout() + self.main_l.setContentsMargins(0,0,0,0) + self.main_l.addWidget (self.main_canvas_frame) + + self.setLayout(self.main_l) + + geometry = self.cfg_dict.get('geometry', None) + if geometry is not None: + self.restoreGeometry(geometry) + else: + self.move( QPoint(0,0)) + + def get_has_ie_polys_count(self): + return self.has_ie_polys_count + + def set_has_ie_polys_count(self, c): + self.has_ie_polys_count = c + self.has_ie_polys_count_label.setText(f"{c} {QStringDB.labeled_tip}") + + def resizeEvent(self, ev): + if self.loading_frame is not None: + self.loading_frame.resize( ev.size() ) + if self.help_frame is not None: + self.help_frame.resize( ev.size() ) + +def start(input_dirpath): + """ + returns exit_code + """ + io.log_info("Running XSeg editor.") + + if PackedFaceset.path_contains(input_dirpath): + io.log_info (f'\n{input_dirpath} contains packed faceset! Unpack it first.\n') + return 1 + + root_path = Path(__file__).parent + cfg_root_path = Path(tempfile.gettempdir()) + + QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True) + QApplication.setAttribute(Qt.AA_UseHighDpiPixmaps, True) + + app = QApplication([]) + app.setApplicationName("XSegEditor") + app.setStyle('Fusion') + + QFontDatabase.addApplicationFont( str(root_path / 'gfx' / 'fonts' / 'NotoSans-Medium.ttf') ) + + app.setFont( QFont('NotoSans')) + + QUIConfig.initialize() + QStringDB.initialize() + + QIconDB.initialize( root_path / 'gfx' / 'icons' ) + QCursorDB.initialize( root_path / 'gfx' / 'cursors' ) + QImageDB.initialize( root_path / 'gfx' / 'images' ) + + app.setWindowIcon(QIconDB.app_icon) + app.setPalette( QDarkPalette() ) + + win = MainWindow( input_dirpath=input_dirpath, cfg_root_path=cfg_root_path) + + win.show() + win.raise_() + + app.exec_() + return 0 diff --git a/XSegEditor/gfx/cursors/cross_blue.png b/XSegEditor/gfx/cursors/cross_blue.png new file mode 100644 index 0000000000000000000000000000000000000000..89152194a802b59df2e25ad602b1753d355af8c5 Binary files /dev/null and b/XSegEditor/gfx/cursors/cross_blue.png differ diff --git a/XSegEditor/gfx/cursors/cross_green.png b/XSegEditor/gfx/cursors/cross_green.png new file mode 100644 index 0000000000000000000000000000000000000000..3ce16f03ba8c707f996513e435684d4a2669e5be Binary files /dev/null and b/XSegEditor/gfx/cursors/cross_green.png differ diff --git a/XSegEditor/gfx/cursors/cross_red.png b/XSegEditor/gfx/cursors/cross_red.png new file mode 100644 index 0000000000000000000000000000000000000000..bb851acddfcc4571003404b6c82cb4e401ea2b4b Binary files /dev/null and b/XSegEditor/gfx/cursors/cross_red.png differ diff --git a/XSegEditor/gfx/fonts/NotoSans-Medium.ttf b/XSegEditor/gfx/fonts/NotoSans-Medium.ttf new file mode 100644 index 0000000000000000000000000000000000000000..25050f76b267a99ce10b820ed128cecd351f4f63 Binary files /dev/null and b/XSegEditor/gfx/fonts/NotoSans-Medium.ttf differ diff --git a/XSegEditor/gfx/icons/app_icon.png b/XSegEditor/gfx/icons/app_icon.png new file mode 100644 index 0000000000000000000000000000000000000000..16bc03e08e1f543f8b481a656b8ee42b23e88693 Binary files /dev/null and b/XSegEditor/gfx/icons/app_icon.png differ diff --git a/XSegEditor/gfx/icons/delete_poly.png b/XSegEditor/gfx/icons/delete_poly.png new file mode 100644 index 0000000000000000000000000000000000000000..afd57d198741d2f7d046a7cbdf63016c5430ad78 Binary files /dev/null and b/XSegEditor/gfx/icons/delete_poly.png differ diff --git a/XSegEditor/gfx/icons/down.png b/XSegEditor/gfx/icons/down.png new file mode 100644 index 0000000000000000000000000000000000000000..873b71953e66f64ff6c27fee04ed4e864d93c75d Binary files /dev/null and b/XSegEditor/gfx/icons/down.png differ diff --git a/XSegEditor/gfx/icons/left.png b/XSegEditor/gfx/icons/left.png new file mode 100644 index 0000000000000000000000000000000000000000..2118be68fb5d624e852ef4b4649a92d30a9ab109 Binary files /dev/null and b/XSegEditor/gfx/icons/left.png differ diff --git a/XSegEditor/gfx/icons/poly_color.psd b/XSegEditor/gfx/icons/poly_color.psd new file mode 100644 index 0000000000000000000000000000000000000000..9a94957f41648d2ad37b5562a0119459309bddab Binary files /dev/null and b/XSegEditor/gfx/icons/poly_color.psd differ diff --git a/XSegEditor/gfx/icons/poly_color_blue.png b/XSegEditor/gfx/icons/poly_color_blue.png new file mode 100644 index 0000000000000000000000000000000000000000..80b5222e54e01c6f4506dd2c87c42addf9d08655 Binary files /dev/null and b/XSegEditor/gfx/icons/poly_color_blue.png differ diff --git a/XSegEditor/gfx/icons/poly_color_green.png b/XSegEditor/gfx/icons/poly_color_green.png new file mode 100644 index 0000000000000000000000000000000000000000..2db1fbb9872bbe8a36cab003d3b99c40b4bd371e Binary files /dev/null and b/XSegEditor/gfx/icons/poly_color_green.png differ diff --git a/XSegEditor/gfx/icons/poly_color_red.png b/XSegEditor/gfx/icons/poly_color_red.png new file mode 100644 index 0000000000000000000000000000000000000000..d04efff638acf8967cfc4630e72d7ca8c64ff86a Binary files /dev/null and b/XSegEditor/gfx/icons/poly_color_red.png differ diff --git a/XSegEditor/gfx/icons/poly_type_exclude.png b/XSegEditor/gfx/icons/poly_type_exclude.png new file mode 100644 index 0000000000000000000000000000000000000000..8e36bc3e74a9c19cc5daddad05c776182753e2a3 Binary files /dev/null and b/XSegEditor/gfx/icons/poly_type_exclude.png differ diff --git a/XSegEditor/gfx/icons/poly_type_include.png b/XSegEditor/gfx/icons/poly_type_include.png new file mode 100644 index 0000000000000000000000000000000000000000..5f16c1566b5134ce4aab1f3eb35756b5de8c85e1 Binary files /dev/null and b/XSegEditor/gfx/icons/poly_type_include.png differ diff --git a/XSegEditor/gfx/icons/poly_type_source.psd b/XSegEditor/gfx/icons/poly_type_source.psd new file mode 100644 index 0000000000000000000000000000000000000000..50943d014ec731f8b5c8ad6a1265f1b138289ae2 Binary files /dev/null and b/XSegEditor/gfx/icons/poly_type_source.psd differ diff --git a/XSegEditor/gfx/icons/pt_edit_mode.png b/XSegEditor/gfx/icons/pt_edit_mode.png new file mode 100644 index 0000000000000000000000000000000000000000..d385fc2fd274fe9212c4720c342c49300dec9f72 Binary files /dev/null and b/XSegEditor/gfx/icons/pt_edit_mode.png differ diff --git a/XSegEditor/gfx/icons/pt_edit_mode_source.psd b/XSegEditor/gfx/icons/pt_edit_mode_source.psd new file mode 100644 index 0000000000000000000000000000000000000000..f73e31075669bf46c4d325615ad94e902891c3e0 Binary files /dev/null and b/XSegEditor/gfx/icons/pt_edit_mode_source.psd differ diff --git a/XSegEditor/gfx/icons/redo_pt.png b/XSegEditor/gfx/icons/redo_pt.png new file mode 100644 index 0000000000000000000000000000000000000000..aa73329d9ff8d9848b37347e94a79a012e889600 Binary files /dev/null and b/XSegEditor/gfx/icons/redo_pt.png differ diff --git a/XSegEditor/gfx/icons/redo_pt_source.psd b/XSegEditor/gfx/icons/redo_pt_source.psd new file mode 100644 index 0000000000000000000000000000000000000000..2771f77286c4314eee2422b7e143879a89074d47 Binary files /dev/null and b/XSegEditor/gfx/icons/redo_pt_source.psd differ diff --git a/XSegEditor/gfx/icons/right.png b/XSegEditor/gfx/icons/right.png new file mode 100644 index 0000000000000000000000000000000000000000..b4ef22031b05bfc6fa390d936046a2494ecf635a Binary files /dev/null and b/XSegEditor/gfx/icons/right.png differ diff --git a/XSegEditor/gfx/icons/trashcan.png b/XSegEditor/gfx/icons/trashcan.png new file mode 100644 index 0000000000000000000000000000000000000000..a31285b12b79836107a336ccc0c7eb4e67c01fbc Binary files /dev/null and b/XSegEditor/gfx/icons/trashcan.png differ diff --git a/XSegEditor/gfx/icons/undo_pt.png b/XSegEditor/gfx/icons/undo_pt.png new file mode 100644 index 0000000000000000000000000000000000000000..7a4464c14c728f05c6ae708d4d78a2114e6485e8 Binary files /dev/null and b/XSegEditor/gfx/icons/undo_pt.png differ diff --git a/XSegEditor/gfx/icons/undo_pt_source.psd b/XSegEditor/gfx/icons/undo_pt_source.psd new file mode 100644 index 0000000000000000000000000000000000000000..98b9d1a71a76a13f340d9a0037d67667a6d404b4 Binary files /dev/null and b/XSegEditor/gfx/icons/undo_pt_source.psd differ diff --git a/XSegEditor/gfx/icons/up.png b/XSegEditor/gfx/icons/up.png new file mode 100644 index 0000000000000000000000000000000000000000..f3368b641aa47807feb3f39ec7f686601b1231a3 Binary files /dev/null and b/XSegEditor/gfx/icons/up.png differ diff --git a/XSegEditor/gfx/icons/view_baked.png b/XSegEditor/gfx/icons/view_baked.png new file mode 100644 index 0000000000000000000000000000000000000000..3e321422405e8aed6a2db7e7091f1bafc585e82e Binary files /dev/null and b/XSegEditor/gfx/icons/view_baked.png differ diff --git a/XSegEditor/gfx/icons/view_lock_center.png b/XSegEditor/gfx/icons/view_lock_center.png new file mode 100644 index 0000000000000000000000000000000000000000..2a10408cca9f994f5d8ae6383199b52966178f82 Binary files /dev/null and b/XSegEditor/gfx/icons/view_lock_center.png differ diff --git a/XSegEditor/gfx/icons/view_xseg.png b/XSegEditor/gfx/icons/view_xseg.png new file mode 100644 index 0000000000000000000000000000000000000000..7328d2c40cb716e477447bde41715d4b902ec4c8 Binary files /dev/null and b/XSegEditor/gfx/icons/view_xseg.png differ diff --git a/XSegEditor/gfx/icons/view_xseg_overlay.png b/XSegEditor/gfx/icons/view_xseg_overlay.png new file mode 100644 index 0000000000000000000000000000000000000000..d188285f82c6bd736e7771d6f9f81fe3319f00db Binary files /dev/null and b/XSegEditor/gfx/icons/view_xseg_overlay.png differ diff --git a/XSegEditor/gfx/images/intro.png b/XSegEditor/gfx/images/intro.png new file mode 100644 index 0000000000000000000000000000000000000000..7f4d43f6472ae1a9de22d5052b54f0cd672fae2d Binary files /dev/null and b/XSegEditor/gfx/images/intro.png differ diff --git a/XSegEditor/gfx/images/intro_source.psd b/XSegEditor/gfx/images/intro_source.psd new file mode 100644 index 0000000000000000000000000000000000000000..bb1cc901d1a7d7d76e6df9a6755a2fa492a6709c Binary files /dev/null and b/XSegEditor/gfx/images/intro_source.psd differ diff --git a/_config.yml b/_config.yml new file mode 100644 index 0000000000000000000000000000000000000000..97517153e900f8d9b5da0ba5631b8d668e844e81 --- /dev/null +++ b/_config.yml @@ -0,0 +1,9 @@ +theme: jekyll-theme-cayman +plugins: + - jekyll-relative-links +relative_links: + enabled: true + collections: true + +include: + - README.md \ No newline at end of file diff --git a/core/cv2ex.py b/core/cv2ex.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5d73c4ff7271def3d3e875d83f842f7eb2b707 --- /dev/null +++ b/core/cv2ex.py @@ -0,0 +1,40 @@ +import cv2 +import numpy as np +from pathlib import Path +from core.interact import interact as io +from core import imagelib +import traceback + +def cv2_imread(filename, flags=cv2.IMREAD_UNCHANGED, loader_func=None, verbose=True): + """ + allows to open non-english characters path + """ + try: + if loader_func is not None: + bytes = bytearray(loader_func(filename)) + else: + with open(filename, "rb") as stream: + bytes = bytearray(stream.read()) + numpyarray = np.asarray(bytes, dtype=np.uint8) + return cv2.imdecode(numpyarray, flags) + except: + if verbose: + io.log_err(f"Exception occured in cv2_imread : {traceback.format_exc()}") + return None + +def cv2_imwrite(filename, img, *args): + ret, buf = cv2.imencode( Path(filename).suffix, img, *args) + if ret == True: + try: + with open(filename, "wb") as stream: + stream.write( buf ) + except: + pass + +def cv2_resize(x, *args, **kwargs): + h,w,c = x.shape + x = cv2.resize(x, *args, **kwargs) + + x = imagelib.normalize_channels(x, c) + return x + \ No newline at end of file diff --git a/core/imagelib/SegIEPolys.py b/core/imagelib/SegIEPolys.py new file mode 100644 index 0000000000000000000000000000000000000000..1a4c3d29e724da56a3e151d22c2bbb95b421ce90 --- /dev/null +++ b/core/imagelib/SegIEPolys.py @@ -0,0 +1,158 @@ +import numpy as np +import cv2 +from enum import IntEnum + + +class SegIEPolyType(IntEnum): + EXCLUDE = 0 + INCLUDE = 1 + + + +class SegIEPoly(): + def __init__(self, type=None, pts=None, **kwargs): + self.type = type + + if pts is None: + pts = np.empty( (0,2), dtype=np.float32 ) + else: + pts = np.float32(pts) + self.pts = pts + self.n_max = self.n = len(pts) + + def dump(self): + return {'type': int(self.type), + 'pts' : self.get_pts(), + } + + def identical(self, b): + if self.n != b.n: + return False + return (self.pts[0:self.n] == b.pts[0:b.n]).all() + + def get_type(self): + return self.type + + def add_pt(self, x, y): + self.pts = np.append(self.pts[0:self.n], [ ( float(x), float(y) ) ], axis=0).astype(np.float32) + self.n_max = self.n = self.n + 1 + + def undo(self): + self.n = max(0, self.n-1) + return self.n + + def redo(self): + self.n = min(len(self.pts), self.n+1) + return self.n + + def redo_clip(self): + self.pts = self.pts[0:self.n] + self.n_max = self.n + + def insert_pt(self, n, pt): + if n < 0 or n > self.n: + raise ValueError("insert_pt out of range") + self.pts = np.concatenate( (self.pts[0:n], pt[None,...].astype(np.float32), self.pts[n:]), axis=0) + self.n_max = self.n = self.n+1 + + def remove_pt(self, n): + if n < 0 or n >= self.n: + raise ValueError("remove_pt out of range") + self.pts = np.concatenate( (self.pts[0:n], self.pts[n+1:]), axis=0) + self.n_max = self.n = self.n-1 + + def get_last_point(self): + return self.pts[self.n-1].copy() + + def get_pts(self): + return self.pts[0:self.n].copy() + + def get_pts_count(self): + return self.n + + def set_point(self, id, pt): + self.pts[id] = pt + + def set_points(self, pts): + self.pts = np.array(pts) + self.n_max = self.n = len(pts) + + def mult_points(self, val): + self.pts *= val + + + +class SegIEPolys(): + def __init__(self): + self.polys = [] + + def identical(self, b): + polys_len = len(self.polys) + o_polys_len = len(b.polys) + if polys_len != o_polys_len: + return False + + return all ([ a_poly.identical(b_poly) for a_poly, b_poly in zip(self.polys, b.polys) ]) + + def add_poly(self, ie_poly_type): + poly = SegIEPoly(ie_poly_type) + self.polys.append (poly) + return poly + + def remove_poly(self, poly): + if poly in self.polys: + self.polys.remove(poly) + + def has_polys(self): + return len(self.polys) != 0 + + def get_poly(self, id): + return self.polys[id] + + def get_polys(self): + return self.polys + + def get_pts_count(self): + return sum([poly.get_pts_count() for poly in self.polys]) + + def sort(self): + poly_by_type = { SegIEPolyType.EXCLUDE : [], SegIEPolyType.INCLUDE : [] } + + for poly in self.polys: + poly_by_type[poly.type].append(poly) + + self.polys = poly_by_type[SegIEPolyType.INCLUDE] + poly_by_type[SegIEPolyType.EXCLUDE] + + def __iter__(self): + for poly in self.polys: + yield poly + + def overlay_mask(self, mask): + h,w,c = mask.shape + white = (1,)*c + black = (0,)*c + for poly in self.polys: + pts = poly.get_pts().astype(np.int32) + if len(pts) != 0: + cv2.fillPoly(mask, [pts], white if poly.type == SegIEPolyType.INCLUDE else black ) + + def dump(self): + return {'polys' : [ poly.dump() for poly in self.polys ] } + + def mult_points(self, val): + for poly in self.polys: + poly.mult_points(val) + + @staticmethod + def load(data=None): + ie_polys = SegIEPolys() + if data is not None: + if isinstance(data, list): + # Backward comp + ie_polys.polys = [ SegIEPoly(type=type, pts=pts) for (type, pts) in data ] + elif isinstance(data, dict): + ie_polys.polys = [ SegIEPoly(**poly_cfg) for poly_cfg in data['polys'] ] + + ie_polys.sort() + + return ie_polys \ No newline at end of file diff --git a/core/imagelib/__init__.py b/core/imagelib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..11234a5db7d5e39e3ee11208cb67ef035a0b4277 --- /dev/null +++ b/core/imagelib/__init__.py @@ -0,0 +1,32 @@ +from .estimate_sharpness import estimate_sharpness + +from .equalize_and_stack_square import equalize_and_stack_square + +from .text import get_text_image, get_draw_text_lines + +from .draw import draw_polygon, draw_rect + +from .morph import morph_by_points + +from .warp import gen_warp_params, warp_by_params + +from .reduce_colors import reduce_colors + +from .color_transfer import color_transfer, color_transfer_mix, color_transfer_sot, color_transfer_mkl, color_transfer_idt, color_hist_match, reinhard_color_transfer, linear_color_transfer + +from .common import random_crop, normalize_channels, cut_odd_image, overlay_alpha_image + +from .SegIEPolys import * + +from .blursharpen import LinearMotionBlur, blursharpen + +from .filters import apply_random_rgb_levels, \ + apply_random_overlay_triangle, \ + apply_random_hsv_shift, \ + apply_random_sharpen, \ + apply_random_motion_blur, \ + apply_random_gaussian_blur, \ + apply_random_nearest_resize, \ + apply_random_bilinear_resize, \ + apply_random_jpeg_compress, \ + apply_random_relight diff --git a/core/imagelib/blursharpen.py b/core/imagelib/blursharpen.py new file mode 100644 index 0000000000000000000000000000000000000000..51745119f7066d6ba4e57dcbc435119b28b03983 --- /dev/null +++ b/core/imagelib/blursharpen.py @@ -0,0 +1,38 @@ +import cv2 +import numpy as np + +def LinearMotionBlur(image, size, angle): + k = np.zeros((size, size), dtype=np.float32) + k[ (size-1)// 2 , :] = np.ones(size, dtype=np.float32) + k = cv2.warpAffine(k, cv2.getRotationMatrix2D( (size / 2 -0.5 , size / 2 -0.5 ) , angle, 1.0), (size, size) ) + k = k * ( 1.0 / np.sum(k) ) + return cv2.filter2D(image, -1, k) + +def blursharpen (img, sharpen_mode=0, kernel_size=3, amount=100): + if kernel_size % 2 == 0: + kernel_size += 1 + if amount > 0: + if sharpen_mode == 1: #box + kernel = np.zeros( (kernel_size, kernel_size), dtype=np.float32) + kernel[ kernel_size//2, kernel_size//2] = 1.0 + box_filter = np.ones( (kernel_size, kernel_size), dtype=np.float32) / (kernel_size**2) + kernel = kernel + (kernel - box_filter) * amount + return cv2.filter2D(img, -1, kernel) + elif sharpen_mode == 2: #gaussian + blur = cv2.GaussianBlur(img, (kernel_size, kernel_size) , 0) + img = cv2.addWeighted(img, 1.0 + (0.5 * amount), blur, -(0.5 * amount), 0) + return img + elif amount < 0: + n = -amount + while n > 0: + + img_blur = cv2.medianBlur(img, 5) + if int(n / 10) != 0: + img = img_blur + else: + pass_power = (n % 10) / 10.0 + img = img*(1.0-pass_power)+img_blur*pass_power + n = max(n-10,0) + + return img + return img \ No newline at end of file diff --git a/core/imagelib/color_transfer.py b/core/imagelib/color_transfer.py new file mode 100644 index 0000000000000000000000000000000000000000..d269de2d36ed597873aabe43d36927a6a45c95d3 --- /dev/null +++ b/core/imagelib/color_transfer.py @@ -0,0 +1,336 @@ +import cv2 +import numexpr as ne +import numpy as np +import scipy as sp +from numpy import linalg as npla + + +def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_sigmaV=5.0): + """ + Color Transform via Sliced Optimal Transfer + ported by @iperov from https://github.com/dcoeurjo/OTColorTransfer + + src - any float range any channel image + dst - any float range any channel image, same shape as src + steps - number of solver steps + batch_size - solver batch size + reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0 + reg_sigmaV - sigmaV of filter + + return value - clip it manually + """ + if not np.issubdtype(src.dtype, np.floating): + raise ValueError("src value must be float") + if not np.issubdtype(trg.dtype, np.floating): + raise ValueError("trg value must be float") + + if len(src.shape) != 3: + raise ValueError("src shape must have rank 3 (h,w,c)") + + if src.shape != trg.shape: + raise ValueError("src and trg shapes must be equal") + + src_dtype = src.dtype + h,w,c = src.shape + new_src = src.copy() + + advect = np.empty ( (h*w,c), dtype=src_dtype ) + for step in range (steps): + advect.fill(0) + for batch in range (batch_size): + dir = np.random.normal(size=c).astype(src_dtype) + dir /= npla.norm(dir) + + projsource = np.sum( new_src*dir, axis=-1).reshape ((h*w)) + projtarget = np.sum( trg*dir, axis=-1).reshape ((h*w)) + + idSource = np.argsort (projsource) + idTarget = np.argsort (projtarget) + + a = projtarget[idTarget]-projsource[idSource] + for i_c in range(c): + advect[idSource,i_c] += a * dir[i_c] + new_src += advect.reshape( (h,w,c) ) / batch_size + + if reg_sigmaXY != 0.0: + src_diff = new_src-src + src_diff_filt = cv2.bilateralFilter (src_diff, 0, reg_sigmaV, reg_sigmaXY ) + if len(src_diff_filt.shape) == 2: + src_diff_filt = src_diff_filt[...,None] + new_src = src + src_diff_filt + return new_src + +def color_transfer_mkl(x0, x1): + eps = np.finfo(float).eps + + h,w,c = x0.shape + h1,w1,c1 = x1.shape + + x0 = x0.reshape ( (h*w,c) ) + x1 = x1.reshape ( (h1*w1,c1) ) + + a = np.cov(x0.T) + b = np.cov(x1.T) + + Da2, Ua = np.linalg.eig(a) + Da = np.diag(np.sqrt(Da2.clip(eps, None))) + + C = np.dot(np.dot(np.dot(np.dot(Da, Ua.T), b), Ua), Da) + + Dc2, Uc = np.linalg.eig(C) + Dc = np.diag(np.sqrt(Dc2.clip(eps, None))) + + Da_inv = np.diag(1./(np.diag(Da))) + + t = np.dot(np.dot(np.dot(np.dot(np.dot(np.dot(Ua, Da_inv), Uc), Dc), Uc.T), Da_inv), Ua.T) + + mx0 = np.mean(x0, axis=0) + mx1 = np.mean(x1, axis=0) + + result = np.dot(x0-mx0, t) + mx1 + return np.clip ( result.reshape ( (h,w,c) ).astype(x0.dtype), 0, 1) + +def color_transfer_idt(i0, i1, bins=256, n_rot=20): + import scipy.stats + + relaxation = 1 / n_rot + h,w,c = i0.shape + h1,w1,c1 = i1.shape + + i0 = i0.reshape ( (h*w,c) ) + i1 = i1.reshape ( (h1*w1,c1) ) + + n_dims = c + + d0 = i0.T + d1 = i1.T + + for i in range(n_rot): + + r = sp.stats.special_ortho_group.rvs(n_dims).astype(np.float32) + + d0r = np.dot(r, d0) + d1r = np.dot(r, d1) + d_r = np.empty_like(d0) + + for j in range(n_dims): + + lo = min(d0r[j].min(), d1r[j].min()) + hi = max(d0r[j].max(), d1r[j].max()) + + p0r, edges = np.histogram(d0r[j], bins=bins, range=[lo, hi]) + p1r, _ = np.histogram(d1r[j], bins=bins, range=[lo, hi]) + + cp0r = p0r.cumsum().astype(np.float32) + cp0r /= cp0r[-1] + + cp1r = p1r.cumsum().astype(np.float32) + cp1r /= cp1r[-1] + + f = np.interp(cp0r, cp1r, edges[1:]) + + d_r[j] = np.interp(d0r[j], edges[1:], f, left=0, right=bins) + + d0 = relaxation * np.linalg.solve(r, (d_r - d0r)) + d0 + + return np.clip ( d0.T.reshape ( (h,w,c) ).astype(i0.dtype) , 0, 1) + +def reinhard_color_transfer(target : np.ndarray, source : np.ndarray, target_mask : np.ndarray = None, source_mask : np.ndarray = None, mask_cutoff=0.5) -> np.ndarray: + """ + Transfer color using rct method. + + target np.ndarray H W 3C (BGR) np.float32 + source np.ndarray H W 3C (BGR) np.float32 + + target_mask(None) np.ndarray H W 1C np.float32 + source_mask(None) np.ndarray H W 1C np.float32 + + mask_cutoff(0.5) float + + masks are used to limit the space where color statistics will be computed to adjust the target + + reference: Color Transfer between Images https://www.cs.tau.ac.il/~turkel/imagepapers/ColorTransfer.pdf + """ + source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB) + target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB) + + source_input = source + if source_mask is not None: + source_input = source_input.copy() + source_input[source_mask[...,0] < mask_cutoff] = [0,0,0] + + target_input = target + if target_mask is not None: + target_input = target_input.copy() + target_input[target_mask[...,0] < mask_cutoff] = [0,0,0] + + target_l_mean, target_l_std, target_a_mean, target_a_std, target_b_mean, target_b_std, \ + = target_input[...,0].mean(), target_input[...,0].std(), target_input[...,1].mean(), target_input[...,1].std(), target_input[...,2].mean(), target_input[...,2].std() + + source_l_mean, source_l_std, source_a_mean, source_a_std, source_b_mean, source_b_std, \ + = source_input[...,0].mean(), source_input[...,0].std(), source_input[...,1].mean(), source_input[...,1].std(), source_input[...,2].mean(), source_input[...,2].std() + + # not as in the paper: scale by the standard deviations using reciprocal of paper proposed factor + target_l = target[...,0] + target_l = ne.evaluate('(target_l - target_l_mean) * source_l_std / target_l_std + source_l_mean') + + target_a = target[...,1] + target_a = ne.evaluate('(target_a - target_a_mean) * source_a_std / target_a_std + source_a_mean') + + target_b = target[...,2] + target_b = ne.evaluate('(target_b - target_b_mean) * source_b_std / target_b_std + source_b_mean') + + np.clip(target_l, 0, 100, out=target_l) + np.clip(target_a, -127, 127, out=target_a) + np.clip(target_b, -127, 127, out=target_b) + + return cv2.cvtColor(np.stack([target_l,target_a,target_b], -1), cv2.COLOR_LAB2BGR) + + +def linear_color_transfer(target_img, source_img, mode='pca', eps=1e-5): + ''' + Matches the colour distribution of the target image to that of the source image + using a linear transform. + Images are expected to be of form (w,h,c) and float in [0,1]. + Modes are chol, pca or sym for different choices of basis. + ''' + mu_t = target_img.mean(0).mean(0) + t = target_img - mu_t + t = t.transpose(2,0,1).reshape( t.shape[-1],-1) + Ct = t.dot(t.T) / t.shape[1] + eps * np.eye(t.shape[0]) + mu_s = source_img.mean(0).mean(0) + s = source_img - mu_s + s = s.transpose(2,0,1).reshape( s.shape[-1],-1) + Cs = s.dot(s.T) / s.shape[1] + eps * np.eye(s.shape[0]) + if mode == 'chol': + chol_t = np.linalg.cholesky(Ct) + chol_s = np.linalg.cholesky(Cs) + ts = chol_s.dot(np.linalg.inv(chol_t)).dot(t) + if mode == 'pca': + eva_t, eve_t = np.linalg.eigh(Ct) + Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T) + eva_s, eve_s = np.linalg.eigh(Cs) + Qs = eve_s.dot(np.sqrt(np.diag(eva_s))).dot(eve_s.T) + ts = Qs.dot(np.linalg.inv(Qt)).dot(t) + if mode == 'sym': + eva_t, eve_t = np.linalg.eigh(Ct) + Qt = eve_t.dot(np.sqrt(np.diag(eva_t))).dot(eve_t.T) + Qt_Cs_Qt = Qt.dot(Cs).dot(Qt) + eva_QtCsQt, eve_QtCsQt = np.linalg.eigh(Qt_Cs_Qt) + QtCsQt = eve_QtCsQt.dot(np.sqrt(np.diag(eva_QtCsQt))).dot(eve_QtCsQt.T) + ts = np.linalg.inv(Qt).dot(QtCsQt).dot(np.linalg.inv(Qt)).dot(t) + matched_img = ts.reshape(*target_img.transpose(2,0,1).shape).transpose(1,2,0) + matched_img += mu_s + matched_img[matched_img>1] = 1 + matched_img[matched_img<0] = 0 + return np.clip(matched_img.astype(source_img.dtype), 0, 1) + +def lab_image_stats(image): + # compute the mean and standard deviation of each channel + (l, a, b) = cv2.split(image) + (lMean, lStd) = (l.mean(), l.std()) + (aMean, aStd) = (a.mean(), a.std()) + (bMean, bStd) = (b.mean(), b.std()) + + # return the color statistics + return (lMean, lStd, aMean, aStd, bMean, bStd) + +def _scale_array(arr, clip=True): + if clip: + return np.clip(arr, 0, 255) + + mn = arr.min() + mx = arr.max() + scale_range = (max([mn, 0]), min([mx, 255])) + + if mn < scale_range[0] or mx > scale_range[1]: + return (scale_range[1] - scale_range[0]) * (arr - mn) / (mx - mn) + scale_range[0] + + return arr + +def channel_hist_match(source, template, hist_match_threshold=255, mask=None): + # Code borrowed from: + # https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x + masked_source = source + masked_template = template + + if mask is not None: + masked_source = source * mask + masked_template = template * mask + + oldshape = source.shape + source = source.ravel() + template = template.ravel() + masked_source = masked_source.ravel() + masked_template = masked_template.ravel() + s_values, bin_idx, s_counts = np.unique(source, return_inverse=True, + return_counts=True) + t_values, t_counts = np.unique(template, return_counts=True) + + s_quantiles = np.cumsum(s_counts).astype(np.float64) + s_quantiles = hist_match_threshold * s_quantiles / s_quantiles[-1] + t_quantiles = np.cumsum(t_counts).astype(np.float64) + t_quantiles = 255 * t_quantiles / t_quantiles[-1] + interp_t_values = np.interp(s_quantiles, t_quantiles, t_values) + + return interp_t_values[bin_idx].reshape(oldshape) + +def color_hist_match(src_im, tar_im, hist_match_threshold=255): + h,w,c = src_im.shape + matched_R = channel_hist_match(src_im[:,:,0], tar_im[:,:,0], hist_match_threshold, None) + matched_G = channel_hist_match(src_im[:,:,1], tar_im[:,:,1], hist_match_threshold, None) + matched_B = channel_hist_match(src_im[:,:,2], tar_im[:,:,2], hist_match_threshold, None) + + to_stack = (matched_R, matched_G, matched_B) + for i in range(3, c): + to_stack += ( src_im[:,:,i],) + + + matched = np.stack(to_stack, axis=-1).astype(src_im.dtype) + return matched + +def color_transfer_mix(img_src,img_trg): + img_src = np.clip(img_src*255.0, 0, 255).astype(np.uint8) + img_trg = np.clip(img_trg*255.0, 0, 255).astype(np.uint8) + + img_src_lab = cv2.cvtColor(img_src, cv2.COLOR_BGR2LAB) + img_trg_lab = cv2.cvtColor(img_trg, cv2.COLOR_BGR2LAB) + + rct_light = np.clip ( linear_color_transfer(img_src_lab[...,0:1].astype(np.float32)/255.0, + img_trg_lab[...,0:1].astype(np.float32)/255.0 )[...,0]*255.0, + 0, 255).astype(np.uint8) + + img_src_lab[...,0] = (np.ones_like (rct_light)*100).astype(np.uint8) + img_src_lab = cv2.cvtColor(img_src_lab, cv2.COLOR_LAB2BGR) + + img_trg_lab[...,0] = (np.ones_like (rct_light)*100).astype(np.uint8) + img_trg_lab = cv2.cvtColor(img_trg_lab, cv2.COLOR_LAB2BGR) + + img_rct = color_transfer_sot( img_src_lab.astype(np.float32), img_trg_lab.astype(np.float32) ) + img_rct = np.clip(img_rct, 0, 255).astype(np.uint8) + + img_rct = cv2.cvtColor(img_rct, cv2.COLOR_BGR2LAB) + img_rct[...,0] = rct_light + img_rct = cv2.cvtColor(img_rct, cv2.COLOR_LAB2BGR) + + + return (img_rct / 255.0).astype(np.float32) + +def color_transfer(ct_mode, img_src, img_trg): + """ + color transfer for [0,1] float32 inputs + """ + if ct_mode == 'lct': + out = linear_color_transfer (img_src, img_trg) + elif ct_mode == 'rct': + out = reinhard_color_transfer(img_src, img_trg) + elif ct_mode == 'mkl': + out = color_transfer_mkl (img_src, img_trg) + elif ct_mode == 'idt': + out = color_transfer_idt (img_src, img_trg) + elif ct_mode == 'sot': + out = color_transfer_sot (img_src, img_trg) + out = np.clip( out, 0.0, 1.0) + else: + raise ValueError(f"unknown ct_mode {ct_mode}") + return out diff --git a/core/imagelib/common.py b/core/imagelib/common.py new file mode 100644 index 0000000000000000000000000000000000000000..4219d7d37753af3457992bf19e3e8fdae7ff0283 --- /dev/null +++ b/core/imagelib/common.py @@ -0,0 +1,58 @@ +import numpy as np + +def random_crop(img, w, h): + height, width = img.shape[:2] + + h_rnd = height - h + w_rnd = width - w + + y = np.random.randint(0, h_rnd) if h_rnd > 0 else 0 + x = np.random.randint(0, w_rnd) if w_rnd > 0 else 0 + + return img[y:y+height, x:x+width] + +def normalize_channels(img, target_channels): + img_shape_len = len(img.shape) + if img_shape_len == 2: + h, w = img.shape + c = 0 + elif img_shape_len == 3: + h, w, c = img.shape + else: + raise ValueError("normalize: incorrect image dimensions.") + + if c == 0 and target_channels > 0: + img = img[...,np.newaxis] + c = 1 + + if c == 1 and target_channels > 1: + img = np.repeat (img, target_channels, -1) + c = target_channels + + if c > target_channels: + img = img[...,0:target_channels] + c = target_channels + + return img + +def cut_odd_image(img): + h, w, c = img.shape + wm, hm = w % 2, h % 2 + if wm + hm != 0: + img = img[0:h-hm,0:w-wm,:] + return img + +def overlay_alpha_image(img_target, img_source, xy_offset=(0,0) ): + (h,w,c) = img_source.shape + if c != 4: + raise ValueError("overlay_alpha_image, img_source must have 4 channels") + + x1, x2 = xy_offset[0], xy_offset[0] + w + y1, y2 = xy_offset[1], xy_offset[1] + h + + alpha_s = img_source[:, :, 3] / 255.0 + alpha_l = 1.0 - alpha_s + + for c in range(0, 3): + img_target[y1:y2, x1:x2, c] = (alpha_s * img_source[:, :, c] + + alpha_l * img_target[y1:y2, x1:x2, c]) \ No newline at end of file diff --git a/core/imagelib/draw.py b/core/imagelib/draw.py new file mode 100644 index 0000000000000000000000000000000000000000..3de1191735bc8135c06178afde810f50d95da077 --- /dev/null +++ b/core/imagelib/draw.py @@ -0,0 +1,13 @@ +import numpy as np +import cv2 + +def draw_polygon (image, points, color, thickness = 1): + points_len = len(points) + for i in range (0, points_len): + p0 = tuple( points[i] ) + p1 = tuple( points[ (i+1) % points_len] ) + cv2.line (image, p0, p1, color, thickness=thickness) + +def draw_rect(image, rect, color, thickness=1): + l,t,r,b = rect + draw_polygon (image, [ (l,t), (r,t), (r,b), (l,b ) ], color, thickness) diff --git a/core/imagelib/equalize_and_stack_square.py b/core/imagelib/equalize_and_stack_square.py new file mode 100644 index 0000000000000000000000000000000000000000..31c435a0714c0525fa6dc6bb84a685fc8102396d --- /dev/null +++ b/core/imagelib/equalize_and_stack_square.py @@ -0,0 +1,45 @@ +import numpy as np +import cv2 + +def equalize_and_stack_square (images, axis=1): + max_c = max ([ 1 if len(image.shape) == 2 else image.shape[2] for image in images ] ) + + target_wh = 99999 + for i,image in enumerate(images): + if len(image.shape) == 2: + h,w = image.shape + c = 1 + else: + h,w,c = image.shape + + if h < target_wh: + target_wh = h + + if w < target_wh: + target_wh = w + + for i,image in enumerate(images): + if len(image.shape) == 2: + h,w = image.shape + c = 1 + else: + h,w,c = image.shape + + if c < max_c: + if c == 1: + if len(image.shape) == 2: + image = np.expand_dims ( image, -1 ) + image = np.concatenate ( (image,)*max_c, -1 ) + elif c == 2: #GA + image = np.expand_dims ( image[...,0], -1 ) + image = np.concatenate ( (image,)*max_c, -1 ) + else: + image = np.concatenate ( (image, np.ones((h,w,max_c - c))), -1 ) + + if h != target_wh or w != target_wh: + image = cv2.resize ( image, (target_wh, target_wh) ) + h,w,c = image.shape + + images[i] = image + + return np.concatenate ( images, axis = 1 ) \ No newline at end of file diff --git a/core/imagelib/estimate_sharpness.py b/core/imagelib/estimate_sharpness.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b3e2dce92cc55cf7bccea633f548db0557d40d --- /dev/null +++ b/core/imagelib/estimate_sharpness.py @@ -0,0 +1,278 @@ +""" +Copyright (c) 2009-2010 Arizona Board of Regents. All Rights Reserved. + Contact: Lina Karam (karam@asu.edu) and Niranjan Narvekar (nnarveka@asu.edu) + Image, Video, and Usabilty (IVU) Lab, http://ivulab.asu.edu , Arizona State University + This copyright statement may not be removed from any file containing it or from modifications to these files. + This copyright notice must also be included in any file or product that is derived from the source files. + + Redistribution and use of this code in source and binary forms, with or without modification, are permitted provided that the + following conditions are met: + - Redistribution's of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + - Redistribution's in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the distribution. + - The Image, Video, and Usability Laboratory (IVU Lab, http://ivulab.asu.edu) is acknowledged in any publication that + reports research results using this code, copies of this code, or modifications of this code. + The code and our papers are to be cited in the bibliography as: + +N. D. Narvekar and L. J. Karam, "CPBD Sharpness Metric Software", http://ivulab.asu.edu/Quality/CPBD + +N. D. Narvekar and L. J. Karam, "A No-Reference Image Blur Metric Based on the Cumulative +Probability of Blur Detection (CPBD)," accepted and to appear in the IEEE Transactions on Image Processing, 2011. + +N. D. Narvekar and L. J. Karam, "An Improved No-Reference Sharpness Metric Based on the Probability of Blur Detection," International Workshop on Video Processing and Quality Metrics for Consumer Electronics (VPQM), January 2010, http://www.vpqm.org (pdf) + +N. D. Narvekar and L. J. Karam, "A No Reference Perceptual Quality Metric based on Cumulative Probability of Blur Detection," First International Workshop on the Quality of Multimedia Experience (QoMEX), pp. 87-91, July 2009. + + DISCLAIMER: + This software is provided by the copyright holders and contributors "as is" and any express or implied warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose are disclaimed. In no event shall the Arizona Board of Regents, Arizona State University, IVU Lab members, authors or contributors be liable for any direct, indirect, incidental, special, exemplary, or consequential damages (including, but not limited to, procurement of substitute +goods or services; loss of use, data, or profits; or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this software, even if advised of the possibility of such damage. +""" + +import numpy as np +import cv2 +from math import atan2, pi + + +def sobel(image): + # type: (numpy.ndarray) -> numpy.ndarray + """ + Find edges using the Sobel approximation to the derivatives. + + Inspired by the [Octave implementation](https://sourceforge.net/p/octave/image/ci/default/tree/inst/edge.m#l196). + """ + from skimage.filters.edges import HSOBEL_WEIGHTS + h1 = np.array(HSOBEL_WEIGHTS) + h1 /= np.sum(abs(h1)) # normalize h1 + + from scipy.ndimage import convolve + strength2 = np.square(convolve(image, h1.T)) + + # Note: https://sourceforge.net/p/octave/image/ci/default/tree/inst/edge.m#l59 + thresh2 = 2 * np.sqrt(np.mean(strength2)) + + strength2[strength2 <= thresh2] = 0 + return _simple_thinning(strength2) + + +def _simple_thinning(strength): + # type: (numpy.ndarray) -> numpy.ndarray + """ + Perform a very simple thinning. + + Inspired by the [Octave implementation](https://sourceforge.net/p/octave/image/ci/default/tree/inst/edge.m#l512). + """ + num_rows, num_cols = strength.shape + + zero_column = np.zeros((num_rows, 1)) + zero_row = np.zeros((1, num_cols)) + + x = ( + (strength > np.c_[zero_column, strength[:, :-1]]) & + (strength > np.c_[strength[:, 1:], zero_column]) + ) + + y = ( + (strength > np.r_[zero_row, strength[:-1, :]]) & + (strength > np.r_[strength[1:, :], zero_row]) + ) + + return x | y + + + + + +# threshold to characterize blocks as edge/non-edge blocks +THRESHOLD = 0.002 +# fitting parameter +BETA = 3.6 +# block size +BLOCK_HEIGHT, BLOCK_WIDTH = (64, 64) +# just noticeable widths based on the perceptual experiments +WIDTH_JNB = np.concatenate([5*np.ones(51), 3*np.ones(205)]) + + +def compute(image): + # type: (numpy.ndarray) -> float + """Compute the sharpness metric for the given data.""" + + # convert the image to double for further processing + image = image.astype(np.float64) + + # edge detection using canny and sobel canny edge detection is done to + # classify the blocks as edge or non-edge blocks and sobel edge + # detection is done for the purpose of edge width measurement. + from skimage.feature import canny + canny_edges = canny(image) + sobel_edges = sobel(image) + + # edge width calculation + marziliano_widths = marziliano_method(sobel_edges, image) + + # sharpness metric calculation + return _calculate_sharpness_metric(image, canny_edges, marziliano_widths) + + +def marziliano_method(edges, image): + # type: (numpy.ndarray, numpy.ndarray) -> numpy.ndarray + """ + Calculate the widths of the given edges. + + :return: A matrix with the same dimensions as the given image with 0's at + non-edge locations and edge-widths at the edge locations. + """ + + # `edge_widths` consists of zero and non-zero values. A zero value + # indicates that there is no edge at that position and a non-zero value + # indicates that there is an edge at that position and the value itself + # gives the edge width. + edge_widths = np.zeros(image.shape) + + # find the gradient for the image + gradient_y, gradient_x = np.gradient(image) + + # dimensions of the image + img_height, img_width = image.shape + + # holds the angle information of the edges + edge_angles = np.zeros(image.shape) + + # calculate the angle of the edges + for row in range(img_height): + for col in range(img_width): + if gradient_x[row, col] != 0: + edge_angles[row, col] = atan2(gradient_y[row, col], gradient_x[row, col]) * (180 / pi) + elif gradient_x[row, col] == 0 and gradient_y[row, col] == 0: + edge_angles[row,col] = 0 + elif gradient_x[row, col] == 0 and gradient_y[row, col] == pi/2: + edge_angles[row, col] = 90 + + + if np.any(edge_angles): + + # quantize the angle + quantized_angles = 45 * np.round(edge_angles / 45) + + for row in range(1, img_height - 1): + for col in range(1, img_width - 1): + if edges[row, col] == 1: + + # gradient angle = 180 or -180 + if quantized_angles[row, col] == 180 or quantized_angles[row, col] == -180: + for margin in range(100 + 1): + inner_border = (col - 1) - margin + outer_border = (col - 2) - margin + + # outside image or intensity increasing from left to right + if outer_border < 0 or (image[row, outer_border] - image[row, inner_border]) <= 0: + break + + width_left = margin + 1 + + for margin in range(100 + 1): + inner_border = (col + 1) + margin + outer_border = (col + 2) + margin + + # outside image or intensity increasing from left to right + if outer_border >= img_width or (image[row, outer_border] - image[row, inner_border]) >= 0: + break + + width_right = margin + 1 + + edge_widths[row, col] = width_left + width_right + + + # gradient angle = 0 + if quantized_angles[row, col] == 0: + for margin in range(100 + 1): + inner_border = (col - 1) - margin + outer_border = (col - 2) - margin + + # outside image or intensity decreasing from left to right + if outer_border < 0 or (image[row, outer_border] - image[row, inner_border]) >= 0: + break + + width_left = margin + 1 + + for margin in range(100 + 1): + inner_border = (col + 1) + margin + outer_border = (col + 2) + margin + + # outside image or intensity decreasing from left to right + if outer_border >= img_width or (image[row, outer_border] - image[row, inner_border]) <= 0: + break + + width_right = margin + 1 + + edge_widths[row, col] = width_right + width_left + + return edge_widths + + +def _calculate_sharpness_metric(image, edges, edge_widths): + # type: (numpy.array, numpy.array, numpy.array) -> numpy.float64 + + # get the size of image + img_height, img_width = image.shape + + total_num_edges = 0 + hist_pblur = np.zeros(101) + + # maximum block indices + num_blocks_vertically = int(img_height / BLOCK_HEIGHT) + num_blocks_horizontally = int(img_width / BLOCK_WIDTH) + + # loop over the blocks + for i in range(num_blocks_vertically): + for j in range(num_blocks_horizontally): + + # get the row and col indices for the block pixel positions + rows = slice(BLOCK_HEIGHT * i, BLOCK_HEIGHT * (i + 1)) + cols = slice(BLOCK_WIDTH * j, BLOCK_WIDTH * (j + 1)) + + if is_edge_block(edges[rows, cols], THRESHOLD): + block_widths = edge_widths[rows, cols] + # rotate block to simulate column-major boolean indexing + block_widths = np.rot90(np.flipud(block_widths), 3) + block_widths = block_widths[block_widths != 0] + + block_contrast = get_block_contrast(image[rows, cols]) + block_jnb = WIDTH_JNB[block_contrast] + + # calculate the probability of blur detection at the edges + # detected in the block + prob_blur_detection = 1 - np.exp(-abs(block_widths/block_jnb) ** BETA) + + # update the statistics using the block information + for probability in prob_blur_detection: + bucket = int(round(probability * 100)) + hist_pblur[bucket] += 1 + total_num_edges += 1 + + # normalize the pdf + if total_num_edges > 0: + hist_pblur = hist_pblur / total_num_edges + + # calculate the sharpness metric + return np.sum(hist_pblur[:64]) + + +def is_edge_block(block, threshold): + # type: (numpy.ndarray, float) -> bool + """Decide whether the given block is an edge block.""" + return np.count_nonzero(block) > (block.size * threshold) + + +def get_block_contrast(block): + # type: (numpy.ndarray) -> int + return int(np.max(block) - np.min(block)) + + +def estimate_sharpness(image): + if image.ndim == 3: + if image.shape[2] > 1: + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + else: + image = image[...,0] + + return compute(image) diff --git a/core/imagelib/filters.py b/core/imagelib/filters.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6957630bf478a39e0467d734a5fe1ec90ad5ed --- /dev/null +++ b/core/imagelib/filters.py @@ -0,0 +1,245 @@ +import numpy as np +from .blursharpen import LinearMotionBlur, blursharpen +import cv2 + +def apply_random_rgb_levels(img, mask=None, rnd_state=None): + if rnd_state is None: + rnd_state = np.random + np_rnd = rnd_state.rand + + inBlack = np.array([np_rnd()*0.25 , np_rnd()*0.25 , np_rnd()*0.25], dtype=np.float32) + inWhite = np.array([1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25], dtype=np.float32) + inGamma = np.array([0.5+np_rnd(), 0.5+np_rnd(), 0.5+np_rnd()], dtype=np.float32) + + outBlack = np.array([np_rnd()*0.25 , np_rnd()*0.25 , np_rnd()*0.25], dtype=np.float32) + outWhite = np.array([1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25], dtype=np.float32) + + result = np.clip( (img - inBlack) / (inWhite - inBlack), 0, 1 ) + result = ( result ** (1/inGamma) ) * (outWhite - outBlack) + outBlack + result = np.clip(result, 0, 1) + + if mask is not None: + result = img*(1-mask) + result*mask + + return result + +def apply_random_hsv_shift(img, mask=None, rnd_state=None): + if rnd_state is None: + rnd_state = np.random + + h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) + h = ( h + rnd_state.randint(360) ) % 360 + s = np.clip ( s + rnd_state.random()-0.5, 0, 1 ) + v = np.clip ( v + rnd_state.random()-0.5, 0, 1 ) + + result = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 ) + if mask is not None: + result = img*(1-mask) + result*mask + + return result + +def apply_random_sharpen( img, chance, kernel_max_size, mask=None, rnd_state=None ): + if rnd_state is None: + rnd_state = np.random + + sharp_rnd_kernel = rnd_state.randint(kernel_max_size)+1 + + result = img + if rnd_state.randint(100) < np.clip(chance, 0, 100): + if rnd_state.randint(2) == 0: + result = blursharpen(result, 1, sharp_rnd_kernel, rnd_state.randint(10) ) + else: + result = blursharpen(result, 2, sharp_rnd_kernel, rnd_state.randint(50) ) + + if mask is not None: + result = img*(1-mask) + result*mask + + return result + +def apply_random_motion_blur( img, chance, mb_max_size, mask=None, rnd_state=None ): + if rnd_state is None: + rnd_state = np.random + + mblur_rnd_kernel = rnd_state.randint(mb_max_size)+1 + mblur_rnd_deg = rnd_state.randint(360) + + result = img + if rnd_state.randint(100) < np.clip(chance, 0, 100): + result = LinearMotionBlur (result, mblur_rnd_kernel, mblur_rnd_deg ) + if mask is not None: + result = img*(1-mask) + result*mask + + return result + +def apply_random_gaussian_blur( img, chance, kernel_max_size, mask=None, rnd_state=None ): + if rnd_state is None: + rnd_state = np.random + + result = img + if rnd_state.randint(100) < np.clip(chance, 0, 100): + gblur_rnd_kernel = rnd_state.randint(kernel_max_size)*2+1 + result = cv2.GaussianBlur(result, (gblur_rnd_kernel,)*2 , 0) + if mask is not None: + result = img*(1-mask) + result*mask + + return result + +def apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_LINEAR, mask=None, rnd_state=None ): + if rnd_state is None: + rnd_state = np.random + + result = img + if rnd_state.randint(100) < np.clip(chance, 0, 100): + h,w,c = result.shape + + trg = rnd_state.rand() + rw = w - int( trg * int(w*(max_size_per/100.0)) ) + rh = h - int( trg * int(h*(max_size_per/100.0)) ) + + result = cv2.resize (result, (rw,rh), interpolation=interpolation ) + result = cv2.resize (result, (w,h), interpolation=interpolation ) + if mask is not None: + result = img*(1-mask) + result*mask + + return result + +def apply_random_nearest_resize( img, chance, max_size_per, mask=None, rnd_state=None ): + return apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_NEAREST, mask=mask, rnd_state=rnd_state ) + +def apply_random_bilinear_resize( img, chance, max_size_per, mask=None, rnd_state=None ): + return apply_random_resize( img, chance, max_size_per, interpolation=cv2.INTER_LINEAR, mask=mask, rnd_state=rnd_state ) + +def apply_random_jpeg_compress( img, chance, mask=None, rnd_state=None ): + if rnd_state is None: + rnd_state = np.random + + result = img + if rnd_state.randint(100) < np.clip(chance, 0, 100): + h,w,c = result.shape + + quality = rnd_state.randint(10,101) + + ret, result = cv2.imencode('.jpg', np.clip(img*255, 0,255).astype(np.uint8), [int(cv2.IMWRITE_JPEG_QUALITY), quality] ) + if ret == True: + result = cv2.imdecode(result, flags=cv2.IMREAD_UNCHANGED) + result = result.astype(np.float32) / 255.0 + if mask is not None: + result = img*(1-mask) + result*mask + + return result + +def apply_random_overlay_triangle( img, max_alpha, mask=None, rnd_state=None ): + if rnd_state is None: + rnd_state = np.random + + h,w,c = img.shape + pt1 = [rnd_state.randint(w), rnd_state.randint(h) ] + pt2 = [rnd_state.randint(w), rnd_state.randint(h) ] + pt3 = [rnd_state.randint(w), rnd_state.randint(h) ] + + alpha = rnd_state.uniform()*max_alpha + + tri_mask = cv2.fillPoly( np.zeros_like(img), [ np.array([pt1,pt2,pt3], np.int32) ], (alpha,)*c ) + + if rnd_state.randint(2) == 0: + result = np.clip(img+tri_mask, 0, 1) + else: + result = np.clip(img-tri_mask, 0, 1) + + if mask is not None: + result = img*(1-mask) + result*mask + + return result + +def _min_resize(x, m): + if x.shape[0] < x.shape[1]: + s0 = m + s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1])) + else: + s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0])) + s1 = m + new_max = min(s1, s0) + raw_max = min(x.shape[0], x.shape[1]) + return cv2.resize(x, (s1, s0), interpolation=cv2.INTER_LANCZOS4) + +def _d_resize(x, d, fac=1.0): + new_min = min(int(d[1] * fac), int(d[0] * fac)) + raw_min = min(x.shape[0], x.shape[1]) + if new_min < raw_min: + interpolation = cv2.INTER_AREA + else: + interpolation = cv2.INTER_LANCZOS4 + y = cv2.resize(x, (int(d[1] * fac), int(d[0] * fac)), interpolation=interpolation) + return y + +def _get_image_gradient(dist): + cols = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, 0, +1], [-2, 0, +2], [-1, 0, +1]])) + rows = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, -2, -1], [0, 0, 0], [+1, +2, +1]])) + return cols, rows + +def _generate_lighting_effects(content): + h512 = content + h256 = cv2.pyrDown(h512) + h128 = cv2.pyrDown(h256) + h64 = cv2.pyrDown(h128) + h32 = cv2.pyrDown(h64) + h16 = cv2.pyrDown(h32) + c512, r512 = _get_image_gradient(h512) + c256, r256 = _get_image_gradient(h256) + c128, r128 = _get_image_gradient(h128) + c64, r64 = _get_image_gradient(h64) + c32, r32 = _get_image_gradient(h32) + c16, r16 = _get_image_gradient(h16) + c = c16 + c = _d_resize(cv2.pyrUp(c), c32.shape) * 4.0 + c32 + c = _d_resize(cv2.pyrUp(c), c64.shape) * 4.0 + c64 + c = _d_resize(cv2.pyrUp(c), c128.shape) * 4.0 + c128 + c = _d_resize(cv2.pyrUp(c), c256.shape) * 4.0 + c256 + c = _d_resize(cv2.pyrUp(c), c512.shape) * 4.0 + c512 + r = r16 + r = _d_resize(cv2.pyrUp(r), r32.shape) * 4.0 + r32 + r = _d_resize(cv2.pyrUp(r), r64.shape) * 4.0 + r64 + r = _d_resize(cv2.pyrUp(r), r128.shape) * 4.0 + r128 + r = _d_resize(cv2.pyrUp(r), r256.shape) * 4.0 + r256 + r = _d_resize(cv2.pyrUp(r), r512.shape) * 4.0 + r512 + coarse_effect_cols = c + coarse_effect_rows = r + EPS = 1e-10 + + max_effect = np.max((coarse_effect_cols**2 + coarse_effect_rows**2)**0.5, axis=0, keepdims=True, ).max(1, keepdims=True) + coarse_effect_cols = (coarse_effect_cols + EPS) / (max_effect + EPS) + coarse_effect_rows = (coarse_effect_rows + EPS) / (max_effect + EPS) + + return np.stack([ np.zeros_like(coarse_effect_rows), coarse_effect_rows, coarse_effect_cols], axis=-1) + +def apply_random_relight(img, mask=None, rnd_state=None): + if rnd_state is None: + rnd_state = np.random + + def_img = img + + if rnd_state.randint(2) == 0: + light_pos_y = 1.0 if rnd_state.randint(2) == 0 else -1.0 + light_pos_x = rnd_state.uniform()*2-1.0 + else: + light_pos_y = rnd_state.uniform()*2-1.0 + light_pos_x = 1.0 if rnd_state.randint(2) == 0 else -1.0 + + light_source_height = 0.3*rnd_state.uniform()*0.7 + light_intensity = 1.0+rnd_state.uniform() + ambient_intensity = 0.5 + + light_source_location = np.array([[[light_source_height, light_pos_y, light_pos_x ]]], dtype=np.float32) + light_source_direction = light_source_location / np.sqrt(np.sum(np.square(light_source_location))) + + lighting_effect = _generate_lighting_effects(img) + lighting_effect = np.sum(lighting_effect * light_source_direction, axis=-1).clip(0, 1) + lighting_effect = np.mean(lighting_effect, axis=-1, keepdims=True) + + result = def_img * (ambient_intensity + lighting_effect * light_intensity) #light_source_color + result = np.clip(result, 0, 1) + + if mask is not None: + result = def_img*(1-mask) + result*mask + + return result \ No newline at end of file diff --git a/core/imagelib/morph.py b/core/imagelib/morph.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa5114c006deaf2c7a2825ab6f5b951569e19e7 --- /dev/null +++ b/core/imagelib/morph.py @@ -0,0 +1,37 @@ +import numpy as np +import cv2 +from scipy.spatial import Delaunay + + +def applyAffineTransform(src, srcTri, dstTri, size) : + warpMat = cv2.getAffineTransform( np.float32(srcTri), np.float32(dstTri) ) + return cv2.warpAffine( src, warpMat, (size[0], size[1]), None, flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101 ) + +def morphTriangle(dst_img, src_img, st, dt) : + (h,w,c) = dst_img.shape + sr = np.array( cv2.boundingRect(np.float32(st)) ) + dr = np.array( cv2.boundingRect(np.float32(dt)) ) + sRect = st - sr[0:2] + dRect = dt - dr[0:2] + d_mask = np.zeros((dr[3], dr[2], c), dtype = np.float32) + cv2.fillConvexPoly(d_mask, np.int32(dRect), (1.0,)*c, 8, 0); + imgRect = src_img[sr[1]:sr[1] + sr[3], sr[0]:sr[0] + sr[2]] + size = (dr[2], dr[3]) + warpImage1 = applyAffineTransform(imgRect, sRect, dRect, size) + + if c == 1: + warpImage1 = np.expand_dims( warpImage1, -1 ) + + dst_img[dr[1]:dr[1]+dr[3], dr[0]:dr[0]+dr[2]] = dst_img[dr[1]:dr[1]+dr[3], dr[0]:dr[0]+dr[2]]*(1-d_mask) + warpImage1 * d_mask + +def morph_by_points (image, sp, dp): + if sp.shape != dp.shape: + raise ValueError ('morph_by_points() sp.shape != dp.shape') + (h,w,c) = image.shape + + result_image = np.zeros(image.shape, dtype = image.dtype) + + for tri in Delaunay(dp).simplices: + morphTriangle(result_image, image, sp[tri], dp[tri]) + + return result_image \ No newline at end of file diff --git a/core/imagelib/reduce_colors.py b/core/imagelib/reduce_colors.py new file mode 100644 index 0000000000000000000000000000000000000000..961f00ddf07886227154034b97fccfabf08a205d --- /dev/null +++ b/core/imagelib/reduce_colors.py @@ -0,0 +1,14 @@ +import numpy as np +import cv2 +from PIL import Image + +#n_colors = [0..256] +def reduce_colors (img_bgr, n_colors): + img_rgb = (img_bgr[...,::-1] * 255.0).astype(np.uint8) + img_rgb_pil = Image.fromarray(img_rgb) + img_rgb_pil_p = img_rgb_pil.convert('P', palette=Image.ADAPTIVE, colors=n_colors) + + img_rgb_p = img_rgb_pil_p.convert('RGB') + img_bgr = cv2.cvtColor( np.array(img_rgb_p, dtype=np.float32) / 255.0, cv2.COLOR_RGB2BGR ) + + return img_bgr diff --git a/core/imagelib/sd/__init__.py b/core/imagelib/sd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1cddc19473acb301104519b5eacf01f7f9afa42b --- /dev/null +++ b/core/imagelib/sd/__init__.py @@ -0,0 +1,2 @@ +from .draw import circle_faded, random_circle_faded, bezier, random_bezier_split_faded, random_faded +from .calc import * \ No newline at end of file diff --git a/core/imagelib/sd/calc.py b/core/imagelib/sd/calc.py new file mode 100644 index 0000000000000000000000000000000000000000..2304e6645f8ab1522906c1f436f880370ee1e40a --- /dev/null +++ b/core/imagelib/sd/calc.py @@ -0,0 +1,25 @@ +import numpy as np +import numpy.linalg as npla + +def dist_to_edges(pts, pt, is_closed=False): + """ + returns array of dist from pt to edge and projection pt to edges + """ + if is_closed: + a = pts + b = np.concatenate( (pts[1:,:], pts[0:1,:]), axis=0 ) + else: + a = pts[:-1,:] + b = pts[1:,:] + + pa = pt-a + ba = b-a + + div = np.einsum('ij,ij->i', ba, ba) + div[div==0]=1 + h = np.clip( np.einsum('ij,ij->i', pa, ba) / div, 0, 1 ) + + x = npla.norm ( pa - ba*h[...,None], axis=1 ) + + return x, a+ba*h[...,None] + diff --git a/core/imagelib/sd/draw.py b/core/imagelib/sd/draw.py new file mode 100644 index 0000000000000000000000000000000000000000..711ad33ad1b6bae31fe66b1adcb7a8e808c446e4 --- /dev/null +++ b/core/imagelib/sd/draw.py @@ -0,0 +1,200 @@ +""" +Signed distance drawing functions using numpy. +""" +import math + +import numpy as np +from numpy import linalg as npla + + +def vector2_dot(a,b): + return a[...,0]*b[...,0]+a[...,1]*b[...,1] + +def vector2_dot2(a): + return a[...,0]*a[...,0]+a[...,1]*a[...,1] + +def vector2_cross(a,b): + return a[...,0]*b[...,1]-a[...,1]*b[...,0] + + +def circle_faded( wh, center, fade_dists ): + """ + returns drawn circle in [h,w,1] output range [0..1.0] float32 + + wh = [w,h] resolution + center = [x,y] center of circle + fade_dists = [fade_start, fade_end] fade values + """ + w,h = wh + + pts = np.empty( (h,w,2), dtype=np.float32 ) + pts[...,0] = np.arange(w)[:,None] + pts[...,1] = np.arange(h)[None,:] + + pts = pts.reshape ( (h*w, -1) ) + + pts_dists = np.abs ( npla.norm(pts-center, axis=-1) ) + + if fade_dists[1] == 0: + fade_dists[1] = 1 + + pts_dists = ( pts_dists - fade_dists[0] ) / fade_dists[1] + + pts_dists = np.clip( 1-pts_dists, 0, 1) + + return pts_dists.reshape ( (h,w,1) ).astype(np.float32) + + +def bezier( wh, A, B, C ): + """ + returns drawn bezier in [h,w,1] output range float32, + every pixel contains signed distance to bezier line + + wh [w,h] resolution + A,B,C points [x,y] + """ + + width,height = wh + + A = np.float32(A) + B = np.float32(B) + C = np.float32(C) + + + pos = np.empty( (height,width,2), dtype=np.float32 ) + pos[...,0] = np.arange(width)[:,None] + pos[...,1] = np.arange(height)[None,:] + + + a = B-A + b = A - 2.0*B + C + c = a * 2.0 + d = A - pos + + b_dot = vector2_dot(b,b) + if b_dot == 0.0: + return np.zeros( (height,width), dtype=np.float32 ) + + kk = 1.0 / b_dot + + kx = kk * vector2_dot(a,b) + ky = kk * (2.0*vector2_dot(a,a)+vector2_dot(d,b))/3.0; + kz = kk * vector2_dot(d,a); + + res = 0.0; + sgn = 0.0; + + p = ky - kx*kx; + + p3 = p*p*p; + q = kx*(2.0*kx*kx - 3.0*ky) + kz; + h = q*q + 4.0*p3; + + hp_sel = h >= 0.0 + + hp_p = h[hp_sel] + hp_p = np.sqrt(hp_p) + + hp_x = ( np.stack( (hp_p,-hp_p), -1) -q[hp_sel,None] ) / 2.0 + hp_uv = np.sign(hp_x) * np.power( np.abs(hp_x), [1.0/3.0, 1.0/3.0] ) + hp_t = np.clip( hp_uv[...,0] + hp_uv[...,1] - kx, 0.0, 1.0 ) + + hp_t = hp_t[...,None] + hp_q = d[hp_sel]+(c+b*hp_t)*hp_t + hp_res = vector2_dot2(hp_q) + hp_sgn = vector2_cross(c+2.0*b*hp_t,hp_q) + + hl_sel = h < 0.0 + + hl_q = q[hl_sel] + hl_p = p[hl_sel] + hl_z = np.sqrt(-hl_p) + hl_v = np.arccos( hl_q / (hl_p*hl_z*2.0)) / 3.0 + + hl_m = np.cos(hl_v) + hl_n = np.sin(hl_v)*1.732050808; + + hl_t = np.clip( np.stack( (hl_m+hl_m,-hl_n-hl_m,hl_n-hl_m), -1)*hl_z[...,None]-kx, 0.0, 1.0 ); + + hl_d = d[hl_sel] + + hl_qx = hl_d+(c+b*hl_t[...,0:1])*hl_t[...,0:1] + + hl_dx = vector2_dot2(hl_qx) + hl_sx = vector2_cross(c+2.0*b*hl_t[...,0:1], hl_qx) + + hl_qy = hl_d+(c+b*hl_t[...,1:2])*hl_t[...,1:2] + hl_dy = vector2_dot2(hl_qy) + hl_sy = vector2_cross(c+2.0*b*hl_t[...,1:2],hl_qy); + + hl_dx_l_dy = hl_dx=hl_dy + + hl_res = np.empty_like(hl_dx) + hl_res[hl_dx_l_dy] = hl_dx[hl_dx_l_dy] + hl_res[hl_dx_ge_dy] = hl_dy[hl_dx_ge_dy] + + hl_sgn = np.empty_like(hl_sx) + hl_sgn[hl_dx_l_dy] = hl_sx[hl_dx_l_dy] + hl_sgn[hl_dx_ge_dy] = hl_sy[hl_dx_ge_dy] + + res = np.empty( (height, width), np.float32 ) + res[hp_sel] = hp_res + res[hl_sel] = hl_res + + sgn = np.empty( (height, width), np.float32 ) + sgn[hp_sel] = hp_sgn + sgn[hl_sel] = hl_sgn + + sgn = np.sign(sgn) + res = np.sqrt(res)*sgn + + return res[...,None] + +def random_faded(wh): + """ + apply one of them: + random_circle_faded + random_bezier_split_faded + """ + rnd = np.random.randint(2) + if rnd == 0: + return random_circle_faded(wh) + elif rnd == 1: + return random_bezier_split_faded(wh) + +def random_circle_faded ( wh, rnd_state=None ): + if rnd_state is None: + rnd_state = np.random + + w,h = wh + wh_max = max(w,h) + fade_start = rnd_state.randint(wh_max) + fade_end = fade_start + rnd_state.randint(wh_max- fade_start) + + return circle_faded (wh, [ rnd_state.randint(h), rnd_state.randint(w) ], + [fade_start, fade_end] ) + +def random_bezier_split_faded( wh ): + width, height = wh + + degA = np.random.randint(360) + degB = np.random.randint(360) + degC = np.random.randint(360) + + deg_2_rad = math.pi / 180.0 + + center = np.float32([width / 2.0, height / 2.0]) + + radius = max(width, height) + + A = center + radius*np.float32([ math.sin( degA * deg_2_rad), math.cos( degA * deg_2_rad) ] ) + B = center + np.random.randint(radius)*np.float32([ math.sin( degB * deg_2_rad), math.cos( degB * deg_2_rad) ] ) + C = center + radius*np.float32([ math.sin( degC * deg_2_rad), math.cos( degC * deg_2_rad) ] ) + + x = bezier( (width,height), A, B, C ) + + x = x / (1+np.random.randint(radius)) + 0.5 + + x = np.clip(x, 0, 1) + return x diff --git a/core/imagelib/text.py b/core/imagelib/text.py new file mode 100644 index 0000000000000000000000000000000000000000..5bcf68eedaf05e9de0bb17850775ba0df7c56ea4 --- /dev/null +++ b/core/imagelib/text.py @@ -0,0 +1,64 @@ +import localization +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +pil_fonts = {} +def _get_pil_font (font, size): + global pil_fonts + try: + font_str_id = '%s_%d' % (font, size) + if font_str_id not in pil_fonts.keys(): + pil_fonts[font_str_id] = ImageFont.truetype(font + ".ttf", size=size, encoding="unic") + pil_font = pil_fonts[font_str_id] + return pil_font + except: + return ImageFont.load_default() + +def get_text_image( shape, text, color=(1,1,1), border=0.2, font=None): + h,w,c = shape + try: + pil_font = _get_pil_font( localization.get_default_ttf_font_name() , h-2) + + canvas = Image.new('RGB', (w,h) , (0,0,0) ) + draw = ImageDraw.Draw(canvas) + offset = ( 0, 0) + draw.text(offset, text, font=pil_font, fill=tuple((np.array(color)*255).astype(np.int)) ) + + result = np.asarray(canvas) / 255 + + if c > 3: + result = np.concatenate ( (result, np.ones ((h,w,c-3)) ), axis=-1 ) + elif c < 3: + result = result[...,0:c] + return result + except: + return np.zeros ( (h,w,c) ) + +def draw_text( image, rect, text, color=(1,1,1), border=0.2, font=None): + h,w,c = image.shape + + l,t,r,b = rect + l = np.clip (l, 0, w-1) + r = np.clip (r, 0, w-1) + t = np.clip (t, 0, h-1) + b = np.clip (b, 0, h-1) + + image[t:b, l:r] += get_text_image ( (b-t,r-l,c) , text, color, border, font ) + + +def draw_text_lines (image, rect, text_lines, color=(1,1,1), border=0.2, font=None): + text_lines_len = len(text_lines) + if text_lines_len == 0: + return + + l,t,r,b = rect + h = b-t + h_per_line = h // text_lines_len + + for i in range(0, text_lines_len): + draw_text (image, (l, i*h_per_line, r, (i+1)*h_per_line), text_lines[i], color, border, font) + +def get_draw_text_lines ( image, rect, text_lines, color=(1,1,1), border=0.2, font=None): + image = np.zeros ( image.shape, dtype=np.float ) + draw_text_lines ( image, rect, text_lines, color, border, font) + return image diff --git a/core/imagelib/warp.py b/core/imagelib/warp.py new file mode 100644 index 0000000000000000000000000000000000000000..af77579b623cafb3ab240282123b7becf2447944 --- /dev/null +++ b/core/imagelib/warp.py @@ -0,0 +1,181 @@ +import numpy as np +import numpy.linalg as npla +import cv2 +from core import randomex + +def mls_rigid_deformation(vy, vx, src_pts, dst_pts, alpha=1.0, eps=1e-8): + dst_pts = dst_pts[..., ::-1].astype(np.int16) + src_pts = src_pts[..., ::-1].astype(np.int16) + + src_pts, dst_pts = dst_pts, src_pts + + grow = vx.shape[0] + gcol = vx.shape[1] + ctrls = src_pts.shape[0] + + reshaped_p = src_pts.reshape(ctrls, 2, 1, 1) + reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) + + w = 1.0 / (np.sum((reshaped_p - reshaped_v).astype(np.float32) ** 2, axis=1) + eps) ** alpha + w /= np.sum(w, axis=0, keepdims=True) + + pstar = np.zeros((2, grow, gcol), np.float32) + for i in range(ctrls): + pstar += w[i] * reshaped_p[i] + + vpstar = reshaped_v - pstar + + reshaped_mul_right = np.concatenate((vpstar[:,None,...], + np.concatenate((vpstar[1:2,None,...],-vpstar[0:1,None,...]), 0) + ), axis=1).transpose(2, 3, 0, 1) + + reshaped_q = dst_pts.reshape((ctrls, 2, 1, 1)) + + qstar = np.zeros((2, grow, gcol), np.float32) + for i in range(ctrls): + qstar += w[i] * reshaped_q[i] + + temp = np.zeros((grow, gcol, 2), np.float32) + for i in range(ctrls): + phat = reshaped_p[i] - pstar + qhat = reshaped_q[i] - qstar + + temp += np.matmul(qhat.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1), + + np.matmul( ( w[None, i:i+1,...] * + np.concatenate((phat.reshape(1, 2, grow, gcol), + np.concatenate( (phat[None,1:2], -phat[None,0:1]), 1 )), 0) + ).transpose(2, 3, 0, 1), reshaped_mul_right + ) + ).reshape(grow, gcol, 2) + + temp = temp.transpose(2, 0, 1) + + normed_temp = np.linalg.norm(temp, axis=0, keepdims=True) + normed_vpstar = np.linalg.norm(vpstar, axis=0, keepdims=True) + nan_mask = normed_temp[0]==0 + + transformers = np.true_divide(temp, normed_temp, out=np.zeros_like(temp), where= ~nan_mask) * normed_vpstar + qstar + nan_mask_flat = np.flatnonzero(nan_mask) + nan_mask_anti_flat = np.flatnonzero(~nan_mask) + + transformers[0][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[0][~nan_mask]) + transformers[1][nan_mask] = np.interp(nan_mask_flat, nan_mask_anti_flat, transformers[1][~nan_mask]) + + return transformers + +def gen_pts(W, H, rnd_state=None): + + if rnd_state is None: + rnd_state = np.random + + min_pts, max_pts = 4, 8 + n_pts = rnd_state.randint(min_pts, max_pts) + + min_radius_per = 0.00 + max_radius_per = 0.10 + pts = [] + + for i in range(n_pts): + while True: + x, y = rnd_state.randint(W), rnd_state.randint(H) + rad = min_radius_per + rnd_state.rand()*(max_radius_per-min_radius_per) + + intersect = False + for px,py,prad,_,_ in pts: + + dist = npla.norm([x-px, y-py]) + if dist <= (rad+prad)*2: + intersect = True + break + if intersect: + continue + + angle = rnd_state.rand()*(2*np.pi) + x2 = int(x+np.cos(angle)*W*rad) + y2 = int(y+np.sin(angle)*H*rad) + + break + pts.append( (x,y,rad, x2,y2) ) + + pts1 = np.array( [ [pt[0],pt[1]] for pt in pts ] ) + pts2 = np.array( [ [pt[-2],pt[-1]] for pt in pts ] ) + + return pts1, pts2 + + +def gen_warp_params (w, flip=False, rotation_range=[-10,10], scale_range=[-0.5, 0.5], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05], rnd_state=None, warp_rnd_state=None ): + if rnd_state is None: + rnd_state = np.random + if warp_rnd_state is None: + warp_rnd_state = np.random + rw = None + if w < 64: + rw = w + w = 64 + + rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] ) + scale = rnd_state.uniform(1 +scale_range[0], 1 +scale_range[1]) + tx = rnd_state.uniform( tx_range[0], tx_range[1] ) + ty = rnd_state.uniform( ty_range[0], ty_range[1] ) + p_flip = flip and rnd_state.randint(10) < 4 + + #random warp V1 + cell_size = [ w // (2**i) for i in range(1,4) ] [ warp_rnd_state.randint(3) ] + cell_count = w // cell_size + 1 + grid_points = np.linspace( 0, w, cell_count) + mapx = np.broadcast_to(grid_points, (cell_count, cell_count)).copy() + mapy = mapx.T + mapx[1:-1,1:-1] = mapx[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2), rnd_state=warp_rnd_state )*(cell_size*0.24) + mapy[1:-1,1:-1] = mapy[1:-1,1:-1] + randomex.random_normal( size=(cell_count-2, cell_count-2), rnd_state=warp_rnd_state )*(cell_size*0.24) + half_cell_size = cell_size // 2 + mapx = cv2.resize(mapx, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32) + mapy = cv2.resize(mapy, (w+cell_size,)*2 )[half_cell_size:-half_cell_size,half_cell_size:-half_cell_size].astype(np.float32) + ############## + + # random warp V2 + # pts1, pts2 = gen_pts(w, w, rnd_state) + # gridX = np.arange(w, dtype=np.int16) + # gridY = np.arange(w, dtype=np.int16) + # vy, vx = np.meshgrid(gridX, gridY) + # drigid = mls_rigid_deformation(vy, vx, pts1, pts2) + # mapy, mapx = drigid.astype(np.float32) + ################ + + #random transform + random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale) + random_transform_mat[:, 2] += (tx*w, ty*w) + + params = dict() + params['mapx'] = mapx + params['mapy'] = mapy + params['rmat'] = random_transform_mat + u_mat = random_transform_mat.copy() + u_mat[:,2] /= w + params['umat'] = u_mat + params['w'] = w + params['rw'] = rw + params['flip'] = p_flip + + return params + +def warp_by_params (params, img, can_warp, can_transform, can_flip, border_replicate, cv2_inter=cv2.INTER_CUBIC): + rw = params['rw'] + + if (can_warp or can_transform) and rw is not None: + img = cv2.resize(img, (64,64), interpolation=cv2_inter) + + if can_warp: + img = cv2.remap(img, params['mapx'], params['mapy'], cv2_inter ) + if can_transform: + img = cv2.warpAffine( img, params['rmat'], (params['w'], params['w']), borderMode=(cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT), flags=cv2_inter ) + + + if (can_warp or can_transform) and rw is not None: + img = cv2.resize(img, (rw,rw), interpolation=cv2_inter) + + if len(img.shape) == 2: + img = img[...,None] + if can_flip and params['flip']: + img = img[:,::-1,...] + return img \ No newline at end of file diff --git a/core/interact/__init__.py b/core/interact/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db40e4f25aa8b064e1755afaa30f67a0aa4619d5 --- /dev/null +++ b/core/interact/__init__.py @@ -0,0 +1 @@ +from .interact import interact diff --git a/core/interact/interact.py b/core/interact/interact.py new file mode 100644 index 0000000000000000000000000000000000000000..1a8214a44190d71add6aa565dccded7becf7b79f --- /dev/null +++ b/core/interact/interact.py @@ -0,0 +1,581 @@ +import multiprocessing +import os +import sys +import threading +import time +import types + +import colorama +import cv2 +import numpy as np +from tqdm import tqdm + +from core import stdex + +try: + import IPython #if success we are in colab + from IPython.display import display, clear_output + import PIL + import matplotlib.pyplot as plt + is_colab = True +except: + is_colab = False + +yn_str = {True:'y',False:'n'} + +class InteractBase(object): + EVENT_LBUTTONDOWN = 1 + EVENT_LBUTTONUP = 2 + EVENT_MBUTTONDOWN = 3 + EVENT_MBUTTONUP = 4 + EVENT_RBUTTONDOWN = 5 + EVENT_RBUTTONUP = 6 + EVENT_MOUSEWHEEL = 10 + + def __init__(self): + self.named_windows = {} + self.capture_mouse_windows = {} + self.capture_keys_windows = {} + self.mouse_events = {} + self.key_events = {} + self.pg_bar = None + self.focus_wnd_name = None + self.error_log_line_prefix = '/!\\ ' + + self.process_messages_callbacks = {} + + def is_support_windows(self): + return False + + def is_colab(self): + return False + + def on_destroy_all_windows(self): + raise NotImplemented + + def on_create_window (self, wnd_name): + raise NotImplemented + + def on_destroy_window (self, wnd_name): + raise NotImplemented + + def on_show_image (self, wnd_name, img): + raise NotImplemented + + def on_capture_mouse (self, wnd_name): + raise NotImplemented + + def on_capture_keys (self, wnd_name): + raise NotImplemented + + def on_process_messages(self, sleep_time=0): + raise NotImplemented + + def on_wait_any_key(self): + raise NotImplemented + + def log_info(self, msg, end='\n'): + if self.pg_bar is not None: + print ("\n") + print (msg, end=end) + + def log_err(self, msg, end='\n'): + if self.pg_bar is not None: + print ("\n") + print (f'{self.error_log_line_prefix}{msg}', end=end) + + def named_window(self, wnd_name): + if wnd_name not in self.named_windows: + #we will show window only on first show_image + self.named_windows[wnd_name] = 0 + self.focus_wnd_name = wnd_name + else: print("named_window: ", wnd_name, " already created.") + + def destroy_all_windows(self): + if len( self.named_windows ) != 0: + self.on_destroy_all_windows() + self.named_windows = {} + self.capture_mouse_windows = {} + self.capture_keys_windows = {} + self.mouse_events = {} + self.key_events = {} + self.focus_wnd_name = None + + def destroy_window(self, wnd_name): + if wnd_name in self.named_windows: + self.on_destroy_window(wnd_name) + self.named_windows.pop(wnd_name) + + if wnd_name == self.focus_wnd_name: + self.focus_wnd_name = list(self.named_windows.keys())[-1] if len( self.named_windows ) != 0 else None + + if wnd_name in self.capture_mouse_windows: + self.capture_mouse_windows.pop(wnd_name) + + if wnd_name in self.capture_keys_windows: + self.capture_keys_windows.pop(wnd_name) + + if wnd_name in self.mouse_events: + self.mouse_events.pop(wnd_name) + + if wnd_name in self.key_events: + self.key_events.pop(wnd_name) + + def show_image(self, wnd_name, img): + if wnd_name in self.named_windows: + if self.named_windows[wnd_name] == 0: + self.named_windows[wnd_name] = 1 + self.on_create_window(wnd_name) + if wnd_name in self.capture_mouse_windows: + self.capture_mouse(wnd_name) + self.on_show_image(wnd_name,img) + else: print("show_image: named_window ", wnd_name, " not found.") + + def capture_mouse(self, wnd_name): + if wnd_name in self.named_windows: + self.capture_mouse_windows[wnd_name] = True + if self.named_windows[wnd_name] == 1: + self.on_capture_mouse(wnd_name) + else: print("capture_mouse: named_window ", wnd_name, " not found.") + + def capture_keys(self, wnd_name): + if wnd_name in self.named_windows: + if wnd_name not in self.capture_keys_windows: + self.capture_keys_windows[wnd_name] = True + self.on_capture_keys(wnd_name) + else: print("capture_keys: already set for window ", wnd_name) + else: print("capture_keys: named_window ", wnd_name, " not found.") + + def progress_bar(self, desc, total, leave=True, initial=0): + if self.pg_bar is None: + self.pg_bar = tqdm( total=total, desc=desc, leave=leave, ascii=True, initial=initial ) + else: print("progress_bar: already set.") + + def progress_bar_inc(self, c): + if self.pg_bar is not None: + self.pg_bar.n += c + self.pg_bar.refresh() + else: print("progress_bar not set.") + + def progress_bar_close(self): + if self.pg_bar is not None: + self.pg_bar.close() + self.pg_bar = None + else: print("progress_bar not set.") + + def progress_bar_generator(self, data, desc=None, leave=True, initial=0): + self.pg_bar = tqdm( data, desc=desc, leave=leave, ascii=True, initial=initial ) + for x in self.pg_bar: + yield x + self.pg_bar.close() + self.pg_bar = None + + def add_process_messages_callback(self, func ): + tid = threading.get_ident() + callbacks = self.process_messages_callbacks.get(tid, None) + if callbacks is None: + callbacks = [] + self.process_messages_callbacks[tid] = callbacks + + callbacks.append ( func ) + + def process_messages(self, sleep_time=0): + callbacks = self.process_messages_callbacks.get(threading.get_ident(), None) + if callbacks is not None: + for func in callbacks: + func() + + self.on_process_messages(sleep_time) + + def wait_any_key(self): + self.on_wait_any_key() + + def add_mouse_event(self, wnd_name, x, y, ev, flags): + if wnd_name not in self.mouse_events: + self.mouse_events[wnd_name] = [] + self.mouse_events[wnd_name] += [ (x, y, ev, flags) ] + + def add_key_event(self, wnd_name, ord_key, ctrl_pressed, alt_pressed, shift_pressed): + if wnd_name not in self.key_events: + self.key_events[wnd_name] = [] + self.key_events[wnd_name] += [ (ord_key, chr(ord_key) if ord_key <= 255 else chr(0), ctrl_pressed, alt_pressed, shift_pressed) ] + + def get_mouse_events(self, wnd_name): + ar = self.mouse_events.get(wnd_name, []) + self.mouse_events[wnd_name] = [] + return ar + + def get_key_events(self, wnd_name): + ar = self.key_events.get(wnd_name, []) + self.key_events[wnd_name] = [] + return ar + + def input(self, s): + return input(s) + + def input_number(self, s, default_value, valid_list=None, show_default_value=True, add_info=None, help_message=None): + if show_default_value and default_value is not None: + s = f"[{default_value}] {s}" + + if add_info is not None or \ + help_message is not None: + s += " (" + + if add_info is not None: + s += f" {add_info}" + if help_message is not None: + s += " ?:help" + + if add_info is not None or \ + help_message is not None: + s += " )" + + s += " : " + + while True: + try: + inp = input(s) + if len(inp) == 0: + result = default_value + break + + if help_message is not None and inp == '?': + print (help_message) + continue + + i = float(inp) + if (valid_list is not None) and (i not in valid_list): + result = default_value + break + result = i + break + except: + result = default_value + break + + print(result) + return result + + def input_int(self, s, default_value, valid_range=None, valid_list=None, add_info=None, show_default_value=True, help_message=None): + if show_default_value: + if len(s) != 0: + s = f"[{default_value}] {s}" + else: + s = f"[{default_value}]" + + if add_info is not None or \ + valid_range is not None or \ + help_message is not None: + s += " (" + + if valid_range is not None: + s += f" {valid_range[0]}-{valid_range[1]}" + + if add_info is not None: + s += f" {add_info}" + + if help_message is not None: + s += " ?:help" + + if add_info is not None or \ + valid_range is not None or \ + help_message is not None: + s += " )" + + s += " : " + + while True: + try: + inp = input(s) + if len(inp) == 0: + raise ValueError("") + + if help_message is not None and inp == '?': + print (help_message) + continue + + i = int(inp) + if valid_range is not None: + i = int(np.clip(i, valid_range[0], valid_range[1])) + + if (valid_list is not None) and (i not in valid_list): + i = default_value + + result = i + break + except: + result = default_value + break + print (result) + return result + + def input_bool(self, s, default_value, help_message=None): + s = f"[{yn_str[default_value]}] {s} ( y/n" + + if help_message is not None: + s += " ?:help" + s += " ) : " + + while True: + try: + inp = input(s) + if len(inp) == 0: + raise ValueError("") + + if help_message is not None and inp == '?': + print (help_message) + continue + + return bool ( {"y":True,"n":False}.get(inp.lower(), default_value) ) + except: + print ( "y" if default_value else "n" ) + return default_value + + def input_str(self, s, default_value=None, valid_list=None, show_default_value=True, help_message=None): + if show_default_value and default_value is not None: + s = f"[{default_value}] {s}" + + if valid_list is not None or \ + help_message is not None: + s += " (" + + if valid_list is not None: + s += " " + "/".join(valid_list) + + if help_message is not None: + s += " ?:help" + + if valid_list is not None or \ + help_message is not None: + s += " )" + + s += " : " + + + while True: + try: + inp = input(s) + + if len(inp) == 0: + if default_value is None: + print("") + return None + result = default_value + break + + if help_message is not None and inp == '?': + print(help_message) + continue + + if valid_list is not None: + if inp.lower() in valid_list: + result = inp.lower() + break + if inp in valid_list: + result = inp + break + continue + + result = inp + break + except: + result = default_value + break + + print(result) + return result + + def input_process(self, stdin_fd, sq, str): + sys.stdin = os.fdopen(stdin_fd) + try: + inp = input (str) + sq.put (True) + except: + sq.put (False) + + def input_in_time (self, str, max_time_sec): + sq = multiprocessing.Queue() + p = multiprocessing.Process(target=self.input_process, args=( sys.stdin.fileno(), sq, str)) + p.daemon = True + p.start() + t = time.time() + inp = False + while True: + if not sq.empty(): + inp = sq.get() + break + if time.time() - t > max_time_sec: + break + + + p.terminate() + p.join() + + old_stdin = sys.stdin + sys.stdin = os.fdopen( os.dup(sys.stdin.fileno()) ) + old_stdin.close() + return inp + + def input_process_skip_pending(self, stdin_fd): + sys.stdin = os.fdopen(stdin_fd) + while True: + try: + if sys.stdin.isatty(): + sys.stdin.read() + except: + pass + + def input_skip_pending(self): + if is_colab: + # currently it does not work on Colab + return + """ + skips unnecessary inputs between the dialogs + """ + p = multiprocessing.Process(target=self.input_process_skip_pending, args=( sys.stdin.fileno(), )) + p.daemon = True + p.start() + time.sleep(0.5) + p.terminate() + p.join() + sys.stdin = os.fdopen( sys.stdin.fileno() ) + + +class InteractDesktop(InteractBase): + def __init__(self): + colorama.init() + super().__init__() + + def color_red(self): + pass + + + def is_support_windows(self): + return True + + def on_destroy_all_windows(self): + cv2.destroyAllWindows() + + def on_create_window (self, wnd_name): + cv2.namedWindow(wnd_name) + + def on_destroy_window (self, wnd_name): + cv2.destroyWindow(wnd_name) + + def on_show_image (self, wnd_name, img): + cv2.imshow (wnd_name, img) + + def on_capture_mouse (self, wnd_name): + self.last_xy = (0,0) + + def onMouse(event, x, y, flags, param): + (inst, wnd_name) = param + if event == cv2.EVENT_LBUTTONDOWN: ev = InteractBase.EVENT_LBUTTONDOWN + elif event == cv2.EVENT_LBUTTONUP: ev = InteractBase.EVENT_LBUTTONUP + elif event == cv2.EVENT_RBUTTONDOWN: ev = InteractBase.EVENT_RBUTTONDOWN + elif event == cv2.EVENT_RBUTTONUP: ev = InteractBase.EVENT_RBUTTONUP + elif event == cv2.EVENT_MBUTTONDOWN: ev = InteractBase.EVENT_MBUTTONDOWN + elif event == cv2.EVENT_MBUTTONUP: ev = InteractBase.EVENT_MBUTTONUP + elif event == cv2.EVENT_MOUSEWHEEL: + ev = InteractBase.EVENT_MOUSEWHEEL + x,y = self.last_xy #fix opencv bug when window size more than screen size + else: ev = 0 + + self.last_xy = (x,y) + inst.add_mouse_event (wnd_name, x, y, ev, flags) + cv2.setMouseCallback(wnd_name, onMouse, (self,wnd_name) ) + + def on_capture_keys (self, wnd_name): + pass + + def on_process_messages(self, sleep_time=0): + + has_windows = False + has_capture_keys = False + + if len(self.named_windows) != 0: + has_windows = True + + if len(self.capture_keys_windows) != 0: + has_capture_keys = True + + if has_windows or has_capture_keys: + wait_key_time = max(1, int(sleep_time*1000) ) + ord_key = cv2.waitKeyEx(wait_key_time) + + shift_pressed = False + if ord_key != -1: + chr_key = chr(ord_key) if ord_key <= 255 else chr(0) + + if chr_key >= 'A' and chr_key <= 'Z': + shift_pressed = True + ord_key += 32 + elif chr_key == '?': + shift_pressed = True + ord_key = ord('/') + elif chr_key == '<': + shift_pressed = True + ord_key = ord(',') + elif chr_key == '>': + shift_pressed = True + ord_key = ord('.') + else: + if sleep_time != 0: + time.sleep(sleep_time) + + if has_capture_keys and ord_key != -1: + self.add_key_event ( self.focus_wnd_name, ord_key, False, False, shift_pressed) + + def on_wait_any_key(self): + cv2.waitKey(0) + +class InteractColab(InteractBase): + + def is_support_windows(self): + return False + + def is_colab(self): + return True + + def on_destroy_all_windows(self): + pass + #clear_output() + + def on_create_window (self, wnd_name): + pass + #clear_output() + + def on_destroy_window (self, wnd_name): + pass + + def on_show_image (self, wnd_name, img): + pass + # # cv2 stores colors as BGR; convert to RGB + # if img.ndim == 3: + # if img.shape[2] == 4: + # img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA) + # else: + # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + # img = PIL.Image.fromarray(img) + # plt.imshow(img) + # plt.show() + + def on_capture_mouse (self, wnd_name): + pass + #print("on_capture_mouse(): Colab does not support") + + def on_capture_keys (self, wnd_name): + pass + #print("on_capture_keys(): Colab does not support") + + def on_process_messages(self, sleep_time=0): + time.sleep(sleep_time) + + def on_wait_any_key(self): + pass + #print("on_wait_any_key(): Colab does not support") + +if is_colab: + interact = InteractColab() +else: + interact = InteractDesktop() diff --git a/core/joblib/MPClassFuncOnDemand.py b/core/joblib/MPClassFuncOnDemand.py new file mode 100644 index 0000000000000000000000000000000000000000..dad924a111b532b3e36cca6f1cbcf89ffe2fc3d9 --- /dev/null +++ b/core/joblib/MPClassFuncOnDemand.py @@ -0,0 +1,32 @@ +import multiprocessing +from core.interact import interact as io + +class MPClassFuncOnDemand(): + def __init__(self, class_handle, class_func_name, **class_kwargs): + self.class_handle = class_handle + self.class_func_name = class_func_name + self.class_kwargs = class_kwargs + + self.class_func = None + + self.s2c = multiprocessing.Queue() + self.c2s = multiprocessing.Queue() + self.lock = multiprocessing.Lock() + + io.add_process_messages_callback(self.io_callback) + + def io_callback(self): + while not self.c2s.empty(): + func_args, func_kwargs = self.c2s.get() + if self.class_func is None: + self.class_func = getattr( self.class_handle(**self.class_kwargs), self.class_func_name) + self.s2c.put ( self.class_func (*func_args, **func_kwargs) ) + + def __call__(self, *args, **kwargs): + with self.lock: + self.c2s.put ( (args, kwargs) ) + return self.s2c.get() + + def __getstate__(self): + return {'s2c':self.s2c, 'c2s':self.c2s, 'lock':self.lock} + diff --git a/core/joblib/MPFunc.py b/core/joblib/MPFunc.py new file mode 100644 index 0000000000000000000000000000000000000000..94512ed1270951631bccaaaaeaf94b9cb06d0619 --- /dev/null +++ b/core/joblib/MPFunc.py @@ -0,0 +1,25 @@ +import multiprocessing +from core.interact import interact as io + +class MPFunc(): + def __init__(self, func): + self.func = func + + self.s2c = multiprocessing.Queue() + self.c2s = multiprocessing.Queue() + self.lock = multiprocessing.Lock() + + io.add_process_messages_callback(self.io_callback) + + def io_callback(self): + while not self.c2s.empty(): + func_args, func_kwargs = self.c2s.get() + self.s2c.put ( self.func (*func_args, **func_kwargs) ) + + def __call__(self, *args, **kwargs): + with self.lock: + self.c2s.put ( (args, kwargs) ) + return self.s2c.get() + + def __getstate__(self): + return {'s2c':self.s2c, 'c2s':self.c2s, 'lock':self.lock} diff --git a/core/joblib/SubprocessGenerator.py b/core/joblib/SubprocessGenerator.py new file mode 100644 index 0000000000000000000000000000000000000000..84a593773de1bda3c39d848bb6e3c25f3cf7b1bc --- /dev/null +++ b/core/joblib/SubprocessGenerator.py @@ -0,0 +1,79 @@ +import multiprocessing +import queue as Queue +import threading +import time + + +class SubprocessGenerator(object): + + @staticmethod + def launch_thread(generator): + generator._start() + + @staticmethod + def start_in_parallel( generator_list ): + """ + Start list of generators in parallel + """ + for generator in generator_list: + thread = threading.Thread(target=SubprocessGenerator.launch_thread, args=(generator,) ) + thread.daemon = True + thread.start() + + while not all ([generator._is_started() for generator in generator_list]): + time.sleep(0.005) + + def __init__(self, generator_func, user_param=None, prefetch=2, start_now=True): + super().__init__() + self.prefetch = prefetch + self.generator_func = generator_func + self.user_param = user_param + self.sc_queue = multiprocessing.Queue() + self.cs_queue = multiprocessing.Queue() + self.p = None + if start_now: + self._start() + + def _start(self): + if self.p == None: + user_param = self.user_param + self.user_param = None + p = multiprocessing.Process(target=self.process_func, args=(user_param,) ) + p.daemon = True + p.start() + self.p = p + + def _is_started(self): + return self.p is not None + + def process_func(self, user_param): + self.generator_func = self.generator_func(user_param) + while True: + while self.prefetch > -1: + try: + gen_data = next (self.generator_func) + except StopIteration: + self.cs_queue.put (None) + return + self.cs_queue.put (gen_data) + self.prefetch -= 1 + self.sc_queue.get() + self.prefetch += 1 + + def __iter__(self): + return self + + def __getstate__(self): + self_dict = self.__dict__.copy() + del self_dict['p'] + return self_dict + + def __next__(self): + self._start() + gen_data = self.cs_queue.get() + if gen_data is None: + self.p.terminate() + self.p.join() + raise StopIteration() + self.sc_queue.put (1) + return gen_data diff --git a/core/joblib/SubprocessorBase.py b/core/joblib/SubprocessorBase.py new file mode 100644 index 0000000000000000000000000000000000000000..17e7056343f959ec8b5801a5f83f58a5a81cac70 --- /dev/null +++ b/core/joblib/SubprocessorBase.py @@ -0,0 +1,302 @@ +import traceback +import multiprocessing +import time +import sys +from core.interact import interact as io + + +class Subprocessor(object): + + class SilenceException(Exception): + pass + + class Cli(object): + def __init__ ( self, client_dict ): + s2c = multiprocessing.Queue() + c2s = multiprocessing.Queue() + self.p = multiprocessing.Process(target=self._subprocess_run, args=(client_dict,s2c,c2s) ) + self.s2c = s2c + self.c2s = c2s + self.p.daemon = True + self.p.start() + + self.state = None + self.sent_time = None + self.sent_data = None + self.name = None + self.host_dict = None + + def kill(self): + self.p.terminate() + self.p.join() + + #overridable optional + def on_initialize(self, client_dict): + #initialize your subprocess here using client_dict + pass + + #overridable optional + def on_finalize(self): + #finalize your subprocess here + pass + + #overridable + def process_data(self, data): + #process 'data' given from host and return result + raise NotImplementedError + + #overridable optional + def get_data_name (self, data): + #return string identificator of your 'data' + return "undefined" + + def log_info(self, msg): self.c2s.put ( {'op': 'log_info', 'msg':msg } ) + def log_err(self, msg): self.c2s.put ( {'op': 'log_err' , 'msg':msg } ) + def progress_bar_inc(self, c): self.c2s.put ( {'op': 'progress_bar_inc' , 'c':c } ) + + def _subprocess_run(self, client_dict, s2c, c2s): + self.c2s = c2s + data = None + is_error = False + try: + self.on_initialize(client_dict) + + c2s.put ( {'op': 'init_ok'} ) + + while True: + msg = s2c.get() + op = msg.get('op','') + if op == 'data': + data = msg['data'] + result = self.process_data (data) + c2s.put ( {'op': 'success', 'data' : data, 'result' : result} ) + data = None + elif op == 'close': + break + + time.sleep(0.001) + + self.on_finalize() + c2s.put ( {'op': 'finalized'} ) + except Subprocessor.SilenceException as e: + c2s.put ( {'op': 'error', 'data' : data} ) + except Exception as e: + err_msg = traceback.format_exc() + c2s.put ( {'op': 'error', 'data' : data, 'err_msg' : err_msg} ) + + c2s.close() + s2c.close() + self.c2s = None + + # disable pickling + def __getstate__(self): + return dict() + def __setstate__(self, d): + self.__dict__.update(d) + + #overridable + def __init__(self, name, SubprocessorCli_class, no_response_time_sec = 0, io_loop_sleep_time=0.005, initialize_subprocesses_in_serial=False): + if not issubclass(SubprocessorCli_class, Subprocessor.Cli): + raise ValueError("SubprocessorCli_class must be subclass of Subprocessor.Cli") + + self.name = name + self.SubprocessorCli_class = SubprocessorCli_class + self.no_response_time_sec = no_response_time_sec + self.io_loop_sleep_time = io_loop_sleep_time + self.initialize_subprocesses_in_serial = initialize_subprocesses_in_serial + + #overridable + def process_info_generator(self): + #yield per process (name, host_dict, client_dict) + raise NotImplementedError + + #overridable optional + def on_clients_initialized(self): + #logic when all subprocesses initialized and ready + pass + + #overridable optional + def on_clients_finalized(self): + #logic when all subprocess finalized + pass + + #overridable + def get_data(self, host_dict): + #return data for processing here + raise NotImplementedError + + #overridable + def on_data_return (self, host_dict, data): + #you have to place returned 'data' back to your queue + raise NotImplementedError + + #overridable + def on_result (self, host_dict, data, result): + #your logic what to do with 'result' of 'data' + raise NotImplementedError + + #overridable + def get_result(self): + #return result that will be returned in func run() + return None + + #overridable + def on_tick(self): + #tick in main loop + #return True if system can be finalized when no data in get_data, orelse False + return True + + #overridable + def on_check_run(self): + return True + + def run(self): + if not self.on_check_run(): + return self.get_result() + + self.clis = [] + + def cli_init_dispatcher(cli): + while not cli.c2s.empty(): + obj = cli.c2s.get() + op = obj.get('op','') + if op == 'init_ok': + cli.state = 0 + elif op == 'log_info': + io.log_info(obj['msg']) + elif op == 'log_err': + io.log_err(obj['msg']) + elif op == 'error': + err_msg = obj.get('err_msg', None) + if err_msg is not None: + io.log_info(f'Error while subprocess initialization: {err_msg}') + cli.kill() + self.clis.remove(cli) + break + + #getting info about name of subprocesses, host and client dicts, and spawning them + for name, host_dict, client_dict in self.process_info_generator(): + try: + cli = self.SubprocessorCli_class(client_dict) + cli.state = 1 + cli.sent_time = 0 + cli.sent_data = None + cli.name = name + cli.host_dict = host_dict + + self.clis.append (cli) + + if self.initialize_subprocesses_in_serial: + while True: + cli_init_dispatcher(cli) + if cli.state == 0: + break + io.process_messages(0.005) + except: + raise Exception (f"Unable to start subprocess {name}. Error: {traceback.format_exc()}") + + if len(self.clis) == 0: + raise Exception ("Unable to start Subprocessor '%s' " % (self.name)) + + #waiting subprocesses their success(or not) initialization + while True: + for cli in self.clis[:]: + cli_init_dispatcher(cli) + if all ([cli.state == 0 for cli in self.clis]): + break + io.process_messages(0.005) + + if len(self.clis) == 0: + raise Exception ( "Unable to start subprocesses." ) + + #ok some processes survived, initialize host logic + + self.on_clients_initialized() + + #main loop of data processing + while True: + for cli in self.clis[:]: + while not cli.c2s.empty(): + obj = cli.c2s.get() + op = obj.get('op','') + if op == 'success': + #success processed data, return data and result to on_result + self.on_result (cli.host_dict, obj['data'], obj['result']) + self.sent_data = None + cli.state = 0 + elif op == 'error': + #some error occured while process data, returning chunk to on_data_return + err_msg = obj.get('err_msg', None) + if err_msg is not None: + io.log_info(f'Error while processing data: {err_msg}') + + if 'data' in obj.keys(): + self.on_data_return (cli.host_dict, obj['data'] ) + #and killing process + cli.kill() + self.clis.remove(cli) + elif op == 'log_info': + io.log_info(obj['msg']) + elif op == 'log_err': + io.log_err(obj['msg']) + elif op == 'progress_bar_inc': + io.progress_bar_inc(obj['c']) + + for cli in self.clis[:]: + if cli.state == 1: + if cli.sent_time != 0 and self.no_response_time_sec != 0 and (time.time() - cli.sent_time) > self.no_response_time_sec: + #subprocess busy too long + print ( '%s doesnt response, terminating it.' % (cli.name) ) + self.on_data_return (cli.host_dict, cli.sent_data ) + cli.kill() + self.clis.remove(cli) + + for cli in self.clis[:]: + if cli.state == 0: + #free state of subprocess, get some data from get_data + data = self.get_data(cli.host_dict) + if data is not None: + #and send it to subprocess + cli.s2c.put ( {'op': 'data', 'data' : data} ) + cli.sent_time = time.time() + cli.sent_data = data + cli.state = 1 + + if self.io_loop_sleep_time != 0: + io.process_messages(self.io_loop_sleep_time) + + if self.on_tick() and all ([cli.state == 0 for cli in self.clis]): + #all subprocesses free and no more data available to process, ending loop + break + + + + #gracefully terminating subprocesses + for cli in self.clis[:]: + cli.s2c.put ( {'op': 'close'} ) + cli.sent_time = time.time() + + while True: + for cli in self.clis[:]: + terminate_it = False + while not cli.c2s.empty(): + obj = cli.c2s.get() + obj_op = obj['op'] + if obj_op == 'finalized': + terminate_it = True + break + + if (time.time() - cli.sent_time) > 30: + terminate_it = True + + if terminate_it: + cli.state = 2 + cli.kill() + + if all ([cli.state == 2 for cli in self.clis]): + break + + #finalizing host logic and return result + self.on_clients_finalized() + + return self.get_result() diff --git a/core/joblib/ThisThreadGenerator.py b/core/joblib/ThisThreadGenerator.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f77a4093716f07a5b6844bf7ba0b92954d7bff --- /dev/null +++ b/core/joblib/ThisThreadGenerator.py @@ -0,0 +1,16 @@ +class ThisThreadGenerator(object): + def __init__(self, generator_func, user_param=None): + super().__init__() + self.generator_func = generator_func + self.user_param = user_param + self.initialized = False + + def __iter__(self): + return self + + def __next__(self): + if not self.initialized: + self.initialized = True + self.generator_func = self.generator_func(self.user_param) + + return next(self.generator_func) \ No newline at end of file diff --git a/core/joblib/__init__.py b/core/joblib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d5613992bca5c1620906626acf6322baad0b20eb --- /dev/null +++ b/core/joblib/__init__.py @@ -0,0 +1,5 @@ +from .SubprocessorBase import Subprocessor +from .ThisThreadGenerator import ThisThreadGenerator +from .SubprocessGenerator import SubprocessGenerator +from .MPFunc import MPFunc +from .MPClassFuncOnDemand import MPClassFuncOnDemand \ No newline at end of file diff --git a/core/leras/__init__.py b/core/leras/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d9fb2b0ea9e79230e9e7b00fe4cc28654149340 --- /dev/null +++ b/core/leras/__init__.py @@ -0,0 +1 @@ +from .nn import nn \ No newline at end of file diff --git a/core/leras/archis/0DeepFakeArchi.py b/core/leras/archis/0DeepFakeArchi.py new file mode 100644 index 0000000000000000000000000000000000000000..8f82aaaa4560e7479f08f8b1dfdd3870a8699963 --- /dev/null +++ b/core/leras/archis/0DeepFakeArchi.py @@ -0,0 +1,474 @@ +from core.leras import nn +tf = nn.tf + +class DeepFakeArchi(nn.ArchiBase): + """ + resolution + + mod None - default + 'uhd' + 'quick' + """ + def __init__(self, resolution, mod=None): + super().__init__() + + if mod is None: + class Downscale(nn.ModelBase): + def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): + self.in_ch = in_ch + self.out_ch = out_ch + self.kernel_size = kernel_size + self.dilations = dilations + self.subpixel = subpixel + self.use_activator = use_activator + super().__init__(*kwargs) + + def on_build(self, *args, **kwargs ): + self.conv1 = nn.Conv2D( self.in_ch, + self.out_ch // (4 if self.subpixel else 1), + kernel_size=self.kernel_size, + strides=1 if self.subpixel else 2, + padding='SAME', dilations=self.dilations) + + def forward(self, x): + x = self.conv1(x) + if self.subpixel: + x = nn.space_to_depth(x, 2) + if self.use_activator: + x = tf.nn.leaky_relu(x, 0.1) + return x + + def get_out_ch(self): + return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch + + class DownscaleBlock(nn.ModelBase): + def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): + self.downs = [] + + last_ch = in_ch + for i in range(n_downscales): + cur_ch = ch*( min(2**i, 8) ) + self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) + last_ch = self.downs[-1].get_out_ch() + + def forward(self, inp): + x = inp + for down in self.downs: + x = down(x) + return x + + class Upscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') + + def forward(self, x): + x = self.conv1(x) + x = tf.nn.leaky_relu(x, 0.1) + x = nn.depth_to_space(x, 2) + return x + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + + def forward(self, inp): + x = self.conv1(inp) + x = tf.nn.leaky_relu(x, 0.2) + x = self.conv2(x) + x = tf.nn.leaky_relu(inp + x, 0.2) + return x + + class UpdownResidualBlock(nn.ModelBase): + def on_build(self, ch, inner_ch, kernel_size=3 ): + self.up = Upscale (ch, inner_ch, kernel_size=kernel_size) + self.res = ResidualBlock (inner_ch, kernel_size=kernel_size) + self.down = Downscale (inner_ch, ch, kernel_size=kernel_size, use_activator=False) + + def forward(self, inp): + x = self.up(inp) + x = upx = self.res(x) + x = self.down(x) + x = x + inp + x = tf.nn.leaky_relu(x, 0.2) + return x, upx + + class Encoder(nn.ModelBase): + def on_build(self, in_ch, e_ch, is_hd): + self.is_hd=is_hd + if self.is_hd: + self.down1 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=3, dilations=1) + self.down2 = DownscaleBlock(in_ch, e_ch*2, n_downscales=4, kernel_size=5, dilations=1) + self.down3 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=5, dilations=2) + self.down4 = DownscaleBlock(in_ch, e_ch//2, n_downscales=4, kernel_size=7, dilations=2) + else: + self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False) + + def forward(self, inp): + if self.is_hd: + x = tf.concat([ nn.flatten(self.down1(inp)), + nn.flatten(self.down2(inp)), + nn.flatten(self.down3(inp)), + nn.flatten(self.down4(inp)) ], -1 ) + else: + x = nn.flatten(self.down1(inp)) + return x + + lowest_dense_res = resolution // 16 + + class Inter(nn.ModelBase): + def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, **kwargs): + self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch + super().__init__(**kwargs) + + def on_build(self): + in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch + + self.dense1 = nn.Dense( in_ch, ae_ch ) + self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) + self.upscale1 = Upscale(ae_out_ch, ae_out_ch) + + def forward(self, inp): + x = self.dense1(inp) + x = self.dense2(x) + x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) + x = self.upscale1(x) + return x + + @staticmethod + def get_code_res(): + return lowest_dense_res + + def get_out_ch(self): + return self.ae_out_ch + + class Decoder(nn.ModelBase): + def on_build(self, in_ch, d_ch, d_mask_ch, is_hd ): + self.is_hd = is_hd + + self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) + self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) + self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) + + if is_hd: + self.res0 = UpdownResidualBlock(in_ch, d_ch*8, kernel_size=3) + self.res1 = UpdownResidualBlock(d_ch*8, d_ch*4, kernel_size=3) + self.res2 = UpdownResidualBlock(d_ch*4, d_ch*2, kernel_size=3) + self.res3 = UpdownResidualBlock(d_ch*2, d_ch, kernel_size=3) + else: + self.res0 = ResidualBlock(d_ch*8, kernel_size=3) + self.res1 = ResidualBlock(d_ch*4, kernel_size=3) + self.res2 = ResidualBlock(d_ch*2, kernel_size=3) + + self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME') + + self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) + self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') + + def forward(self, inp): + z = inp + + if self.is_hd: + x, upx = self.res0(z) + x = self.upscale0(x) + x = tf.nn.leaky_relu(x + upx, 0.2) + x, upx = self.res1(x) + + x = self.upscale1(x) + x = tf.nn.leaky_relu(x + upx, 0.2) + x, upx = self.res2(x) + + x = self.upscale2(x) + x = tf.nn.leaky_relu(x + upx, 0.2) + x, upx = self.res3(x) + else: + x = self.upscale0(z) + x = self.res0(x) + x = self.upscale1(x) + x = self.res1(x) + x = self.upscale2(x) + x = self.res2(x) + + m = self.upscalem0(z) + m = self.upscalem1(m) + m = self.upscalem2(m) + + return tf.nn.sigmoid(self.out_conv(x)), \ + tf.nn.sigmoid(self.out_convm(m)) + + elif mod == 'quick': + class Downscale(nn.ModelBase): + def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): + self.in_ch = in_ch + self.out_ch = out_ch + self.kernel_size = kernel_size + self.dilations = dilations + self.subpixel = subpixel + self.use_activator = use_activator + super().__init__(*kwargs) + + def on_build(self, *args, **kwargs ): + self.conv1 = nn.Conv2D( self.in_ch, + self.out_ch // (4 if self.subpixel else 1), + kernel_size=self.kernel_size, + strides=1 if self.subpixel else 2, + padding='SAME', dilations=self.dilations ) + + def forward(self, x): + x = self.conv1(x) + + if self.subpixel: + x = nn.space_to_depth(x, 2) + + if self.use_activator: + x = nn.gelu(x) + return x + + def get_out_ch(self): + return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch + + class DownscaleBlock(nn.ModelBase): + def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): + self.downs = [] + + last_ch = in_ch + for i in range(n_downscales): + cur_ch = ch*( min(2**i, 8) ) + self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) + last_ch = self.downs[-1].get_out_ch() + + def forward(self, inp): + x = inp + for down in self.downs: + x = down(x) + return x + + class Upscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') + + def forward(self, x): + x = self.conv1(x) + x = nn.gelu(x) + x = nn.depth_to_space(x, 2) + return x + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + + def forward(self, inp): + x = self.conv1(inp) + x = nn.gelu(x) + x = self.conv2(x) + x = inp + x + x = nn.gelu(x) + return x + + class Encoder(nn.ModelBase): + def on_build(self, in_ch, e_ch): + self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5) + def forward(self, inp): + return nn.flatten(self.down1(inp)) + + lowest_dense_res = resolution // 16 + + class Inter(nn.ModelBase): + def __init__(self, in_ch, ae_ch, ae_out_ch, d_ch, **kwargs): + self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch = in_ch, ae_ch, ae_out_ch, d_ch + super().__init__(**kwargs) + + def on_build(self): + in_ch, ae_ch, ae_out_ch, d_ch = self.in_ch, self.ae_ch, self.ae_out_ch, self.d_ch + + self.dense1 = nn.Dense( in_ch, ae_ch, kernel_initializer=tf.initializers.orthogonal ) + self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch, kernel_initializer=tf.initializers.orthogonal ) + self.upscale1 = Upscale(ae_out_ch, d_ch*8) + self.res1 = ResidualBlock(d_ch*8) + + def forward(self, inp): + x = self.dense1(inp) + x = self.dense2(x) + x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) + x = self.upscale1(x) + x = self.res1(x) + return x + + def get_out_ch(self): + return self.ae_out_ch + + class Decoder(nn.ModelBase): + def on_build(self, in_ch, d_ch): + self.upscale1 = Upscale(in_ch, d_ch*4) + self.res1 = ResidualBlock(d_ch*4) + self.upscale2 = Upscale(d_ch*4, d_ch*2) + self.res2 = ResidualBlock(d_ch*2) + self.upscale3 = Upscale(d_ch*2, d_ch*1) + self.res3 = ResidualBlock(d_ch*1) + + self.upscalem1 = Upscale(in_ch, d_ch) + self.upscalem2 = Upscale(d_ch, d_ch//2) + self.upscalem3 = Upscale(d_ch//2, d_ch//2) + + self.out_conv = nn.Conv2D( d_ch*1, 3, kernel_size=1, padding='SAME') + self.out_convm = nn.Conv2D( d_ch//2, 1, kernel_size=1, padding='SAME') + + def forward(self, inp): + z = inp + x = self.upscale1 (z) + x = self.res1 (x) + x = self.upscale2 (x) + x = self.res2 (x) + x = self.upscale3 (x) + x = self.res3 (x) + + y = self.upscalem1 (z) + y = self.upscalem2 (y) + y = self.upscalem3 (y) + + return tf.nn.sigmoid(self.out_conv(x)), \ + tf.nn.sigmoid(self.out_convm(y)) + elif mod == 'uhd': + + class Downscale(nn.ModelBase): + def __init__(self, in_ch, out_ch, kernel_size=5, dilations=1, subpixel=True, use_activator=True, *kwargs ): + self.in_ch = in_ch + self.out_ch = out_ch + self.kernel_size = kernel_size + self.dilations = dilations + self.subpixel = subpixel + self.use_activator = use_activator + super().__init__(*kwargs) + + def on_build(self, *args, **kwargs ): + self.conv1 = nn.Conv2D( self.in_ch, + self.out_ch // (4 if self.subpixel else 1), + kernel_size=self.kernel_size, + strides=1 if self.subpixel else 2, + padding='SAME', dilations=self.dilations) + + def forward(self, x): + x = self.conv1(x) + if self.subpixel: + x = nn.space_to_depth(x, 2) + if self.use_activator: + x = tf.nn.leaky_relu(x, 0.1) + return x + + def get_out_ch(self): + return (self.out_ch // 4) * 4 if self.subpixel else self.out_ch + + class DownscaleBlock(nn.ModelBase): + def on_build(self, in_ch, ch, n_downscales, kernel_size, dilations=1, subpixel=True): + self.downs = [] + + last_ch = in_ch + for i in range(n_downscales): + cur_ch = ch*( min(2**i, 8) ) + self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size, dilations=dilations, subpixel=subpixel) ) + last_ch = self.downs[-1].get_out_ch() + + def forward(self, inp): + x = inp + for down in self.downs: + x = down(x) + return x + + class Upscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') + + def forward(self, x): + x = self.conv1(x) + x = tf.nn.leaky_relu(x, 0.1) + x = nn.depth_to_space(x, 2) + return x + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + + def forward(self, inp): + x = self.conv1(inp) + x = tf.nn.leaky_relu(x, 0.2) + x = self.conv2(x) + x = tf.nn.leaky_relu(inp + x, 0.2) + return x + + class Encoder(nn.ModelBase): + def on_build(self, in_ch, e_ch, **kwargs): + self.down1 = DownscaleBlock(in_ch, e_ch, n_downscales=4, kernel_size=5, dilations=1, subpixel=False) + + def forward(self, inp): + x = nn.flatten(self.down1(inp)) + return x + + lowest_dense_res = resolution // 16 + + class Inter(nn.ModelBase): + def on_build(self, in_ch, ae_ch, ae_out_ch, **kwargs): + self.ae_out_ch = ae_out_ch + self.dense_norm = nn.DenseNorm() + self.dense1 = nn.Dense( in_ch, ae_ch ) + self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) + self.upscale1 = Upscale(ae_out_ch, ae_out_ch) + + def forward(self, inp): + x = self.dense_norm(inp) + x = self.dense1(x) + x = self.dense2(x) + x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) + x = self.upscale1(x) + return x + + @staticmethod + def get_code_res(): + return lowest_dense_res + + def get_out_ch(self): + return self.ae_out_ch + + class Decoder(nn.ModelBase): + def on_build(self, in_ch, d_ch, d_mask_ch, **kwargs ): + + self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) + self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) + self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) + + self.res0 = ResidualBlock(d_ch*8, kernel_size=3) + self.res1 = ResidualBlock(d_ch*4, kernel_size=3) + self.res2 = ResidualBlock(d_ch*2, kernel_size=3) + + self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME') + + self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) + self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') + + def forward(self, inp): + z = inp + + x = self.upscale0(z) + x = self.res0(x) + x = self.upscale1(x) + x = self.res1(x) + x = self.upscale2(x) + x = self.res2(x) + + m = self.upscalem0(z) + m = self.upscalem1(m) + m = self.upscalem2(m) + + return tf.nn.sigmoid(self.out_conv(x)), \ + tf.nn.sigmoid(self.out_convm(m)) + + self.Encoder = Encoder + self.Inter = Inter + self.Decoder = Decoder + +nn.DeepFakeArchi = DeepFakeArchi \ No newline at end of file diff --git a/core/leras/archis/ArchiBase.py b/core/leras/archis/ArchiBase.py new file mode 100644 index 0000000000000000000000000000000000000000..5bfe6d9dabfe37ff2d008296dd21b3214797367b --- /dev/null +++ b/core/leras/archis/ArchiBase.py @@ -0,0 +1,17 @@ +from core.leras import nn + +class ArchiBase(): + + def __init__(self, *args, name=None, **kwargs): + self.name=name + + + #overridable + def flow(self, *args, **kwargs): + raise Exception("this archi does not support flow. Use model classes directly.") + + #overridable + def get_weights(self): + pass + +nn.ArchiBase = ArchiBase \ No newline at end of file diff --git a/core/leras/archis/DeepFakeArchi.py b/core/leras/archis/DeepFakeArchi.py new file mode 100644 index 0000000000000000000000000000000000000000..93ff13c615b5c3155e6e30a20374faf8a9ecacb6 --- /dev/null +++ b/core/leras/archis/DeepFakeArchi.py @@ -0,0 +1,265 @@ +from core.leras import nn +tf = nn.tf + +class DeepFakeArchi(nn.ArchiBase): + """ + resolution + + mod None - default + 'quick' + + opts '' + '' + 't' + """ + def __init__(self, resolution, use_fp16=False, mod=None, opts=None): + super().__init__() + + if opts is None: + opts = '' + + + conv_dtype = tf.float16 if use_fp16 else tf.float32 + + if 'c' in opts: + def act(x, alpha=0.1): + return x*tf.cos(x) + else: + def act(x, alpha=0.1): + return tf.nn.leaky_relu(x, alpha) + + if mod is None: + class Downscale(nn.ModelBase): + def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ): + self.in_ch = in_ch + self.out_ch = out_ch + self.kernel_size = kernel_size + super().__init__(*kwargs) + + def on_build(self, *args, **kwargs ): + self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + x = self.conv1(x) + x = act(x, 0.1) + return x + + def get_out_ch(self): + return self.out_ch + + class DownscaleBlock(nn.ModelBase): + def on_build(self, in_ch, ch, n_downscales, kernel_size): + self.downs = [] + + last_ch = in_ch + for i in range(n_downscales): + cur_ch = ch*( min(2**i, 8) ) + self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size)) + last_ch = self.downs[-1].get_out_ch() + + def forward(self, inp): + x = inp + for down in self.downs: + x = down(x) + return x + + class Upscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=3): + self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + x = self.conv1(x) + x = act(x, 0.1) + x = nn.depth_to_space(x, 2) + return x + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, inp): + x = self.conv1(inp) + x = act(x, 0.2) + x = self.conv2(x) + x = act(inp + x, 0.2) + return x + + class Encoder(nn.ModelBase): + def __init__(self, in_ch, e_ch, **kwargs ): + self.in_ch = in_ch + self.e_ch = e_ch + super().__init__(**kwargs) + + def on_build(self): + if 't' in opts: + self.down1 = Downscale(self.in_ch, self.e_ch, kernel_size=5) + self.res1 = ResidualBlock(self.e_ch) + self.down2 = Downscale(self.e_ch, self.e_ch*2, kernel_size=5) + self.down3 = Downscale(self.e_ch*2, self.e_ch*4, kernel_size=5) + self.down4 = Downscale(self.e_ch*4, self.e_ch*8, kernel_size=5) + self.down5 = Downscale(self.e_ch*8, self.e_ch*8, kernel_size=5) + self.res5 = ResidualBlock(self.e_ch*8) + else: + self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4 if 't' not in opts else 5, kernel_size=5) + + def forward(self, x): + if use_fp16: + x = tf.cast(x, tf.float16) + + if 't' in opts: + x = self.down1(x) + x = self.res1(x) + x = self.down2(x) + x = self.down3(x) + x = self.down4(x) + x = self.down5(x) + x = self.res5(x) + else: + x = self.down1(x) + x = nn.flatten(x) + if 'u' in opts: + x = nn.pixel_norm(x, axes=-1) + + if use_fp16: + x = tf.cast(x, tf.float32) + return x + + def get_out_res(self, res): + return res // ( (2**4) if 't' not in opts else (2**5) ) + + def get_out_ch(self): + return self.e_ch * 8 + + lowest_dense_res = resolution // (32 if 'd' in opts else 16) + + class Inter(nn.ModelBase): + def __init__(self, in_ch, ae_ch, ae_out_ch, **kwargs): + self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch + super().__init__(**kwargs) + + def on_build(self): + in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch + + self.dense1 = nn.Dense( in_ch, ae_ch ) + self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) + if 't' not in opts: + self.upscale1 = Upscale(ae_out_ch, ae_out_ch) + + def forward(self, inp): + x = inp + x = self.dense1(x) + x = self.dense2(x) + x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) + + if use_fp16: + x = tf.cast(x, tf.float16) + + if 't' not in opts: + x = self.upscale1(x) + + return x + + def get_out_res(self): + return lowest_dense_res * 2 if 't' not in opts else lowest_dense_res + + def get_out_ch(self): + return self.ae_out_ch + + class Decoder(nn.ModelBase): + def on_build(self, in_ch, d_ch, d_mask_ch): + if 't' not in opts: + self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) + self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) + self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) + self.res0 = ResidualBlock(d_ch*8, kernel_size=3) + self.res1 = ResidualBlock(d_ch*4, kernel_size=3) + self.res2 = ResidualBlock(d_ch*2, kernel_size=3) + + self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) + self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) + + self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) + + if 'd' in opts: + self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + else: + self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + else: + self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) + self.upscale1 = Upscale(d_ch*8, d_ch*8, kernel_size=3) + self.upscale2 = Upscale(d_ch*8, d_ch*4, kernel_size=3) + self.upscale3 = Upscale(d_ch*4, d_ch*2, kernel_size=3) + self.res0 = ResidualBlock(d_ch*8, kernel_size=3) + self.res1 = ResidualBlock(d_ch*8, kernel_size=3) + self.res2 = ResidualBlock(d_ch*4, kernel_size=3) + self.res3 = ResidualBlock(d_ch*2, kernel_size=3) + + self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*8, kernel_size=3) + self.upscalem2 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) + self.upscalem3 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) + self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) + + if 'd' in opts: + self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + else: + self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + + + + def forward(self, z): + x = self.upscale0(z) + x = self.res0(x) + x = self.upscale1(x) + x = self.res1(x) + x = self.upscale2(x) + x = self.res2(x) + + if 't' in opts: + x = self.upscale3(x) + x = self.res3(x) + + if 'd' in opts: + x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x), + self.out_conv1(x), + self.out_conv2(x), + self.out_conv3(x)), nn.conv2d_ch_axis), 2) ) + else: + x = tf.nn.sigmoid(self.out_conv(x)) + + + m = self.upscalem0(z) + m = self.upscalem1(m) + m = self.upscalem2(m) + + if 't' in opts: + m = self.upscalem3(m) + if 'd' in opts: + m = self.upscalem4(m) + else: + if 'd' in opts: + m = self.upscalem3(m) + + m = tf.nn.sigmoid(self.out_convm(m)) + + if use_fp16: + x = tf.cast(x, tf.float32) + m = tf.cast(m, tf.float32) + + return x, m + + self.Encoder = Encoder + self.Inter = Inter + self.Decoder = Decoder + +nn.DeepFakeArchi = DeepFakeArchi \ No newline at end of file diff --git a/core/leras/archis/DeepFakeArchi1.py b/core/leras/archis/DeepFakeArchi1.py new file mode 100644 index 0000000000000000000000000000000000000000..82913d1200794b465383e474a8925be06ff841df --- /dev/null +++ b/core/leras/archis/DeepFakeArchi1.py @@ -0,0 +1,339 @@ +from core.leras import nn +tf = nn.tf + +class RB(nn.ModelBase): + def on_build(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + + def forward(self, inp): + x = self.conv1(inp) + x = tf.nn.relu(x) + x = self.conv2(x) + return x + + +class LCA(nn.ModelBase): + def on_build(self, ch): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=1, padding='VALID') + + def forward(self, inp): + x = inp + x = tf.reduce_mean(x, axis=nn.conv2d_spatial_axes, keepdims=True) + x = self.conv1(x) + x = tf.nn.sigmoid(x) + x = inp * x + return x + +class RAFG(nn.ModelBase): + def on_build(self, ch): + self.rb1 = RB(ch) + self.lca1 = LCA(ch) + self.rb2 = RB(ch) + self.lca2 = LCA(ch) + self.rb3 = RB(ch) + self.lca3 = LCA(ch) + + self.ab_conv = nn.Conv2D( ch*3, ch, kernel_size=1, padding='VALID') + self.ab_lca = LCA(ch) + self.fb_conv = nn.Conv2D( ch*4, ch, kernel_size=1, padding='VALID') + + def forward(self, inp): + x = inp + + rb1 = self.rb1(x) + lca1 = self.lca1(rb1) + + x = x+rb1 + + rb2 = self.rb2(x) + lca2 = self.lca2(rb2) + + x = x+rb2 + + rb3 = self.rb3(x) + lca3 = self.lca3(rb3) + + lca = tf.concat([lca1,lca2,lca3], axis=nn.conv2d_ch_axis) + lca = self.ab_conv(lca) + lca = self.ab_lca(lca) + + rb = tf.concat([inp,rb1,rb2,rb2], axis=nn.conv2d_ch_axis) + rb = self.fb_conv(rb) + + return rb+lca, lca + + +class HRAN(nn.ModelBase): + """ + Hierarchical Residual Attention Network for Single Image Super-Resolution + + https://arxiv.org/pdf/2012.04578v1.pdf + """ + + + def on_build(self, ch): + self.rafg1 = RAFG(ch) + self.rafg2 = RAFG(ch) + self.rafg3 = RAFG(ch) + + self.ab_conv = nn.Conv2D( ch*3, ch, kernel_size=1, padding='VALID') + self.ab_lca = LCA(ch) + self.fb_conv = nn.Conv2D( ch*4, ch, kernel_size=1, padding='VALID') + + def forward(self, inp): + x = inp + + rafg1, rafg1_lca = self.rafg1(x) + + rafg2, rafg2_lca = self.rafg2(x) + rafg3, rafg3_lca = self.rafg3(x) + + rafg_lca = tf.concat([rafg1_lca,rafg2_lca,rafg3_lca], axis=nn.conv2d_ch_axis) + rafg_lca = self.ab_conv(rafg_lca) + rafg_lca = self.ab_lca(rafg_lca) + + rafg = tf.concat([x,rafg1,rafg2,rafg3], axis=nn.conv2d_ch_axis) + rafg = self.fb_conv(rafg) + + x = x + rafg + rafg_lca + x = tf.nn.leaky_relu(x, 0.2) + + return x + +class DeepFakeArchi(nn.ArchiBase): + """ + resolution + + mod None - default + 'quick' + """ + def __init__(self, resolution, mod=None, opts=None): + super().__init__() + + if opts is None: + opts = '' + + if mod is None: + class Downscale(nn.ModelBase): + def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ): + self.in_ch = in_ch + self.out_ch = out_ch + self.kernel_size = kernel_size + super().__init__(*kwargs) + + def on_build(self, *args, **kwargs ): + self.conv1 = nn.Conv2D( self.in_ch, self.out_ch, kernel_size=self.kernel_size, strides=2, padding='SAME') + + def forward(self, x): + x = self.conv1(x) + x = tf.nn.leaky_relu(x, 0.1) + return x + + def get_out_ch(self): + return self.out_ch + + class DownscaleBlock(nn.ModelBase): + def on_build(self, in_ch, ch, n_downscales, kernel_size): + self.downs = [] + + last_ch = in_ch + for i in range(n_downscales): + cur_ch = ch*( min(2**i, 8) ) + self.downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size) ) + last_ch = self.downs[-1].get_out_ch() + + def forward(self, inp): + x = inp + for down in self.downs: + x = down(x) + return x + + class Upscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') + + def forward(self, x): + x = self.conv1(x) + x = tf.nn.leaky_relu(x, 0.1) + x = nn.depth_to_space(x, 2) + return x + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + + def forward(self, inp): + x = self.conv1(inp) + x = tf.nn.leaky_relu(x, 0.2) + x = self.conv2(x) + x = tf.nn.leaky_relu(inp + x, 0.2) + return x + + class Encoder(nn.ModelBase): + def __init__(self, in_ch, e_ch, **kwargs ): + self.in_ch = in_ch + self.e_ch = e_ch + super().__init__(**kwargs) + + def on_build(self): + self.down1 = DownscaleBlock(self.in_ch, self.e_ch, n_downscales=4, kernel_size=5) + + def forward(self, inp): + return nn.flatten(self.down1(inp)) + + def get_out_res(self, res): + return res // (2**4) + + def get_out_ch(self): + return self.e_ch * 8 + + lowest_dense_res = resolution // 16 + + if 'h' in opts: + lowest_dense_res //= 2 + + if 'd' in opts: + lowest_dense_res //= 2 + + class Inter(nn.ModelBase): + def __init__(self, in_ch, ae_ch, ae_out_ch, **kwargs): + self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch + super().__init__(**kwargs) + + def on_build(self): + in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch + if 'u' in opts: + self.dense_norm = nn.DenseNorm() + + self.dense1 = nn.Dense( in_ch, ae_ch ) + self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) + self.upscale1 = Upscale(ae_out_ch, ae_out_ch) + + def forward(self, inp): + x = inp + if 'u' in opts: + x = self.dense_norm(x) + x = self.dense1(x) + x = self.dense2(x) + x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) + x = self.upscale1(x) + return x + + def get_out_res(self): + return lowest_dense_res * 2 + + def get_out_ch(self): + return self.ae_out_ch + + class Decoder(nn.ModelBase): + def on_build(self, in_ch, d_ch, d_mask_ch ): + self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) + self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) + self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) + + self.res0 = HRAN(d_ch*8) + self.res1 = HRAN(d_ch*4) + self.res2 = HRAN(d_ch*2) + self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='SAME') + + self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) + self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') + + if 'd' in opts: + self.out_conv1 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME') + self.out_conv2 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME') + self.out_conv3 = nn.Conv2D( d_ch*2, 3, kernel_size=3, padding='SAME') + + if 'h' in opts and 'd' in opts: + self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*2, kernel_size=3) + self.upscalem4 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3) + elif 'h' in opts or 'd' in opts: + self.upscalem3 = Upscale(d_mask_ch*2, d_mask_ch*1, kernel_size=3) + + if 'h' in opts or 'd' in opts: + self.out_convm = nn.Conv2D( d_mask_ch*1, 1, kernel_size=1, padding='SAME') + else: + self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='SAME') + + if 'h' in opts: + self.hran = HRAN(3, 64) + + + + def forward(self, inp): + z = inp + + x = self.upscale0(z) + x = self.res0(x) + x = self.upscale1(x) + x = self.res1(x) + x = self.upscale2(x) + x = self.res2(x) + + + if 'd' in opts: + x0 = tf.nn.sigmoid(self.out_conv(x)) + x0 = nn.upsample2d(x0) + x1 = tf.nn.sigmoid(self.out_conv1(x)) + x1 = nn.upsample2d(x1) + x2 = tf.nn.sigmoid(self.out_conv2(x)) + x2 = nn.upsample2d(x2) + x3 = tf.nn.sigmoid(self.out_conv3(x)) + x3 = nn.upsample2d(x3) + + tile_res = resolution // 2 + if 'h' in opts: + tile_res //= 2 + + if nn.data_format == "NHWC": + tile_cfg = ( 1, tile_res, tile_res, 1) + else: + tile_cfg = ( 1, 1, tile_res, tile_res) + + z0 = tf.concat ( ( tf.concat ( ( tf.ones ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ), + tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] ) + + z0 = tf.tile ( z0, tile_cfg ) + + z1 = tf.concat ( ( tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.ones ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ), + tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] ) + z1 = tf.tile ( z1, tile_cfg ) + + z2 = tf.concat ( ( tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ), + tf.concat ( ( tf.ones ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] ) + z2 = tf.tile ( z2, tile_cfg ) + + z3 = tf.concat ( ( tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.zeros ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ), + tf.concat ( ( tf.zeros ( (1,1,1,1) ), tf.ones ( (1,1,1,1) ) ), axis=nn.conv2d_spatial_axes[1] ) ), axis=nn.conv2d_spatial_axes[0] ) + z3 = tf.tile ( z3, tile_cfg ) + + x = x0*z0 + x1*z1 + x2*z2 + x3*z3 + else: + x = tf.nn.sigmoid(self.out_conv(x)) + + + m = self.upscalem0(z) + m = self.upscalem1(m) + m = self.upscalem2(m) + + if 'h' in opts and 'd' in opts: + m = self.upscalem3(m) + m = self.upscalem4(m) + elif 'h' in opts or 'd' in opts: + m = self.upscalem3(m) + m = tf.nn.sigmoid(self.out_convm(m)) + + if 'h' in opts: + x = self.hran(x) + + return x, m + + self.Encoder = Encoder + self.Inter = Inter + self.Decoder = Decoder + +nn.DeepFakeArchi = DeepFakeArchi \ No newline at end of file diff --git a/core/leras/archis/__init__.py b/core/leras/archis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3734ddd68904c89e0f494017094c2c1cb0447b42 --- /dev/null +++ b/core/leras/archis/__init__.py @@ -0,0 +1,2 @@ +from .ArchiBase import * +from .DeepFakeArchi import * \ No newline at end of file diff --git a/core/leras/device.py b/core/leras/device.py new file mode 100644 index 0000000000000000000000000000000000000000..31d2f880926b8ccf59f0a35af307f9adf14441cf --- /dev/null +++ b/core/leras/device.py @@ -0,0 +1,272 @@ +import sys +import ctypes +import os +import multiprocessing +import json +import time +from pathlib import Path +from core.interact import interact as io + + +class Device(object): + def __init__(self, index, tf_dev_type, name, total_mem, free_mem): + self.index = index + self.tf_dev_type = tf_dev_type + self.name = name + + self.total_mem = total_mem + self.total_mem_gb = total_mem / 1024**3 + self.free_mem = free_mem + self.free_mem_gb = free_mem / 1024**3 + + def __str__(self): + return f"[{self.index}]:[{self.name}][{self.free_mem_gb:.3}/{self.total_mem_gb :.3}]" + +class Devices(object): + all_devices = None + + def __init__(self, devices): + self.devices = devices + + def __len__(self): + return len(self.devices) + + def __getitem__(self, key): + result = self.devices[key] + if isinstance(key, slice): + return Devices(result) + return result + + def __iter__(self): + for device in self.devices: + yield device + + def get_best_device(self): + result = None + idx_mem = 0 + for device in self.devices: + mem = device.total_mem + if mem > idx_mem: + result = device + idx_mem = mem + return result + + def get_worst_device(self): + result = None + idx_mem = sys.maxsize + for device in self.devices: + mem = device.total_mem + if mem < idx_mem: + result = device + idx_mem = mem + return result + + def get_device_by_index(self, idx): + for device in self.devices: + if device.index == idx: + return device + return None + + def get_devices_from_index_list(self, idx_list): + result = [] + for device in self.devices: + if device.index in idx_list: + result += [device] + return Devices(result) + + def get_equal_devices(self, device): + device_name = device.name + result = [] + for device in self.devices: + if device.name == device_name: + result.append (device) + return Devices(result) + + def get_devices_at_least_mem(self, totalmemsize_gb): + result = [] + for device in self.devices: + if device.total_mem >= totalmemsize_gb*(1024**3): + result.append (device) + return Devices(result) + + @staticmethod + def _get_tf_devices_proc(q : multiprocessing.Queue): + + if sys.platform[0:3] == 'win': + compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache_ALL') + os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path) + if not compute_cache_path.exists(): + io.log_info("Caching GPU kernels...") + compute_cache_path.mkdir(parents=True, exist_ok=True) + + import tensorflow + + tf_version = tensorflow.version.VERSION + #if tf_version is None: + # tf_version = tensorflow.version.GIT_VERSION + if tf_version[0] == 'v': + tf_version = tf_version[1:] + if tf_version[0] == '2': + tf = tensorflow.compat.v1 + else: + tf = tensorflow + + import logging + # Disable tensorflow warnings + tf_logger = logging.getLogger('tensorflow') + tf_logger.setLevel(logging.ERROR) + + from tensorflow.python.client import device_lib + + devices = [] + + physical_devices = device_lib.list_local_devices() + physical_devices_f = {} + for dev in physical_devices: + dev_type = dev.device_type + dev_tf_name = dev.name + dev_tf_name = dev_tf_name[ dev_tf_name.index(dev_type) : ] + + dev_idx = int(dev_tf_name.split(':')[-1]) + + if dev_type in ['GPU','DML']: + dev_name = dev_tf_name + + dev_desc = dev.physical_device_desc + if len(dev_desc) != 0: + if dev_desc[0] == '{': + dev_desc_json = json.loads(dev_desc) + dev_desc_json_name = dev_desc_json.get('name',None) + if dev_desc_json_name is not None: + dev_name = dev_desc_json_name + else: + for param, value in ( v.split(':') for v in dev_desc.split(',') ): + param = param.strip() + value = value.strip() + if param == 'name': + dev_name = value + break + + physical_devices_f[dev_idx] = (dev_type, dev_name, dev.memory_limit) + + q.put(physical_devices_f) + time.sleep(0.1) + + + @staticmethod + def initialize_main_env(): + if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 0: + return + + if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): + os.environ.pop('CUDA_VISIBLE_DEVICES') + + os.environ['CUDA_​CACHE_​MAXSIZE'] = '2147483647' + os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2' + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # tf log errors only + + q = multiprocessing.Queue() + p = multiprocessing.Process(target=Devices._get_tf_devices_proc, args=(q,), daemon=True) + p.start() + p.join() + + visible_devices = q.get() + + os.environ['NN_DEVICES_INITIALIZED'] = '1' + os.environ['NN_DEVICES_COUNT'] = str(len(visible_devices)) + + for i in visible_devices: + dev_type, name, total_mem = visible_devices[i] + + os.environ[f'NN_DEVICE_{i}_TF_DEV_TYPE'] = dev_type + os.environ[f'NN_DEVICE_{i}_NAME'] = name + os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(total_mem) + os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(total_mem) + + + + @staticmethod + def getDevices(): + if Devices.all_devices is None: + if int(os.environ.get("NN_DEVICES_INITIALIZED", 0)) != 1: + raise Exception("nn devices are not initialized. Run initialize_main_env() in main process.") + devices = [] + for i in range ( int(os.environ['NN_DEVICES_COUNT']) ): + devices.append ( Device(index=i, + tf_dev_type=os.environ[f'NN_DEVICE_{i}_TF_DEV_TYPE'], + name=os.environ[f'NN_DEVICE_{i}_NAME'], + total_mem=int(os.environ[f'NN_DEVICE_{i}_TOTAL_MEM']), + free_mem=int(os.environ[f'NN_DEVICE_{i}_FREE_MEM']), ) + ) + Devices.all_devices = Devices(devices) + + return Devices.all_devices + +""" + + + # {'name' : name.split(b'\0', 1)[0].decode(), + # 'total_mem' : totalMem.value + # } + + + + + + return + + + + + min_cc = int(os.environ.get("TF_MIN_REQ_CAP", 35)) + libnames = ('libcuda.so', 'libcuda.dylib', 'nvcuda.dll') + for libname in libnames: + try: + cuda = ctypes.CDLL(libname) + except: + continue + else: + break + else: + return Devices([]) + + nGpus = ctypes.c_int() + name = b' ' * 200 + cc_major = ctypes.c_int() + cc_minor = ctypes.c_int() + freeMem = ctypes.c_size_t() + totalMem = ctypes.c_size_t() + + result = ctypes.c_int() + device = ctypes.c_int() + context = ctypes.c_void_p() + error_str = ctypes.c_char_p() + + devices = [] + + if cuda.cuInit(0) == 0 and \ + cuda.cuDeviceGetCount(ctypes.byref(nGpus)) == 0: + for i in range(nGpus.value): + if cuda.cuDeviceGet(ctypes.byref(device), i) != 0 or \ + cuda.cuDeviceGetName(ctypes.c_char_p(name), len(name), device) != 0 or \ + cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) != 0: + continue + + if cuda.cuCtxCreate_v2(ctypes.byref(context), 0, device) == 0: + if cuda.cuMemGetInfo_v2(ctypes.byref(freeMem), ctypes.byref(totalMem)) == 0: + cc = cc_major.value * 10 + cc_minor.value + if cc >= min_cc: + devices.append ( {'name' : name.split(b'\0', 1)[0].decode(), + 'total_mem' : totalMem.value, + 'free_mem' : freeMem.value, + 'cc' : cc + }) + cuda.cuCtxDetach(context) + + os.environ['NN_DEVICES_COUNT'] = str(len(devices)) + for i, device in enumerate(devices): + os.environ[f'NN_DEVICE_{i}_NAME'] = device['name'] + os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem']) + os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(device['free_mem']) + os.environ[f'NN_DEVICE_{i}_CC'] = str(device['cc']) +""" \ No newline at end of file diff --git a/core/leras/initializers/CA.py b/core/leras/initializers/CA.py new file mode 100644 index 0000000000000000000000000000000000000000..d05f698a215b27e06789db26276477bba5fe23d8 --- /dev/null +++ b/core/leras/initializers/CA.py @@ -0,0 +1,82 @@ +import multiprocessing +from core.joblib import Subprocessor +import numpy as np + +class CAInitializerSubprocessor(Subprocessor): + @staticmethod + def generate(shape, dtype=np.float32, eps_std=0.05): + """ + Super fast implementation of Convolution Aware Initialization for 4D shapes + Convolution Aware Initialization https://arxiv.org/abs/1702.06295 + """ + if len(shape) != 4: + raise ValueError("only shape with rank 4 supported.") + + row, column, stack_size, filters_size = shape + + fan_in = stack_size * (row * column) + + kernel_shape = (row, column) + + kernel_fft_shape = np.fft.rfft2(np.zeros(kernel_shape)).shape + + basis_size = np.prod(kernel_fft_shape) + if basis_size == 1: + x = np.random.normal( 0.0, eps_std, (filters_size, stack_size, basis_size) ) + else: + nbb = stack_size // basis_size + 1 + x = np.random.normal(0.0, 1.0, (filters_size, nbb, basis_size, basis_size)) + x = x + np.transpose(x, (0,1,3,2) ) * (1-np.eye(basis_size)) + u, _, v = np.linalg.svd(x) + x = np.transpose(u, (0,1,3,2) ) + x = np.reshape(x, (filters_size, -1, basis_size) ) + x = x[:,:stack_size,:] + + x = np.reshape(x, ( (filters_size,stack_size,) + kernel_fft_shape ) ) + + x = np.fft.irfft2( x, kernel_shape ) \ + + np.random.normal(0, eps_std, (filters_size,stack_size,)+kernel_shape) + + x = x * np.sqrt( (2/fan_in) / np.var(x) ) + x = np.transpose( x, (2, 3, 1, 0) ) + return x.astype(dtype) + + class Cli(Subprocessor.Cli): + #override + def process_data(self, data): + idx, shape, dtype = data + weights = CAInitializerSubprocessor.generate (shape, dtype) + return idx, weights + + #override + def __init__(self, data_list): + self.data_list = data_list + self.data_list_idxs = [*range(len(data_list))] + self.result = [None]*len(data_list) + super().__init__('CAInitializerSubprocessor', CAInitializerSubprocessor.Cli) + + #override + def process_info_generator(self): + for i in range( min(multiprocessing.cpu_count(), len(self.data_list)) ): + yield 'CPU%d' % (i), {}, {} + + #override + def get_data(self, host_dict): + if len (self.data_list_idxs) > 0: + idx = self.data_list_idxs.pop(0) + shape, dtype = self.data_list[idx] + return idx, shape, dtype + return None + + #override + def on_data_return (self, host_dict, data): + self.data_list_idxs.insert(0, data) + + #override + def on_result (self, host_dict, data, result): + idx, weights = result + self.result[idx] = weights + + #override + def get_result(self): + return self.result diff --git a/core/leras/initializers/__init__.py b/core/leras/initializers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4211fbe7b50fe6eac641e2b943c4cd781da91e0d --- /dev/null +++ b/core/leras/initializers/__init__.py @@ -0,0 +1,20 @@ +import numpy as np +from tensorflow.python.ops import init_ops + +from core.leras import nn + +tf = nn.tf + +from .CA import CAInitializerSubprocessor + +class initializers(): + class ca (init_ops.Initializer): + def __call__(self, shape, dtype=None, partition_info=None): + return tf.zeros( shape, dtype=dtype, name="_cai_") + + @staticmethod + def generate_batch( data_list, eps_std=0.05 ): + # list of (shape, np.dtype) + return CAInitializerSubprocessor (data_list).run() + +nn.initializers = initializers diff --git a/core/leras/layers/AdaIN.py b/core/leras/layers/AdaIN.py new file mode 100644 index 0000000000000000000000000000000000000000..fd25038ca10efab2b4f31d8e91e73b22bbec8bca --- /dev/null +++ b/core/leras/layers/AdaIN.py @@ -0,0 +1,56 @@ +from core.leras import nn +tf = nn.tf + +class AdaIN(nn.LayerBase): + """ + """ + def __init__(self, in_ch, mlp_ch, kernel_initializer=None, dtype=None, **kwargs): + self.in_ch = in_ch + self.mlp_ch = mlp_ch + self.kernel_initializer = kernel_initializer + + if dtype is None: + dtype = nn.floatx + self.dtype = dtype + + super().__init__(**kwargs) + + def build_weights(self): + kernel_initializer = self.kernel_initializer + if kernel_initializer is None: + kernel_initializer = tf.initializers.he_normal() + + self.weight1 = tf.get_variable("weight1", (self.mlp_ch, self.in_ch), dtype=self.dtype, initializer=kernel_initializer) + self.bias1 = tf.get_variable("bias1", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros()) + self.weight2 = tf.get_variable("weight2", (self.mlp_ch, self.in_ch), dtype=self.dtype, initializer=kernel_initializer) + self.bias2 = tf.get_variable("bias2", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros()) + + def get_weights(self): + return [self.weight1, self.bias1, self.weight2, self.bias2] + + def forward(self, inputs): + x, mlp = inputs + + gamma = tf.matmul(mlp, self.weight1) + gamma = tf.add(gamma, tf.reshape(self.bias1, (1,self.in_ch) ) ) + + beta = tf.matmul(mlp, self.weight2) + beta = tf.add(beta, tf.reshape(self.bias2, (1,self.in_ch) ) ) + + + if nn.data_format == "NHWC": + shape = (-1,1,1,self.in_ch) + else: + shape = (-1,self.in_ch,1,1) + + x_mean = tf.reduce_mean(x, axis=nn.conv2d_spatial_axes, keepdims=True ) + x_std = tf.math.reduce_std(x, axis=nn.conv2d_spatial_axes, keepdims=True ) + 1e-5 + + x = (x - x_mean) / x_std + x *= tf.reshape(gamma, shape) + + x += tf.reshape(beta, shape) + + return x + +nn.AdaIN = AdaIN \ No newline at end of file diff --git a/core/leras/layers/BatchNorm2D.py b/core/leras/layers/BatchNorm2D.py new file mode 100644 index 0000000000000000000000000000000000000000..62de521c1a7d7b73e838fbfe7b3fbf38306da479 --- /dev/null +++ b/core/leras/layers/BatchNorm2D.py @@ -0,0 +1,42 @@ +from core.leras import nn +tf = nn.tf + +class BatchNorm2D(nn.LayerBase): + """ + currently not for training + """ + def __init__(self, dim, eps=1e-05, momentum=0.1, dtype=None, **kwargs): + self.dim = dim + self.eps = eps + self.momentum = momentum + if dtype is None: + dtype = nn.floatx + self.dtype = dtype + super().__init__(**kwargs) + + def build_weights(self): + self.weight = tf.get_variable("weight", (self.dim,), dtype=self.dtype, initializer=tf.initializers.ones() ) + self.bias = tf.get_variable("bias", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros() ) + self.running_mean = tf.get_variable("running_mean", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros(), trainable=False ) + self.running_var = tf.get_variable("running_var", (self.dim,), dtype=self.dtype, initializer=tf.initializers.zeros(), trainable=False ) + + def get_weights(self): + return [self.weight, self.bias, self.running_mean, self.running_var] + + def forward(self, x): + if nn.data_format == "NHWC": + shape = (1,1,1,self.dim) + else: + shape = (1,self.dim,1,1) + + weight = tf.reshape ( self.weight , shape ) + bias = tf.reshape ( self.bias , shape ) + running_mean = tf.reshape ( self.running_mean, shape ) + running_var = tf.reshape ( self.running_var , shape ) + + x = (x - running_mean) / tf.sqrt( running_var + self.eps ) + x *= weight + x += bias + return x + +nn.BatchNorm2D = BatchNorm2D \ No newline at end of file diff --git a/core/leras/layers/BlurPool.py b/core/leras/layers/BlurPool.py new file mode 100644 index 0000000000000000000000000000000000000000..3a2b26db0c29b691ab28d2c0e03bb173e77499db --- /dev/null +++ b/core/leras/layers/BlurPool.py @@ -0,0 +1,50 @@ +import numpy as np +from core.leras import nn +tf = nn.tf + +class BlurPool(nn.LayerBase): + def __init__(self, filt_size=3, stride=2, **kwargs ): + + if nn.data_format == "NHWC": + self.strides = [1,stride,stride,1] + else: + self.strides = [1,1,stride,stride] + + self.filt_size = filt_size + pad = [ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ] + + if nn.data_format == "NHWC": + self.padding = [ [0,0], pad, pad, [0,0] ] + else: + self.padding = [ [0,0], [0,0], pad, pad ] + + if(self.filt_size==1): + a = np.array([1.,]) + elif(self.filt_size==2): + a = np.array([1., 1.]) + elif(self.filt_size==3): + a = np.array([1., 2., 1.]) + elif(self.filt_size==4): + a = np.array([1., 3., 3., 1.]) + elif(self.filt_size==5): + a = np.array([1., 4., 6., 4., 1.]) + elif(self.filt_size==6): + a = np.array([1., 5., 10., 10., 5., 1.]) + elif(self.filt_size==7): + a = np.array([1., 6., 15., 20., 15., 6., 1.]) + + a = a[:,None]*a[None,:] + a = a / np.sum(a) + a = a[:,:,None,None] + self.a = a + super().__init__(**kwargs) + + def build_weights(self): + self.k = tf.constant (self.a, dtype=nn.floatx ) + + def forward(self, x): + k = tf.tile (self.k, (1,1,x.shape[nn.conv2d_ch_axis],1) ) + x = tf.pad(x, self.padding ) + x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID', data_format=nn.data_format) + return x +nn.BlurPool = BlurPool \ No newline at end of file diff --git a/core/leras/layers/Conv2D.py b/core/leras/layers/Conv2D.py new file mode 100644 index 0000000000000000000000000000000000000000..a5febf0866fa5a5dc7c13ba13ec40a7456b05299 --- /dev/null +++ b/core/leras/layers/Conv2D.py @@ -0,0 +1,114 @@ +import numpy as np +from core.leras import nn +tf = nn.tf + +class Conv2D(nn.LayerBase): + """ + default kernel_initializer - CA + use_wscale bool enables equalized learning rate, if kernel_initializer is None, it will be forced to random_normal + + + """ + def __init__(self, in_ch, out_ch, kernel_size, strides=1, padding='SAME', dilations=1, use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ): + if not isinstance(strides, int): + raise ValueError ("strides must be an int type") + if not isinstance(dilations, int): + raise ValueError ("dilations must be an int type") + kernel_size = int(kernel_size) + + if dtype is None: + dtype = nn.floatx + + if isinstance(padding, str): + if padding == "SAME": + padding = ( (kernel_size - 1) * dilations + 1 ) // 2 + elif padding == "VALID": + padding = None + else: + raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs") + else: + padding = int(padding) + + + + self.in_ch = in_ch + self.out_ch = out_ch + self.kernel_size = kernel_size + self.strides = strides + self.padding = padding + self.dilations = dilations + self.use_bias = use_bias + self.use_wscale = use_wscale + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.trainable = trainable + self.dtype = dtype + super().__init__(**kwargs) + + def build_weights(self): + kernel_initializer = self.kernel_initializer + if self.use_wscale: + gain = 1.0 if self.kernel_size == 1 else np.sqrt(2) + fan_in = self.kernel_size*self.kernel_size*self.in_ch + he_std = gain / np.sqrt(fan_in) + self.wscale = tf.constant(he_std, dtype=self.dtype ) + if kernel_initializer is None: + kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype) + + #if kernel_initializer is None: + # kernel_initializer = nn.initializers.ca() + + self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.out_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) + + if self.use_bias: + bias_initializer = self.bias_initializer + if bias_initializer is None: + bias_initializer = tf.initializers.zeros(dtype=self.dtype) + + self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable ) + + def get_weights(self): + weights = [self.weight] + if self.use_bias: + weights += [self.bias] + return weights + + def forward(self, x): + weight = self.weight + if self.use_wscale: + weight = weight * self.wscale + + padding = self.padding + if padding is not None: + if nn.data_format == "NHWC": + padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ] + else: + padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ] + x = tf.pad (x, padding, mode='CONSTANT') + + strides = self.strides + if nn.data_format == "NHWC": + strides = [1,strides,strides,1] + else: + strides = [1,1,strides,strides] + + dilations = self.dilations + if nn.data_format == "NHWC": + dilations = [1,dilations,dilations,1] + else: + dilations = [1,1,dilations,dilations] + + x = tf.nn.conv2d(x, weight, strides, 'VALID', dilations=dilations, data_format=nn.data_format) + if self.use_bias: + if nn.data_format == "NHWC": + bias = tf.reshape (self.bias, (1,1,1,self.out_ch) ) + else: + bias = tf.reshape (self.bias, (1,self.out_ch,1,1) ) + x = tf.add(x, bias) + return x + + def __str__(self): + r = f"{self.__class__.__name__} : in_ch:{self.in_ch} out_ch:{self.out_ch} " + + return r +nn.Conv2D = Conv2D \ No newline at end of file diff --git a/core/leras/layers/Conv2DTranspose.py b/core/leras/layers/Conv2DTranspose.py new file mode 100644 index 0000000000000000000000000000000000000000..a2e97dc3177ea387122db270a692a26574a57323 --- /dev/null +++ b/core/leras/layers/Conv2DTranspose.py @@ -0,0 +1,107 @@ +import numpy as np +from core.leras import nn +tf = nn.tf + +class Conv2DTranspose(nn.LayerBase): + """ + use_wscale enables weight scale (equalized learning rate) + if kernel_initializer is None, it will be forced to random_normal + """ + def __init__(self, in_ch, out_ch, kernel_size, strides=2, padding='SAME', use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ): + if not isinstance(strides, int): + raise ValueError ("strides must be an int type") + kernel_size = int(kernel_size) + + if dtype is None: + dtype = nn.floatx + + self.in_ch = in_ch + self.out_ch = out_ch + self.kernel_size = kernel_size + self.strides = strides + self.padding = padding + self.use_bias = use_bias + self.use_wscale = use_wscale + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.trainable = trainable + self.dtype = dtype + super().__init__(**kwargs) + + def build_weights(self): + kernel_initializer = self.kernel_initializer + if self.use_wscale: + gain = 1.0 if self.kernel_size == 1 else np.sqrt(2) + fan_in = self.kernel_size*self.kernel_size*self.in_ch + he_std = gain / np.sqrt(fan_in) # He init + self.wscale = tf.constant(he_std, dtype=self.dtype ) + if kernel_initializer is None: + kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype) + + #if kernel_initializer is None: + # kernel_initializer = nn.initializers.ca() + self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.out_ch,self.in_ch), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) + + if self.use_bias: + bias_initializer = self.bias_initializer + if bias_initializer is None: + bias_initializer = tf.initializers.zeros(dtype=self.dtype) + + self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable ) + + def get_weights(self): + weights = [self.weight] + if self.use_bias: + weights += [self.bias] + return weights + + def forward(self, x): + shape = x.shape + + if nn.data_format == "NHWC": + h,w,c = shape[1], shape[2], shape[3] + output_shape = tf.stack ( (tf.shape(x)[0], + self.deconv_length(w, self.strides, self.kernel_size, self.padding), + self.deconv_length(h, self.strides, self.kernel_size, self.padding), + self.out_ch) ) + + strides = [1,self.strides,self.strides,1] + else: + c,h,w = shape[1], shape[2], shape[3] + output_shape = tf.stack ( (tf.shape(x)[0], + self.out_ch, + self.deconv_length(w, self.strides, self.kernel_size, self.padding), + self.deconv_length(h, self.strides, self.kernel_size, self.padding), + ) ) + strides = [1,1,self.strides,self.strides] + weight = self.weight + if self.use_wscale: + weight = weight * self.wscale + + x = tf.nn.conv2d_transpose(x, weight, output_shape, strides, padding=self.padding, data_format=nn.data_format) + + if self.use_bias: + if nn.data_format == "NHWC": + bias = tf.reshape (self.bias, (1,1,1,self.out_ch) ) + else: + bias = tf.reshape (self.bias, (1,self.out_ch,1,1) ) + x = tf.add(x, bias) + return x + + def __str__(self): + r = f"{self.__class__.__name__} : in_ch:{self.in_ch} out_ch:{self.out_ch} " + + return r + + def deconv_length(self, dim_size, stride_size, kernel_size, padding): + assert padding in {'SAME', 'VALID', 'FULL'} + if dim_size is None: + return None + if padding == 'VALID': + dim_size = dim_size * stride_size + max(kernel_size - stride_size, 0) + elif padding == 'FULL': + dim_size = dim_size * stride_size - (stride_size + kernel_size - 2) + elif padding == 'SAME': + dim_size = dim_size * stride_size + return dim_size +nn.Conv2DTranspose = Conv2DTranspose \ No newline at end of file diff --git a/core/leras/layers/Dense.py b/core/leras/layers/Dense.py new file mode 100644 index 0000000000000000000000000000000000000000..54d3ba62e6d7f2a69d27f1fd02fd21da704a1233 --- /dev/null +++ b/core/leras/layers/Dense.py @@ -0,0 +1,76 @@ +import numpy as np +from core.leras import nn +tf = nn.tf + +class Dense(nn.LayerBase): + def __init__(self, in_ch, out_ch, use_bias=True, use_wscale=False, maxout_ch=0, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ): + """ + use_wscale enables weight scale (equalized learning rate) + if kernel_initializer is None, it will be forced to random_normal + + maxout_ch https://link.springer.com/article/10.1186/s40537-019-0233-0 + typical 2-4 if you want to enable DenseMaxout behaviour + """ + self.in_ch = in_ch + self.out_ch = out_ch + self.use_bias = use_bias + self.use_wscale = use_wscale + self.maxout_ch = maxout_ch + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.trainable = trainable + if dtype is None: + dtype = nn.floatx + + self.dtype = dtype + super().__init__(**kwargs) + + def build_weights(self): + if self.maxout_ch > 1: + weight_shape = (self.in_ch,self.out_ch*self.maxout_ch) + else: + weight_shape = (self.in_ch,self.out_ch) + + kernel_initializer = self.kernel_initializer + + if self.use_wscale: + gain = 1.0 + fan_in = np.prod( weight_shape[:-1] ) + he_std = gain / np.sqrt(fan_in) # He init + self.wscale = tf.constant(he_std, dtype=self.dtype ) + if kernel_initializer is None: + kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype) + + if kernel_initializer is None: + kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype) + + self.weight = tf.get_variable("weight", weight_shape, dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) + + if self.use_bias: + bias_initializer = self.bias_initializer + if bias_initializer is None: + bias_initializer = tf.initializers.zeros(dtype=self.dtype) + self.bias = tf.get_variable("bias", (self.out_ch,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable ) + + def get_weights(self): + weights = [self.weight] + if self.use_bias: + weights += [self.bias] + return weights + + def forward(self, x): + weight = self.weight + if self.use_wscale: + weight = weight * self.wscale + + x = tf.matmul(x, weight) + + if self.maxout_ch > 1: + x = tf.reshape (x, (-1, self.out_ch, self.maxout_ch) ) + x = tf.reduce_max(x, axis=-1) + + if self.use_bias: + x = tf.add(x, tf.reshape(self.bias, (1,self.out_ch) ) ) + + return x +nn.Dense = Dense \ No newline at end of file diff --git a/core/leras/layers/DenseNorm.py b/core/leras/layers/DenseNorm.py new file mode 100644 index 0000000000000000000000000000000000000000..594bf575d52ed88c4585ef7a3118b404fa5f9041 --- /dev/null +++ b/core/leras/layers/DenseNorm.py @@ -0,0 +1,16 @@ +from core.leras import nn +tf = nn.tf + +class DenseNorm(nn.LayerBase): + def __init__(self, dense=False, eps=1e-06, dtype=None, **kwargs): + self.dense = dense + if dtype is None: + dtype = nn.floatx + self.eps = tf.constant(eps, dtype=dtype, name="epsilon") + + super().__init__(**kwargs) + + def __call__(self, x): + return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + self.eps) + +nn.DenseNorm = DenseNorm \ No newline at end of file diff --git a/core/leras/layers/DepthwiseConv2D.py b/core/leras/layers/DepthwiseConv2D.py new file mode 100644 index 0000000000000000000000000000000000000000..2916f01d4178c6c5ae922c070c07d3da41ec11e0 --- /dev/null +++ b/core/leras/layers/DepthwiseConv2D.py @@ -0,0 +1,110 @@ +import numpy as np +from core.leras import nn +tf = nn.tf + +class DepthwiseConv2D(nn.LayerBase): + """ + default kernel_initializer - CA + use_wscale bool enables equalized learning rate, if kernel_initializer is None, it will be forced to random_normal + """ + def __init__(self, in_ch, kernel_size, strides=1, padding='SAME', depth_multiplier=1, dilations=1, use_bias=True, use_wscale=False, kernel_initializer=None, bias_initializer=None, trainable=True, dtype=None, **kwargs ): + if not isinstance(strides, int): + raise ValueError ("strides must be an int type") + if not isinstance(dilations, int): + raise ValueError ("dilations must be an int type") + kernel_size = int(kernel_size) + + if dtype is None: + dtype = nn.floatx + + if isinstance(padding, str): + if padding == "SAME": + padding = ( (kernel_size - 1) * dilations + 1 ) // 2 + elif padding == "VALID": + padding = 0 + else: + raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs") + + if isinstance(padding, int): + if padding != 0: + if nn.data_format == "NHWC": + padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ] + else: + padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ] + else: + padding = None + + if nn.data_format == "NHWC": + strides = [1,strides,strides,1] + else: + strides = [1,1,strides,strides] + + if nn.data_format == "NHWC": + dilations = [1,dilations,dilations,1] + else: + dilations = [1,1,dilations,dilations] + + self.in_ch = in_ch + self.depth_multiplier = depth_multiplier + self.kernel_size = kernel_size + self.strides = strides + self.padding = padding + self.dilations = dilations + self.use_bias = use_bias + self.use_wscale = use_wscale + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.trainable = trainable + self.dtype = dtype + super().__init__(**kwargs) + + def build_weights(self): + kernel_initializer = self.kernel_initializer + if self.use_wscale: + gain = 1.0 if self.kernel_size == 1 else np.sqrt(2) + fan_in = self.kernel_size*self.kernel_size*self.in_ch + he_std = gain / np.sqrt(fan_in) + self.wscale = tf.constant(he_std, dtype=self.dtype ) + if kernel_initializer is None: + kernel_initializer = tf.initializers.random_normal(0, 1.0, dtype=self.dtype) + + #if kernel_initializer is None: + # kernel_initializer = nn.initializers.ca() + + self.weight = tf.get_variable("weight", (self.kernel_size,self.kernel_size,self.in_ch,self.depth_multiplier), dtype=self.dtype, initializer=kernel_initializer, trainable=self.trainable ) + + if self.use_bias: + bias_initializer = self.bias_initializer + if bias_initializer is None: + bias_initializer = tf.initializers.zeros(dtype=self.dtype) + + self.bias = tf.get_variable("bias", (self.in_ch*self.depth_multiplier,), dtype=self.dtype, initializer=bias_initializer, trainable=self.trainable ) + + def get_weights(self): + weights = [self.weight] + if self.use_bias: + weights += [self.bias] + return weights + + def forward(self, x): + weight = self.weight + if self.use_wscale: + weight = weight * self.wscale + + if self.padding is not None: + x = tf.pad (x, self.padding, mode='CONSTANT') + + x = tf.nn.depthwise_conv2d(x, weight, self.strides, 'VALID', data_format=nn.data_format) + if self.use_bias: + if nn.data_format == "NHWC": + bias = tf.reshape (self.bias, (1,1,1,self.in_ch*self.depth_multiplier) ) + else: + bias = tf.reshape (self.bias, (1,self.in_ch*self.depth_multiplier,1,1) ) + x = tf.add(x, bias) + return x + + def __str__(self): + r = f"{self.__class__.__name__} : in_ch:{self.in_ch} depth_multiplier:{self.depth_multiplier} " + return r + +nn.DepthwiseConv2D = DepthwiseConv2D \ No newline at end of file diff --git a/core/leras/layers/FRNorm2D.py b/core/leras/layers/FRNorm2D.py new file mode 100644 index 0000000000000000000000000000000000000000..80f05972e804996fd9739bd59c05c208ee62aab2 --- /dev/null +++ b/core/leras/layers/FRNorm2D.py @@ -0,0 +1,38 @@ +from core.leras import nn +tf = nn.tf + +class FRNorm2D(nn.LayerBase): + """ + Tensorflow implementation of + Filter Response Normalization Layer: Eliminating Batch Dependence in theTraining of Deep Neural Networks + https://arxiv.org/pdf/1911.09737.pdf + """ + def __init__(self, in_ch, dtype=None, **kwargs): + self.in_ch = in_ch + + if dtype is None: + dtype = nn.floatx + self.dtype = dtype + + super().__init__(**kwargs) + + def build_weights(self): + self.weight = tf.get_variable("weight", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.ones() ) + self.bias = tf.get_variable("bias", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros() ) + self.eps = tf.get_variable("eps", (1,), dtype=self.dtype, initializer=tf.initializers.constant(1e-6) ) + + def get_weights(self): + return [self.weight, self.bias, self.eps] + + def forward(self, x): + if nn.data_format == "NHWC": + shape = (1,1,1,self.in_ch) + else: + shape = (1,self.in_ch,1,1) + weight = tf.reshape ( self.weight, shape ) + bias = tf.reshape ( self.bias , shape ) + nu2 = tf.reduce_mean(tf.square(x), axis=nn.conv2d_spatial_axes, keepdims=True) + x = x * ( 1.0/tf.sqrt(nu2 + tf.abs(self.eps) ) ) + + return x*weight + bias +nn.FRNorm2D = FRNorm2D \ No newline at end of file diff --git a/core/leras/layers/InstanceNorm2D.py b/core/leras/layers/InstanceNorm2D.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1a43fe3fb9243b154d9430d7a35ba2249d17b9 --- /dev/null +++ b/core/leras/layers/InstanceNorm2D.py @@ -0,0 +1,40 @@ +from core.leras import nn +tf = nn.tf + +class InstanceNorm2D(nn.LayerBase): + def __init__(self, in_ch, dtype=None, **kwargs): + self.in_ch = in_ch + + if dtype is None: + dtype = nn.floatx + self.dtype = dtype + + super().__init__(**kwargs) + + def build_weights(self): + kernel_initializer = tf.initializers.glorot_uniform(dtype=self.dtype) + self.weight = tf.get_variable("weight", (self.in_ch,), dtype=self.dtype, initializer=kernel_initializer ) + self.bias = tf.get_variable("bias", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros() ) + + def get_weights(self): + return [self.weight, self.bias] + + def forward(self, x): + if nn.data_format == "NHWC": + shape = (1,1,1,self.in_ch) + else: + shape = (1,self.in_ch,1,1) + + weight = tf.reshape ( self.weight , shape ) + bias = tf.reshape ( self.bias , shape ) + + x_mean = tf.reduce_mean(x, axis=nn.conv2d_spatial_axes, keepdims=True ) + x_std = tf.math.reduce_std(x, axis=nn.conv2d_spatial_axes, keepdims=True ) + 1e-5 + + x = (x - x_mean) / x_std + x *= weight + x += bias + + return x + +nn.InstanceNorm2D = InstanceNorm2D \ No newline at end of file diff --git a/core/leras/layers/LayerBase.py b/core/leras/layers/LayerBase.py new file mode 100644 index 0000000000000000000000000000000000000000..a71a11186f3f7ef64a504efd44071314b4657743 --- /dev/null +++ b/core/leras/layers/LayerBase.py @@ -0,0 +1,16 @@ +from core.leras import nn +tf = nn.tf + +class LayerBase(nn.Saveable): + #override + def build_weights(self): + pass + + #override + def forward(self, *args, **kwargs): + pass + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + +nn.LayerBase = LayerBase \ No newline at end of file diff --git a/core/leras/layers/Saveable.py b/core/leras/layers/Saveable.py new file mode 100644 index 0000000000000000000000000000000000000000..e72f5942ee76f1ccff86268866f155543ca13cfe --- /dev/null +++ b/core/leras/layers/Saveable.py @@ -0,0 +1,108 @@ +import pickle +from pathlib import Path +from core import pathex +import numpy as np + +from core.leras import nn + +tf = nn.tf + +class Saveable(): + def __init__(self, name=None): + self.name = name + + #override + def get_weights(self): + #return tf tensors that should be initialized/loaded/saved + return [] + + #override + def get_weights_np(self): + weights = self.get_weights() + if len(weights) == 0: + return [] + return nn.tf_sess.run (weights) + + def set_weights(self, new_weights): + weights = self.get_weights() + if len(weights) != len(new_weights): + raise ValueError ('len of lists mismatch') + + tuples = [] + for w, new_w in zip(weights, new_weights): + + if len(w.shape) != new_w.shape: + new_w = new_w.reshape(w.shape) + + tuples.append ( (w, new_w) ) + + nn.batch_set_value (tuples) + + def save_weights(self, filename, force_dtype=None): + d = {} + weights = self.get_weights() + + if self.name is None: + raise Exception("name must be defined.") + + name = self.name + + for w in weights: + w_val = nn.tf_sess.run (w).copy() + w_name_split = w.name.split('/', 1) + if name != w_name_split[0]: + raise Exception("weight first name != Saveable.name") + + if force_dtype is not None: + w_val = w_val.astype(force_dtype) + + d[ w_name_split[1] ] = w_val + + d_dumped = pickle.dumps (d, 4) + pathex.write_bytes_safe ( Path(filename), d_dumped ) + + def load_weights(self, filename): + """ + returns True if file exists + """ + filepath = Path(filename) + if filepath.exists(): + result = True + d_dumped = filepath.read_bytes() + d = pickle.loads(d_dumped) + else: + return False + + weights = self.get_weights() + + if self.name is None: + raise Exception("name must be defined.") + + try: + tuples = [] + for w in weights: + w_name_split = w.name.split('/') + if self.name != w_name_split[0]: + raise Exception("weight first name != Saveable.name") + + sub_w_name = "/".join(w_name_split[1:]) + + w_val = d.get(sub_w_name, None) + + if w_val is None: + #io.log_err(f"Weight {w.name} was not loaded from file {filename}") + tuples.append ( (w, w.initializer) ) + else: + w_val = np.reshape( w_val, w.shape.as_list() ) + tuples.append ( (w, w_val) ) + + nn.batch_set_value(tuples) + except: + return False + + return True + + def init_weights(self): + nn.init_weights(self.get_weights()) + +nn.Saveable = Saveable diff --git a/core/leras/layers/ScaleAdd.py b/core/leras/layers/ScaleAdd.py new file mode 100644 index 0000000000000000000000000000000000000000..06188b876fb08bc45ee1490d9646efddc52a9e8d --- /dev/null +++ b/core/leras/layers/ScaleAdd.py @@ -0,0 +1,31 @@ +from core.leras import nn +tf = nn.tf + +class ScaleAdd(nn.LayerBase): + def __init__(self, ch, dtype=None, **kwargs): + if dtype is None: + dtype = nn.floatx + self.dtype = dtype + self.ch = ch + + super().__init__(**kwargs) + + def build_weights(self): + self.weight = tf.get_variable("weight",(self.ch,), dtype=self.dtype, initializer=tf.initializers.zeros() ) + + def get_weights(self): + return [self.weight] + + def forward(self, inputs): + if nn.data_format == "NHWC": + shape = (1,1,1,self.ch) + else: + shape = (1,self.ch,1,1) + + weight = tf.reshape ( self.weight, shape ) + + x0, x1 = inputs + x = x0 + x1*weight + + return x +nn.ScaleAdd = ScaleAdd \ No newline at end of file diff --git a/core/leras/layers/TLU.py b/core/leras/layers/TLU.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4c25488dcb3f212affe163ac8dbe05b8a4add9 --- /dev/null +++ b/core/leras/layers/TLU.py @@ -0,0 +1,33 @@ +from core.leras import nn +tf = nn.tf + +class TLU(nn.LayerBase): + """ + Tensorflow implementation of + Filter Response Normalization Layer: Eliminating Batch Dependence in theTraining of Deep Neural Networks + https://arxiv.org/pdf/1911.09737.pdf + """ + def __init__(self, in_ch, dtype=None, **kwargs): + self.in_ch = in_ch + + if dtype is None: + dtype = nn.floatx + self.dtype = dtype + + super().__init__(**kwargs) + + def build_weights(self): + self.tau = tf.get_variable("tau", (self.in_ch,), dtype=self.dtype, initializer=tf.initializers.zeros() ) + + def get_weights(self): + return [self.tau] + + def forward(self, x): + if nn.data_format == "NHWC": + shape = (1,1,1,self.in_ch) + else: + shape = (1,self.in_ch,1,1) + + tau = tf.reshape ( self.tau, shape ) + return tf.math.maximum(x, tau) +nn.TLU = TLU \ No newline at end of file diff --git a/core/leras/layers/TanhPolar.py b/core/leras/layers/TanhPolar.py new file mode 100644 index 0000000000000000000000000000000000000000..8955f32b7d19fd29f28983f7623532093682eb23 --- /dev/null +++ b/core/leras/layers/TanhPolar.py @@ -0,0 +1,104 @@ +import numpy as np +from core.leras import nn +tf = nn.tf + +class TanhPolar(nn.LayerBase): + """ + RoI Tanh-polar Transformer Network for Face Parsing in the Wild + https://github.com/hhj1897/roi_tanh_warping + """ + + def __init__(self, width, height, angular_offset_deg=270, **kwargs): + self.width = width + self.height = height + + warp_gridx, warp_gridy = TanhPolar._get_tanh_polar_warp_grids(width,height,angular_offset_deg=angular_offset_deg) + restore_gridx, restore_gridy = TanhPolar._get_tanh_polar_restore_grids(width,height,angular_offset_deg=angular_offset_deg) + + self.warp_gridx_t = tf.constant(warp_gridx[None, ...]) + self.warp_gridy_t = tf.constant(warp_gridy[None, ...]) + self.restore_gridx_t = tf.constant(restore_gridx[None, ...]) + self.restore_gridy_t = tf.constant(restore_gridy[None, ...]) + + super().__init__(**kwargs) + + def warp(self, inp_t): + batch_t = tf.shape(inp_t)[0] + warp_gridx_t = tf.tile(self.warp_gridx_t, (batch_t,1,1) ) + warp_gridy_t = tf.tile(self.warp_gridy_t, (batch_t,1,1) ) + + if nn.data_format == "NCHW": + inp_t = tf.transpose(inp_t,(0,2,3,1)) + + out_t = nn.bilinear_sampler(inp_t, warp_gridx_t, warp_gridy_t) + + if nn.data_format == "NCHW": + out_t = tf.transpose(out_t,(0,3,1,2)) + + return out_t + + def restore(self, inp_t): + batch_t = tf.shape(inp_t)[0] + restore_gridx_t = tf.tile(self.restore_gridx_t, (batch_t,1,1) ) + restore_gridy_t = tf.tile(self.restore_gridy_t, (batch_t,1,1) ) + + if nn.data_format == "NCHW": + inp_t = tf.transpose(inp_t,(0,2,3,1)) + + inp_t = tf.pad(inp_t, [(0,0), (1, 1), (1, 0), (0, 0)], "SYMMETRIC") + + out_t = nn.bilinear_sampler(inp_t, restore_gridx_t, restore_gridy_t) + + if nn.data_format == "NCHW": + out_t = tf.transpose(out_t,(0,3,1,2)) + + return out_t + + @staticmethod + def _get_tanh_polar_warp_grids(W,H,angular_offset_deg): + angular_offset_pi = angular_offset_deg * np.pi / 180.0 + + roi_center = np.array([ W//2, H//2], np.float32 ) + roi_radii = np.array([W, H], np.float32 ) / np.pi ** 0.5 + cos_offset, sin_offset = np.cos(angular_offset_pi), np.sin(angular_offset_pi) + normalised_dest_indices = np.stack(np.meshgrid(np.arange(0.0, 1.0, 1.0 / W),np.arange(0.0, 2.0 * np.pi, 2.0 * np.pi / H)), axis=-1) + radii = normalised_dest_indices[..., 0] + orientation_x = np.cos(normalised_dest_indices[..., 1]) + orientation_y = np.sin(normalised_dest_indices[..., 1]) + + src_radii = np.arctanh(radii) * (roi_radii[0] * roi_radii[1] / np.sqrt(roi_radii[1] ** 2 * orientation_x ** 2 + roi_radii[0] ** 2 * orientation_y ** 2)) + src_x_indices = src_radii * orientation_x + src_y_indices = src_radii * orientation_y + src_x_indices, src_y_indices = (roi_center[0] + cos_offset * src_x_indices - sin_offset * src_y_indices, + roi_center[1] + cos_offset * src_y_indices + sin_offset * src_x_indices) + + return src_x_indices.astype(np.float32), src_y_indices.astype(np.float32) + + @staticmethod + def _get_tanh_polar_restore_grids(W,H,angular_offset_deg): + angular_offset_pi = angular_offset_deg * np.pi / 180.0 + + roi_center = np.array([ W//2, H//2], np.float32 ) + roi_radii = np.array([W, H], np.float32 ) / np.pi ** 0.5 + cos_offset, sin_offset = np.cos(angular_offset_pi), np.sin(angular_offset_pi) + + dest_indices = np.stack(np.meshgrid(np.arange(W), np.arange(H)), axis=-1).astype(float) + normalised_dest_indices = np.matmul(dest_indices - roi_center, np.array([[cos_offset, -sin_offset], + [sin_offset, cos_offset]])) + radii = np.linalg.norm(normalised_dest_indices, axis=-1) + normalised_dest_indices[..., 0] /= np.clip(radii, 1e-9, None) + normalised_dest_indices[..., 1] /= np.clip(radii, 1e-9, None) + radii *= np.sqrt(roi_radii[1] ** 2 * normalised_dest_indices[..., 0] ** 2 + + roi_radii[0] ** 2 * normalised_dest_indices[..., 1] ** 2) / roi_radii[0] / roi_radii[1] + + src_radii = np.tanh(radii) + + + src_x_indices = src_radii * W + 1.0 + src_y_indices = np.mod((np.arctan2(normalised_dest_indices[..., 1], normalised_dest_indices[..., 0]) / + 2.0 / np.pi) * H, H) + 1.0 + + return src_x_indices.astype(np.float32), src_y_indices.astype(np.float32) + + +nn.TanhPolar = TanhPolar \ No newline at end of file diff --git a/core/leras/layers/__init__.py b/core/leras/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4accaf640041e197c830811faf4bbdaa30c8df89 --- /dev/null +++ b/core/leras/layers/__init__.py @@ -0,0 +1,18 @@ +from .Saveable import * +from .LayerBase import * + +from .Conv2D import * +from .Conv2DTranspose import * +from .DepthwiseConv2D import * +from .Dense import * +from .BlurPool import * + +from .BatchNorm2D import * +from .InstanceNorm2D import * +from .FRNorm2D import * + +from .TLU import * +from .ScaleAdd import * +from .DenseNorm import * +from .AdaIN import * +from .TanhPolar import * \ No newline at end of file diff --git a/core/leras/models/CodeDiscriminator.py b/core/leras/models/CodeDiscriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..770cd2e656a601ec657ad8ef943a34c0c14fd796 --- /dev/null +++ b/core/leras/models/CodeDiscriminator.py @@ -0,0 +1,22 @@ +from core.leras import nn +tf = nn.tf + +class CodeDiscriminator(nn.ModelBase): + def on_build(self, in_ch, code_res, ch=256, conv_kernel_initializer=None): + n_downscales = 1 + code_res // 8 + + self.convs = [] + prev_ch = in_ch + for i in range(n_downscales): + cur_ch = ch * min( (2**i), 8 ) + self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=4 if i == 0 else 3, strides=2, padding='SAME', kernel_initializer=conv_kernel_initializer) ) + prev_ch = cur_ch + + self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer) + + def forward(self, x): + for conv in self.convs: + x = tf.nn.leaky_relu( conv(x), 0.1 ) + return self.out_conv(x) + +nn.CodeDiscriminator = CodeDiscriminator \ No newline at end of file diff --git a/core/leras/models/ModelBase.py b/core/leras/models/ModelBase.py new file mode 100644 index 0000000000000000000000000000000000000000..cc558a4bcb6066e6c21335b4116aaa6edb8db272 --- /dev/null +++ b/core/leras/models/ModelBase.py @@ -0,0 +1,244 @@ +import types +import numpy as np +from core.interact import interact as io +from core.leras import nn +tf = nn.tf + +class ModelBase(nn.Saveable): + def __init__(self, *args, name=None, **kwargs): + super().__init__(name=name) + self.layers = [] + self.layers_by_name = {} + self.built = False + self.args = args + self.kwargs = kwargs + self.run_placeholders = None + + def _build_sub(self, layer, name): + if isinstance (layer, list): + for i,sublayer in enumerate(layer): + self._build_sub(sublayer, f"{name}_{i}") + elif isinstance (layer, dict): + for subname in layer.keys(): + sublayer = layer[subname] + self._build_sub(sublayer, f"{name}_{subname}") + elif isinstance (layer, nn.LayerBase) or \ + isinstance (layer, ModelBase): + + if layer.name is None: + layer.name = name + + if isinstance (layer, nn.LayerBase): + with tf.variable_scope(layer.name): + layer.build_weights() + elif isinstance (layer, ModelBase): + layer.build() + + self.layers.append (layer) + self.layers_by_name[layer.name] = layer + + def xor_list(self, lst1, lst2): + return [value for value in lst1+lst2 if (value not in lst1) or (value not in lst2) ] + + def build(self): + with tf.variable_scope(self.name): + + current_vars = [] + generator = None + while True: + + if generator is None: + generator = self.on_build(*self.args, **self.kwargs) + if not isinstance(generator, types.GeneratorType): + generator = None + + if generator is not None: + try: + next(generator) + except StopIteration: + generator = None + + v = vars(self) + new_vars = self.xor_list (current_vars, list(v.keys()) ) + + for name in new_vars: + self._build_sub(v[name],name) + + current_vars += new_vars + + if generator is None: + break + + self.built = True + + #override + def get_weights(self): + if not self.built: + self.build() + + weights = [] + for layer in self.layers: + weights += layer.get_weights() + return weights + + def get_layer_by_name(self, name): + return self.layers_by_name.get(name, None) + + def get_layers(self): + if not self.built: + self.build() + layers = [] + for layer in self.layers: + if isinstance (layer, nn.LayerBase): + layers.append(layer) + else: + layers += layer.get_layers() + return layers + + #override + def on_build(self, *args, **kwargs): + """ + init model layers here + + return 'yield' if build is not finished + therefore dependency models will be initialized + """ + pass + + #override + def forward(self, *args, **kwargs): + #flow layers/models/tensors here + pass + + def __call__(self, *args, **kwargs): + if not self.built: + self.build() + + return self.forward(*args, **kwargs) + + # def compute_output_shape(self, shapes): + # if not self.built: + # self.build() + + # not_list = False + # if not isinstance(shapes, list): + # not_list = True + # shapes = [shapes] + + # with tf.device('/CPU:0'): + # # CPU tensors will not impact any performance, only slightly RAM "leakage" + # phs = [] + # for dtype,sh in shapes: + # phs += [ tf.placeholder(dtype, sh) ] + + # result = self.__call__(phs[0] if not_list else phs) + + # if not isinstance(result, list): + # result = [result] + + # result_shapes = [] + + # for t in result: + # result_shapes += [ t.shape.as_list() ] + + # return result_shapes[0] if not_list else result_shapes + + def build_for_run(self, shapes_list): + if not isinstance(shapes_list, list): + raise ValueError("shapes_list must be a list.") + + self.run_placeholders = [] + for dtype,sh in shapes_list: + self.run_placeholders.append ( tf.placeholder(dtype, sh) ) + + self.run_output = self.__call__(self.run_placeholders) + + def run (self, inputs): + if self.run_placeholders is None: + raise Exception ("Model didn't build for run.") + + if len(inputs) != len(self.run_placeholders): + raise ValueError("len(inputs) != self.run_placeholders") + + feed_dict = {} + for ph, inp in zip(self.run_placeholders, inputs): + feed_dict[ph] = inp + + return nn.tf_sess.run ( self.run_output, feed_dict=feed_dict) + + def summary(self): + layers = self.get_layers() + layers_names = [] + layers_params = [] + + max_len_str = 0 + max_len_param_str = 0 + delim_str = "-" + + total_params = 0 + + #Get layers names and str lenght for delim + for l in layers: + if len(str(l))>max_len_str: + max_len_str = len(str(l)) + layers_names+=[str(l).capitalize()] + + #Get params for each layer + layers_params = [ int(np.sum(np.prod(w.shape) for w in l.get_weights())) for l in layers ] + total_params = np.sum(layers_params) + + #Get str lenght for delim + for p in layers_params: + if len(str(p))>max_len_param_str: + max_len_param_str=len(str(p)) + + #Set delim + for i in range(max_len_str+max_len_param_str+3): + delim_str += "-" + + output = "\n"+delim_str+"\n" + + #Format model name str + model_name_str = "| "+self.name.capitalize() + len_model_name_str = len(model_name_str) + for i in range(len(delim_str)-len_model_name_str): + model_name_str+= " " if i!=(len(delim_str)-len_model_name_str-2) else " |" + + output += model_name_str +"\n" + output += delim_str +"\n" + + + #Format layers table + for i in range(len(layers_names)): + output += delim_str +"\n" + + l_name = layers_names[i] + l_param = str(layers_params[i]) + l_param_str = "" + if len(l_name)<=max_len_str: + for i in range(max_len_str - len(l_name)): + l_name+= " " + + if len(l_param)<=max_len_param_str: + for i in range(max_len_param_str - len(l_param)): + l_param_str+= " " + + l_param_str += l_param + + + output +="| "+l_name+"|"+l_param_str+"| \n" + + output += delim_str +"\n" + + #Format sum of params + total_params_str = "| Total params count: "+str(total_params) + len_total_params_str = len(total_params_str) + for i in range(len(delim_str)-len_total_params_str): + total_params_str+= " " if i!=(len(delim_str)-len_total_params_str-2) else " |" + + output += total_params_str +"\n" + output += delim_str +"\n" + + io.log_info(output) + +nn.ModelBase = ModelBase diff --git a/core/leras/models/PatchDiscriminator.py b/core/leras/models/PatchDiscriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..9b94e9f342af0372beed57dbda5f6aeb794817c0 --- /dev/null +++ b/core/leras/models/PatchDiscriminator.py @@ -0,0 +1,194 @@ +import numpy as np +from core.leras import nn +tf = nn.tf + +patch_discriminator_kernels = \ + { 1 : (512, [ [1,1] ]), + 2 : (512, [ [2,1] ]), + 3 : (512, [ [2,1], [2,1] ]), + 4 : (512, [ [2,2], [2,2] ]), + 5 : (512, [ [3,2], [2,2] ]), + 6 : (512, [ [4,2], [2,2] ]), + 7 : (512, [ [3,2], [3,2] ]), + 8 : (512, [ [4,2], [3,2] ]), + 9 : (512, [ [3,2], [4,2] ]), + 10 : (512, [ [4,2], [4,2] ]), + 11 : (512, [ [3,2], [3,2], [2,1] ]), + 12 : (512, [ [4,2], [3,2], [2,1] ]), + 13 : (512, [ [3,2], [4,2], [2,1] ]), + 14 : (512, [ [4,2], [4,2], [2,1] ]), + 15 : (512, [ [3,2], [3,2], [3,1] ]), + 16 : (512, [ [4,2], [3,2], [3,1] ]), + 17 : (512, [ [3,2], [4,2], [3,1] ]), + 18 : (512, [ [4,2], [4,2], [3,1] ]), + 19 : (512, [ [3,2], [3,2], [4,1] ]), + 20 : (512, [ [4,2], [3,2], [4,1] ]), + 21 : (512, [ [3,2], [4,2], [4,1] ]), + 22 : (512, [ [4,2], [4,2], [4,1] ]), + 23 : (256, [ [3,2], [3,2], [3,2], [2,1] ]), + 24 : (256, [ [4,2], [3,2], [3,2], [2,1] ]), + 25 : (256, [ [3,2], [4,2], [3,2], [2,1] ]), + 26 : (256, [ [4,2], [4,2], [3,2], [2,1] ]), + 27 : (256, [ [3,2], [4,2], [4,2], [2,1] ]), + 28 : (256, [ [4,2], [3,2], [4,2], [2,1] ]), + 29 : (256, [ [3,2], [4,2], [4,2], [2,1] ]), + 30 : (256, [ [4,2], [4,2], [4,2], [2,1] ]), + 31 : (256, [ [3,2], [3,2], [3,2], [3,1] ]), + 32 : (256, [ [4,2], [3,2], [3,2], [3,1] ]), + 33 : (256, [ [3,2], [4,2], [3,2], [3,1] ]), + 34 : (256, [ [4,2], [4,2], [3,2], [3,1] ]), + 35 : (256, [ [3,2], [4,2], [4,2], [3,1] ]), + 36 : (256, [ [4,2], [3,2], [4,2], [3,1] ]), + 37 : (256, [ [3,2], [4,2], [4,2], [3,1] ]), + 38 : (256, [ [4,2], [4,2], [4,2], [3,1] ]), + 39 : (256, [ [3,2], [3,2], [3,2], [4,1] ]), + 40 : (256, [ [4,2], [3,2], [3,2], [4,1] ]), + 41 : (256, [ [3,2], [4,2], [3,2], [4,1] ]), + 42 : (256, [ [4,2], [4,2], [3,2], [4,1] ]), + 43 : (256, [ [3,2], [4,2], [4,2], [4,1] ]), + 44 : (256, [ [4,2], [3,2], [4,2], [4,1] ]), + 45 : (256, [ [3,2], [4,2], [4,2], [4,1] ]), + 46 : (256, [ [4,2], [4,2], [4,2], [4,1] ]), + } + + +class PatchDiscriminator(nn.ModelBase): + def on_build(self, patch_size, in_ch, base_ch=None, conv_kernel_initializer=None): + suggested_base_ch, kernels_strides = patch_discriminator_kernels[patch_size] + + if base_ch is None: + base_ch = suggested_base_ch + + prev_ch = in_ch + self.convs = [] + for i, (kernel_size, strides) in enumerate(kernels_strides): + cur_ch = base_ch * min( (2**i), 8 ) + + self.convs.append ( nn.Conv2D( prev_ch, cur_ch, kernel_size=kernel_size, strides=strides, padding='SAME', kernel_initializer=conv_kernel_initializer) ) + prev_ch = cur_ch + + self.out_conv = nn.Conv2D( prev_ch, 1, kernel_size=1, padding='VALID', kernel_initializer=conv_kernel_initializer) + + def forward(self, x): + for conv in self.convs: + x = tf.nn.leaky_relu( conv(x), 0.1 ) + return self.out_conv(x) + +nn.PatchDiscriminator = PatchDiscriminator + +class UNetPatchDiscriminator(nn.ModelBase): + """ + Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks" + """ + def calc_receptive_field_size(self, layers): + """ + result the same as https://fomoro.com/research/article/receptive-field-calculatorindex.html + """ + rf = 0 + ts = 1 + for i, (k, s) in enumerate(layers): + if i == 0: + rf = k + else: + rf += (k-1)*ts + ts *= s + return rf + + def find_archi(self, target_patch_size, max_layers=9): + """ + Find the best configuration of layers using only 3x3 convs for target patch size + """ + s = {} + for layers_count in range(1,max_layers+1): + val = 1 << (layers_count-1) + while True: + val -= 1 + + layers = [] + sum_st = 0 + layers.append ( [3, 2]) + sum_st += 2 + for i in range(layers_count-1): + st = 1 + (1 if val & (1 << i) !=0 else 0 ) + layers.append ( [3, st ]) + sum_st += st + + rf = self.calc_receptive_field_size(layers) + + s_rf = s.get(rf, None) + if s_rf is None: + s[rf] = (layers_count, sum_st, layers) + else: + if layers_count < s_rf[0] or \ + ( layers_count == s_rf[0] and sum_st > s_rf[1] ): + s[rf] = (layers_count, sum_st, layers) + + if val == 0: + break + + x = sorted(list(s.keys())) + q=x[np.abs(np.array(x)-target_patch_size).argmin()] + return s[q][2] + + def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False): + self.use_fp16 = use_fp16 + conv_dtype = tf.float16 if use_fp16 else tf.float32 + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, inp): + x = self.conv1(inp) + x = tf.nn.leaky_relu(x, 0.2) + x = self.conv2(x) + x = tf.nn.leaky_relu(inp + x, 0.2) + return x + + prev_ch = in_ch + self.convs = [] + self.upconvs = [] + layers = self.find_archi(patch_size) + + level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) } + + self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype) + + for i, (kernel_size, strides) in enumerate(layers): + self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) + + self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) + + self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype) + + self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype) + self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype) + + + def forward(self, x): + if self.use_fp16: + x = tf.cast(x, tf.float16) + + x = tf.nn.leaky_relu( self.in_conv(x), 0.2 ) + + encs = [] + for conv in self.convs: + encs.insert(0, x) + x = tf.nn.leaky_relu( conv(x), 0.2 ) + + center_out, x = self.center_out(x), tf.nn.leaky_relu( self.center_conv(x), 0.2 ) + + for i, (upconv, enc) in enumerate(zip(self.upconvs, encs)): + x = tf.nn.leaky_relu( upconv(x), 0.2 ) + x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis) + + x = self.out_conv(x) + + if self.use_fp16: + center_out = tf.cast(center_out, tf.float32) + x = tf.cast(x, tf.float32) + + return center_out, x + +nn.UNetPatchDiscriminator = UNetPatchDiscriminator diff --git "a/core/leras/models/XSeg \342\200\224 \320\272\320\276\320\277\320\270\321\217.py" "b/core/leras/models/XSeg \342\200\224 \320\272\320\276\320\277\320\270\321\217.py" new file mode 100644 index 0000000000000000000000000000000000000000..95e942c5d847132bfafed0745d1ef5c47b95c3e4 --- /dev/null +++ "b/core/leras/models/XSeg \342\200\224 \320\272\320\276\320\277\320\270\321\217.py" @@ -0,0 +1,166 @@ +from core.leras import nn +tf = nn.tf + +class XSeg(nn.ModelBase): + + def on_build (self, in_ch, base_ch, out_ch): + + class ConvBlock(nn.ModelBase): + def on_build(self, in_ch, out_ch): + self.conv = nn.Conv2D (in_ch, out_ch, kernel_size=3, padding='SAME') + self.frn = nn.FRNorm2D(out_ch) + self.tlu = nn.TLU(out_ch) + + def forward(self, x): + x = self.conv(x) + x = self.frn(x) + x = self.tlu(x) + return x + + class UpConvBlock(nn.ModelBase): + def on_build(self, in_ch, out_ch): + self.conv = nn.Conv2DTranspose (in_ch, out_ch, kernel_size=3, padding='SAME') + self.frn = nn.FRNorm2D(out_ch) + self.tlu = nn.TLU(out_ch) + + def forward(self, x): + x = self.conv(x) + x = self.frn(x) + x = self.tlu(x) + return x + + self.conv01 = ConvBlock(in_ch, base_ch) + self.conv02 = ConvBlock(base_ch, base_ch) + self.bp0 = nn.BlurPool (filt_size=3) + + + self.conv11 = ConvBlock(base_ch, base_ch*2) + self.conv12 = ConvBlock(base_ch*2, base_ch*2) + self.bp1 = nn.BlurPool (filt_size=3) + + self.conv21 = ConvBlock(base_ch*2, base_ch*4) + self.conv22 = ConvBlock(base_ch*4, base_ch*4) + self.conv23 = ConvBlock(base_ch*4, base_ch*4) + self.bp2 = nn.BlurPool (filt_size=3) + + + self.conv31 = ConvBlock(base_ch*4, base_ch*8) + self.conv32 = ConvBlock(base_ch*8, base_ch*8) + self.conv33 = ConvBlock(base_ch*8, base_ch*8) + self.bp3 = nn.BlurPool (filt_size=3) + + self.conv41 = ConvBlock(base_ch*8, base_ch*8) + self.conv42 = ConvBlock(base_ch*8, base_ch*8) + self.conv43 = ConvBlock(base_ch*8, base_ch*8) + self.bp4 = nn.BlurPool (filt_size=3) + + self.up4 = UpConvBlock (base_ch*8, base_ch*4) + self.uconv43 = ConvBlock(base_ch*12, base_ch*8) + self.uconv42 = ConvBlock(base_ch*8, base_ch*8) + self.uconv41 = ConvBlock(base_ch*8, base_ch*8) + + self.up3 = UpConvBlock (base_ch*8, base_ch*4) + self.uconv33 = ConvBlock(base_ch*12, base_ch*8) + self.uconv32 = ConvBlock(base_ch*8, base_ch*8) + self.uconv31 = ConvBlock(base_ch*8, base_ch*8) + + self.up2 = UpConvBlock (base_ch*8, base_ch*4) + self.uconv23 = ConvBlock(base_ch*8, base_ch*4) + self.uconv22 = ConvBlock(base_ch*4, base_ch*4) + self.uconv21 = ConvBlock(base_ch*4, base_ch*4) + + self.up1 = UpConvBlock (base_ch*4, base_ch*2) + self.uconv12 = ConvBlock(base_ch*4, base_ch*2) + self.uconv11 = ConvBlock(base_ch*2, base_ch*2) + + self.up0 = UpConvBlock (base_ch*2, base_ch) + self.uconv02 = ConvBlock(base_ch*2, base_ch) + self.uconv01 = ConvBlock(base_ch, base_ch) + + self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME') + + self.conv_center = ConvBlock(base_ch*8, base_ch*8) + + #self.ae_latent_enc = nn.Dense( base_ch*8, 64 ) + #self.ae_latent_dec = nn.Dense( 64, base_ch*8 ) + + #self.ae_up4 = nn.Conv2D( base_ch*8, base_ch*8 *4, kernel_size=3, padding='SAME') + #self.ae_up3 = nn.Conv2D( base_ch*8, base_ch*8 *4, kernel_size=3, padding='SAME') + #self.ae_up2 = nn.Conv2D( base_ch*8, base_ch*4 *4, kernel_size=3, padding='SAME') + #self.ae_up1 = nn.Conv2D( base_ch*4, base_ch*2 *4, kernel_size=3, padding='SAME') + #self.ae_up0 = nn.Conv2D( base_ch*2, base_ch *4, kernel_size=3, padding='SAME') + + + + def forward(self, inp): + x = inp + + x = self.conv01(x) + x = x0 = self.conv02(x) + x = self.bp0(x) + + x = self.conv11(x) + x = x1 = self.conv12(x) + x = self.bp1(x) + + x = self.conv21(x) + x = self.conv22(x) + x = x2 = self.conv23(x) + x = self.bp2(x) + + x = self.conv31(x) + x = self.conv32(x) + x = x3 = self.conv33(x) + x = self.bp3(x) + + x = self.conv41(x) + x = self.conv42(x) + x = x4 = self.conv43(x) + x = self.bp4(x) + + ae_x = x = self.conv_center(x) + + + + x = self.up4(x) + x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis)) + x = self.uconv42(x) + x = self.uconv41(x) + + x = self.up3(x) + x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis)) + x = self.uconv32(x) + x = self.uconv31(x) + + x = self.up2(x) + x = self.uconv23(tf.concat([x,x2],axis=nn.conv2d_ch_axis)) + x = self.uconv22(x) + x = self.uconv21(x) + + x = self.up1(x) + x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis)) + x = self.uconv11(x) + + x = self.up0(x) + x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis)) + x = self.uconv01(x) + + """ + ae_x = nn.flatten(x) + ae_x = self.ae_latent_enc(ae_x) + ae_x = self.ae_latent_dec(ae_x) + ae_x = nn.reshape_4D (ae_x, 8, 8, 64) + + ae_x = nn.depth_to_space(tf.nn.leaky_relu(self.ae_up4(ae_x), 0.1), 2) + ae_x = nn.depth_to_space(tf.nn.leaky_relu(self.ae_up3(ae_x), 0.1), 2) + ae_x = nn.depth_to_space(tf.nn.leaky_relu(self.ae_up2(ae_x), 0.1), 2) + ae_x = nn.depth_to_space(tf.nn.leaky_relu(self.ae_up1(ae_x), 0.1), 2) + ae_x = nn.depth_to_space(tf.nn.leaky_relu(self.ae_up0(ae_x), 0.1), 2) + + x = tf.concat([x,ae_x],axis=nn.conv2d_ch_axis) + """ + + logits = self.out_conv(x) + return logits, tf.nn.sigmoid(logits) + +nn.XSeg = XSeg \ No newline at end of file diff --git a/core/leras/models/XSeg.py b/core/leras/models/XSeg.py new file mode 100644 index 0000000000000000000000000000000000000000..f59eb8cde8804bee5b5513e989f81be0132c74e2 --- /dev/null +++ b/core/leras/models/XSeg.py @@ -0,0 +1,170 @@ +from core.leras import nn +tf = nn.tf + +class XSeg(nn.ModelBase): + + def on_build (self, in_ch, base_ch, out_ch): + + class ConvBlock(nn.ModelBase): + def on_build(self, in_ch, out_ch): + self.conv = nn.Conv2D (in_ch, out_ch, kernel_size=3, padding='SAME') + self.frn = nn.FRNorm2D(out_ch) + self.tlu = nn.TLU(out_ch) + + def forward(self, x): + x = self.conv(x) + x = self.frn(x) + x = self.tlu(x) + return x + + class UpConvBlock(nn.ModelBase): + def on_build(self, in_ch, out_ch): + self.conv = nn.Conv2DTranspose (in_ch, out_ch, kernel_size=3, padding='SAME') + self.frn = nn.FRNorm2D(out_ch) + self.tlu = nn.TLU(out_ch) + + def forward(self, x): + x = self.conv(x) + x = self.frn(x) + x = self.tlu(x) + return x + + self.base_ch = base_ch + + self.conv01 = ConvBlock(in_ch, base_ch) + self.conv02 = ConvBlock(base_ch, base_ch) + self.bp0 = nn.BlurPool (filt_size=4) + + self.conv11 = ConvBlock(base_ch, base_ch*2) + self.conv12 = ConvBlock(base_ch*2, base_ch*2) + self.bp1 = nn.BlurPool (filt_size=3) + + self.conv21 = ConvBlock(base_ch*2, base_ch*4) + self.conv22 = ConvBlock(base_ch*4, base_ch*4) + self.bp2 = nn.BlurPool (filt_size=2) + + self.conv31 = ConvBlock(base_ch*4, base_ch*8) + self.conv32 = ConvBlock(base_ch*8, base_ch*8) + self.conv33 = ConvBlock(base_ch*8, base_ch*8) + self.bp3 = nn.BlurPool (filt_size=2) + + self.conv41 = ConvBlock(base_ch*8, base_ch*8) + self.conv42 = ConvBlock(base_ch*8, base_ch*8) + self.conv43 = ConvBlock(base_ch*8, base_ch*8) + self.bp4 = nn.BlurPool (filt_size=2) + + self.conv51 = ConvBlock(base_ch*8, base_ch*8) + self.conv52 = ConvBlock(base_ch*8, base_ch*8) + self.conv53 = ConvBlock(base_ch*8, base_ch*8) + self.bp5 = nn.BlurPool (filt_size=2) + + self.dense1 = nn.Dense ( 4*4* base_ch*8, 512) + self.dense2 = nn.Dense ( 512, 4*4* base_ch*8) + + self.up5 = UpConvBlock (base_ch*8, base_ch*4) + self.uconv53 = ConvBlock(base_ch*12, base_ch*8) + self.uconv52 = ConvBlock(base_ch*8, base_ch*8) + self.uconv51 = ConvBlock(base_ch*8, base_ch*8) + + self.up4 = UpConvBlock (base_ch*8, base_ch*4) + self.uconv43 = ConvBlock(base_ch*12, base_ch*8) + self.uconv42 = ConvBlock(base_ch*8, base_ch*8) + self.uconv41 = ConvBlock(base_ch*8, base_ch*8) + + self.up3 = UpConvBlock (base_ch*8, base_ch*4) + self.uconv33 = ConvBlock(base_ch*12, base_ch*8) + self.uconv32 = ConvBlock(base_ch*8, base_ch*8) + self.uconv31 = ConvBlock(base_ch*8, base_ch*8) + + self.up2 = UpConvBlock (base_ch*8, base_ch*4) + self.uconv22 = ConvBlock(base_ch*8, base_ch*4) + self.uconv21 = ConvBlock(base_ch*4, base_ch*4) + + self.up1 = UpConvBlock (base_ch*4, base_ch*2) + self.uconv12 = ConvBlock(base_ch*4, base_ch*2) + self.uconv11 = ConvBlock(base_ch*2, base_ch*2) + + self.up0 = UpConvBlock (base_ch*2, base_ch) + self.uconv02 = ConvBlock(base_ch*2, base_ch) + self.uconv01 = ConvBlock(base_ch, base_ch) + self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME') + + + def forward(self, inp, pretrain=False): + x = inp + + x = self.conv01(x) + x = x0 = self.conv02(x) + x = self.bp0(x) + + x = self.conv11(x) + x = x1 = self.conv12(x) + x = self.bp1(x) + + x = self.conv21(x) + x = x2 = self.conv22(x) + x = self.bp2(x) + + x = self.conv31(x) + x = self.conv32(x) + x = x3 = self.conv33(x) + x = self.bp3(x) + + x = self.conv41(x) + x = self.conv42(x) + x = x4 = self.conv43(x) + x = self.bp4(x) + + x = self.conv51(x) + x = self.conv52(x) + x = x5 = self.conv53(x) + x = self.bp5(x) + + x = nn.flatten(x) + x = self.dense1(x) + x = self.dense2(x) + x = nn.reshape_4D (x, 4, 4, self.base_ch*8 ) + + x = self.up5(x) + if pretrain: + x5 = tf.zeros_like(x5) + x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis)) + x = self.uconv52(x) + x = self.uconv51(x) + + x = self.up4(x) + if pretrain: + x4 = tf.zeros_like(x4) + x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis)) + x = self.uconv42(x) + x = self.uconv41(x) + + x = self.up3(x) + if pretrain: + x3 = tf.zeros_like(x3) + x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis)) + x = self.uconv32(x) + x = self.uconv31(x) + + x = self.up2(x) + if pretrain: + x2 = tf.zeros_like(x2) + x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis)) + x = self.uconv21(x) + + x = self.up1(x) + if pretrain: + x1 = tf.zeros_like(x1) + x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis)) + x = self.uconv11(x) + + x = self.up0(x) + if pretrain: + x0 = tf.zeros_like(x0) + x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis)) + x = self.uconv01(x) + + logits = self.out_conv(x) + return logits, tf.nn.sigmoid(logits) + +nn.XSeg = XSeg \ No newline at end of file diff --git a/core/leras/models/__init__.py b/core/leras/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f7e5456f74106a0f78268a7bded82eeefee83b4 --- /dev/null +++ b/core/leras/models/__init__.py @@ -0,0 +1,4 @@ +from .ModelBase import * +from .PatchDiscriminator import * +from .CodeDiscriminator import * +from .XSeg import * \ No newline at end of file diff --git a/core/leras/nn.py b/core/leras/nn.py new file mode 100644 index 0000000000000000000000000000000000000000..f392aaf9b115d110b6557a38bdd3d1ffe69bca75 --- /dev/null +++ b/core/leras/nn.py @@ -0,0 +1,300 @@ +""" +Leras. + +like lighter keras. +This is my lightweight neural network library written from scratch +based on pure tensorflow without keras. + +Provides: ++ full freedom of tensorflow operations without keras model's restrictions ++ easy model operations like in PyTorch, but in graph mode (no eager execution) ++ convenient and understandable logic + +Reasons why we cannot import tensorflow or any tensorflow.sub modules right here: +1) program is changing env variables based on DeviceConfig before import tensorflow +2) multiprocesses will import tensorflow every spawn + +NCHW speed up training for 10-20%. +""" + +import os +import sys +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) +from pathlib import Path +import numpy as np +from core.interact import interact as io +from .device import Devices + + +class nn(): + current_DeviceConfig = None + + tf = None + tf_sess = None + tf_sess_config = None + tf_default_device_name = None + + data_format = None + conv2d_ch_axis = None + conv2d_spatial_axes = None + + floatx = None + + @staticmethod + def initialize(device_config=None, floatx="float32", data_format="NHWC"): + + if nn.tf is None: + if device_config is None: + device_config = nn.getCurrentDeviceConfig() + nn.setCurrentDeviceConfig(device_config) + + # Manipulate environment variables before import tensorflow + + first_run = False + if len(device_config.devices) != 0: + if sys.platform[0:3] == 'win': + # Windows specific env vars + if all( [ x.name == device_config.devices[0].name for x in device_config.devices ] ): + devices_str = "_" + device_config.devices[0].name.replace(' ','_') + else: + devices_str = "" + for device in device_config.devices: + devices_str += "_" + device.name.replace(' ','_') + + compute_cache_path = Path(os.environ['APPDATA']) / 'NVIDIA' / ('ComputeCache' + devices_str) + if not compute_cache_path.exists(): + first_run = True + compute_cache_path.mkdir(parents=True, exist_ok=True) + os.environ['CUDA_CACHE_PATH'] = str(compute_cache_path) + + if first_run: + io.log_info("Caching GPU kernels...") + + import tensorflow + + tf_version = tensorflow.version.VERSION + #if tf_version is None: + # tf_version = tensorflow.version.GIT_VERSION + if tf_version[0] == 'v': + tf_version = tf_version[1:] + if tf_version[0] == '2': + tf = tensorflow.compat.v1 + else: + tf = tensorflow + + import logging + # Disable tensorflow warnings + tf_logger = logging.getLogger('tensorflow') + tf_logger.setLevel(logging.ERROR) + + if tf_version[0] == '2': + tf.disable_v2_behavior() + nn.tf = tf + + # Initialize framework + import core.leras.ops + import core.leras.layers + import core.leras.initializers + import core.leras.optimizers + import core.leras.models + import core.leras.archis + + # Configure tensorflow session-config + if len(device_config.devices) == 0: + config = tf.ConfigProto(device_count={'GPU': 0}) + nn.tf_default_device_name = '/CPU:0' + else: + nn.tf_default_device_name = f'/{device_config.devices[0].tf_dev_type}:0' + + config = tf.ConfigProto() + config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices]) + + config.gpu_options.force_gpu_compatible = True + config.gpu_options.allow_growth = True + nn.tf_sess_config = config + + if nn.tf_sess is None: + nn.tf_sess = tf.Session(config=nn.tf_sess_config) + + if floatx == "float32": + floatx = nn.tf.float32 + elif floatx == "float16": + floatx = nn.tf.float16 + else: + raise ValueError(f"unsupported floatx {floatx}") + nn.set_floatx(floatx) + nn.set_data_format(data_format) + + @staticmethod + def initialize_main_env(): + Devices.initialize_main_env() + + @staticmethod + def set_floatx(tf_dtype): + """ + set default float type for all layers when dtype is None for them + """ + nn.floatx = tf_dtype + + @staticmethod + def set_data_format(data_format): + if data_format != "NHWC" and data_format != "NCHW": + raise ValueError(f"unsupported data_format {data_format}") + nn.data_format = data_format + + if data_format == "NHWC": + nn.conv2d_ch_axis = 3 + nn.conv2d_spatial_axes = [1,2] + elif data_format == "NCHW": + nn.conv2d_ch_axis = 1 + nn.conv2d_spatial_axes = [2,3] + + @staticmethod + def get4Dshape ( w, h, c ): + """ + returns 4D shape based on current data_format + """ + if nn.data_format == "NHWC": + return (None,h,w,c) + else: + return (None,c,h,w) + + @staticmethod + def to_data_format( x, to_data_format, from_data_format): + if to_data_format == from_data_format: + return x + + if to_data_format == "NHWC": + return np.transpose(x, (0,2,3,1) ) + elif to_data_format == "NCHW": + return np.transpose(x, (0,3,1,2) ) + else: + raise ValueError(f"unsupported to_data_format {to_data_format}") + + @staticmethod + def getCurrentDeviceConfig(): + if nn.current_DeviceConfig is None: + nn.current_DeviceConfig = DeviceConfig.BestGPU() + return nn.current_DeviceConfig + + @staticmethod + def setCurrentDeviceConfig(device_config): + nn.current_DeviceConfig = device_config + + @staticmethod + def reset_session(): + if nn.tf is not None: + if nn.tf_sess is not None: + nn.tf.reset_default_graph() + nn.tf_sess.close() + nn.tf_sess = nn.tf.Session(config=nn.tf_sess_config) + + @staticmethod + def close_session(): + if nn.tf_sess is not None: + nn.tf.reset_default_graph() + nn.tf_sess.close() + nn.tf_sess = None + + @staticmethod + def ask_choose_device_idxs(choose_only_one=False, allow_cpu=True, suggest_best_multi_gpu=False, suggest_all_gpu=False): + devices = Devices.getDevices() + if len(devices) == 0: + return [] + + all_devices_indexes = [device.index for device in devices] + + if choose_only_one: + suggest_best_multi_gpu = False + suggest_all_gpu = False + + if suggest_all_gpu: + best_device_indexes = all_devices_indexes + elif suggest_best_multi_gpu: + best_device_indexes = [device.index for device in devices.get_equal_devices(devices.get_best_device()) ] + else: + best_device_indexes = [ devices.get_best_device().index ] + best_device_indexes = ",".join([str(x) for x in best_device_indexes]) + + io.log_info ("") + if choose_only_one: + io.log_info ("Choose one GPU idx.") + else: + io.log_info ("Choose one or several GPU idxs (separated by comma).") + io.log_info ("") + + if allow_cpu: + io.log_info ("[CPU] : CPU") + for device in devices: + io.log_info (f" [{device.index}] : {device.name}") + + io.log_info ("") + + while True: + try: + if choose_only_one: + choosed_idxs = io.input_str("Which GPU index to choose?", best_device_indexes) + else: + choosed_idxs = io.input_str("Which GPU indexes to choose?", best_device_indexes) + + if allow_cpu and choosed_idxs.lower() == "cpu": + choosed_idxs = [] + break + + choosed_idxs = [ int(x) for x in choosed_idxs.split(',') ] + + if choose_only_one: + if len(choosed_idxs) == 1: + break + else: + if all( [idx in all_devices_indexes for idx in choosed_idxs] ): + break + except: + pass + io.log_info ("") + + return choosed_idxs + + class DeviceConfig(): + @staticmethod + def ask_choose_device(*args, **kwargs): + return nn.DeviceConfig.GPUIndexes( nn.ask_choose_device_idxs(*args,**kwargs) ) + + def __init__ (self, devices=None): + devices = devices or [] + + if not isinstance(devices, Devices): + devices = Devices(devices) + + self.devices = devices + self.cpu_only = len(devices) == 0 + + @staticmethod + def BestGPU(): + devices = Devices.getDevices() + if len(devices) == 0: + return nn.DeviceConfig.CPU() + + return nn.DeviceConfig([devices.get_best_device()]) + + @staticmethod + def WorstGPU(): + devices = Devices.getDevices() + if len(devices) == 0: + return nn.DeviceConfig.CPU() + + return nn.DeviceConfig([devices.get_worst_device()]) + + @staticmethod + def GPUIndexes(indexes): + if len(indexes) != 0: + devices = Devices.getDevices().get_devices_from_index_list(indexes) + else: + devices = [] + + return nn.DeviceConfig(devices) + + @staticmethod + def CPU(): + return nn.DeviceConfig([]) diff --git a/core/leras/ops/__init__.py b/core/leras/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd690da363cf0c1799d01808507d526c370f825b --- /dev/null +++ b/core/leras/ops/__init__.py @@ -0,0 +1,478 @@ +import numpy as np +from core.leras import nn +tf = nn.tf +from tensorflow.python.ops import array_ops, random_ops, math_ops, sparse_ops, gradients +from tensorflow.python.framework import sparse_tensor + +def tf_get_value(tensor): + return nn.tf_sess.run (tensor) +nn.tf_get_value = tf_get_value + + +def batch_set_value(tuples): + if len(tuples) != 0: + with nn.tf.device('/CPU:0'): + assign_ops = [] + feed_dict = {} + + for x, value in tuples: + if isinstance(value, nn.tf.Operation) or \ + isinstance(value, nn.tf.Variable): + assign_ops.append(value) + else: + value = np.asarray(value, dtype=x.dtype.as_numpy_dtype) + assign_placeholder = nn.tf.placeholder( x.dtype.base_dtype, shape=[None]*value.ndim ) + assign_op = nn.tf.assign (x, assign_placeholder ) + assign_ops.append(assign_op) + feed_dict[assign_placeholder] = value + + nn.tf_sess.run(assign_ops, feed_dict=feed_dict) +nn.batch_set_value = batch_set_value + +def init_weights(weights): + ops = [] + + ca_tuples_w = [] + ca_tuples = [] + for w in weights: + initializer = w.initializer + for input in initializer.inputs: + if "_cai_" in input.name: + ca_tuples_w.append (w) + ca_tuples.append ( (w.shape.as_list(), w.dtype.as_numpy_dtype) ) + break + else: + ops.append (initializer) + + if len(ops) != 0: + nn.tf_sess.run (ops) + + if len(ca_tuples) != 0: + nn.batch_set_value( [*zip(ca_tuples_w, nn.initializers.ca.generate_batch (ca_tuples))] ) +nn.init_weights = init_weights + +def tf_gradients ( loss, vars ): + grads = gradients.gradients(loss, vars, colocate_gradients_with_ops=True ) + gv = [*zip(grads,vars)] + for g,v in gv: + if g is None: + raise Exception(f"Variable {v.name} is declared as trainable, but no tensors flow through it.") + return gv +nn.gradients = tf_gradients + +def average_gv_list(grad_var_list, tf_device_string=None): + if len(grad_var_list) == 1: + return grad_var_list[0] + + e = tf.device(tf_device_string) if tf_device_string is not None else None + if e is not None: e.__enter__() + result = [] + for i, (gv) in enumerate(grad_var_list): + for j,(g,v) in enumerate(gv): + g = tf.expand_dims(g, 0) + if i == 0: + result += [ [[g], v] ] + else: + result[j][0] += [g] + + for i,(gs,v) in enumerate(result): + result[i] = ( tf.reduce_mean( tf.concat (gs, 0), 0 ), v ) + if e is not None: e.__exit__(None,None,None) + return result +nn.average_gv_list = average_gv_list + +def average_tensor_list(tensors_list, tf_device_string=None): + if len(tensors_list) == 1: + return tensors_list[0] + + e = tf.device(tf_device_string) if tf_device_string is not None else None + if e is not None: e.__enter__() + result = tf.reduce_mean(tf.concat ([tf.expand_dims(t, 0) for t in tensors_list], 0), 0) + if e is not None: e.__exit__(None,None,None) + return result +nn.average_tensor_list = average_tensor_list + +def concat (tensors_list, axis): + """ + Better version. + """ + if len(tensors_list) == 1: + return tensors_list[0] + return tf.concat(tensors_list, axis) +nn.concat = concat + +def gelu(x): + cdf = 0.5 * (1.0 + tf.nn.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3))))) + return x * cdf +nn.gelu = gelu + +def upsample2d(x, size=2): + if nn.data_format == "NCHW": + x = tf.transpose(x, (0,2,3,1)) + x = tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) ) + x = tf.transpose(x, (0,3,1,2)) + + + # b,c,h,w = x.shape.as_list() + # x = tf.reshape (x, (-1,c,h,1,w,1) ) + # x = tf.tile(x, (1,1,1,size,1,size) ) + # x = tf.reshape (x, (-1,c,h*size,w*size) ) + return x + else: + return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) ) +nn.upsample2d = upsample2d + +def resize2d_bilinear(x, size=2): + h = x.shape[nn.conv2d_spatial_axes[0]].value + w = x.shape[nn.conv2d_spatial_axes[1]].value + + if nn.data_format == "NCHW": + x = tf.transpose(x, (0,2,3,1)) + + if size > 0: + new_size = (h*size,w*size) + else: + new_size = (h//-size,w//-size) + + x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.BILINEAR) + + if nn.data_format == "NCHW": + x = tf.transpose(x, (0,3,1,2)) + + return x +nn.resize2d_bilinear = resize2d_bilinear + +def resize2d_nearest(x, size=2): + if size in [-1,0,1]: + return x + + + if size > 0: + raise Exception("") + else: + if nn.data_format == "NCHW": + x = x[:,:,::-size,::-size] + else: + x = x[:,::-size,::-size,:] + return x + + h = x.shape[nn.conv2d_spatial_axes[0]].value + w = x.shape[nn.conv2d_spatial_axes[1]].value + + if nn.data_format == "NCHW": + x = tf.transpose(x, (0,2,3,1)) + + if size > 0: + new_size = (h*size,w*size) + else: + new_size = (h//-size,w//-size) + + x = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) + + if nn.data_format == "NCHW": + x = tf.transpose(x, (0,3,1,2)) + + return x +nn.resize2d_nearest = resize2d_nearest + +def flatten(x): + if nn.data_format == "NHWC": + # match NCHW version in order to switch data_format without problems + x = tf.transpose(x, (0,3,1,2) ) + return tf.reshape (x, (-1, np.prod(x.shape[1:])) ) + +nn.flatten = flatten + +def max_pool(x, kernel_size=2, strides=2): + if nn.data_format == "NHWC": + return tf.nn.max_pool(x, [1,kernel_size,kernel_size,1], [1,strides,strides,1], 'SAME', data_format=nn.data_format) + else: + return tf.nn.max_pool(x, [1,1,kernel_size,kernel_size], [1,1,strides,strides], 'SAME', data_format=nn.data_format) + +nn.max_pool = max_pool + +def reshape_4D(x, w,h,c): + if nn.data_format == "NHWC": + # match NCHW version in order to switch data_format without problems + x = tf.reshape (x, (-1,c,h,w)) + x = tf.transpose(x, (0,2,3,1) ) + return x + else: + return tf.reshape (x, (-1,c,h,w)) +nn.reshape_4D = reshape_4D + +def random_binomial(shape, p=0.0, dtype=None, seed=None): + if dtype is None: + dtype=tf.float32 + + if seed is None: + seed = np.random.randint(10e6) + return array_ops.where( + random_ops.random_uniform(shape, dtype=tf.float16, seed=seed) < p, + array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype)) +nn.random_binomial = random_binomial + +def gaussian_blur(input, radius=2.0): + def gaussian(x, mu, sigma): + return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2)) + + def make_kernel(sigma): + kernel_size = max(3, int(2 * 2 * sigma)) + if kernel_size % 2 == 0: + kernel_size += 1 + mean = np.floor(0.5 * kernel_size) + kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)]) + np_kernel = np.outer(kernel_1d, kernel_1d).astype(np.float32) + kernel = np_kernel / np.sum(np_kernel) + return kernel, kernel_size + + gauss_kernel, kernel_size = make_kernel(radius) + padding = kernel_size//2 + if padding != 0: + if nn.data_format == "NHWC": + padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ] + else: + padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ] + else: + padding = None + gauss_kernel = gauss_kernel[:,:,None,None] + + x = input + k = tf.tile (gauss_kernel, (1,1,x.shape[nn.conv2d_ch_axis],1) ) + x = tf.pad(x, padding ) + x = tf.nn.depthwise_conv2d(x, k, strides=[1,1,1,1], padding='VALID', data_format=nn.data_format) + return x +nn.gaussian_blur = gaussian_blur + +def style_loss(target, style, gaussian_blur_radius=0.0, loss_weight=1.0, step_size=1): + def sd(content, style, loss_weight): + content_nc = content.shape[ nn.conv2d_ch_axis ] + style_nc = style.shape[nn.conv2d_ch_axis] + if content_nc != style_nc: + raise Exception("style_loss() content_nc != style_nc") + c_mean, c_var = tf.nn.moments(content, axes=nn.conv2d_spatial_axes, keep_dims=True) + s_mean, s_var = tf.nn.moments(style, axes=nn.conv2d_spatial_axes, keep_dims=True) + c_std, s_std = tf.sqrt(c_var + 1e-5), tf.sqrt(s_var + 1e-5) + mean_loss = tf.reduce_sum(tf.square(c_mean-s_mean), axis=[1,2,3]) + std_loss = tf.reduce_sum(tf.square(c_std-s_std), axis=[1,2,3]) + return (mean_loss + std_loss) * ( loss_weight / content_nc.value ) + + if gaussian_blur_radius > 0.0: + target = gaussian_blur(target, gaussian_blur_radius) + style = gaussian_blur(style, gaussian_blur_radius) + + return sd( target, style, loss_weight=loss_weight ) + +nn.style_loss = style_loss + +def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): + if img1.dtype != img2.dtype: + raise ValueError("img1.dtype != img2.dtype") + + not_float32 = img1.dtype != tf.float32 + + if not_float32: + img_dtype = img1.dtype + img1 = tf.cast(img1, tf.float32) + img2 = tf.cast(img2, tf.float32) + + filter_size = max(1, filter_size) + + kernel = np.arange(0, filter_size, dtype=np.float32) + kernel -= (filter_size - 1 ) / 2.0 + kernel = kernel**2 + kernel *= ( -0.5 / (filter_sigma**2) ) + kernel = np.reshape (kernel, (1,-1)) + np.reshape(kernel, (-1,1) ) + kernel = tf.constant ( np.reshape (kernel, (1,-1)), dtype=tf.float32 ) + kernel = tf.nn.softmax(kernel) + kernel = tf.reshape (kernel, (filter_size, filter_size, 1, 1)) + kernel = tf.tile (kernel, (1,1, img1.shape[ nn.conv2d_ch_axis ] ,1)) + + def reducer(x): + return tf.nn.depthwise_conv2d(x, kernel, strides=[1,1,1,1], padding='VALID', data_format=nn.data_format) + + c1 = (k1 * max_val) ** 2 + c2 = (k2 * max_val) ** 2 + + mean0 = reducer(img1) + mean1 = reducer(img2) + num0 = mean0 * mean1 * 2.0 + den0 = tf.square(mean0) + tf.square(mean1) + luminance = (num0 + c1) / (den0 + c1) + + num1 = reducer(img1 * img2) * 2.0 + den1 = reducer(tf.square(img1) + tf.square(img2)) + c2 *= 1.0 #compensation factor + cs = (num1 - num0 + c2) / (den1 - den0 + c2) + + ssim_val = tf.reduce_mean(luminance * cs, axis=nn.conv2d_spatial_axes ) + dssim = (1.0 - ssim_val ) / 2.0 + + if not_float32: + dssim = tf.cast(dssim, img_dtype) + return dssim + +nn.dssim = dssim + +def space_to_depth(x, size): + if nn.data_format == "NHWC": + # match NCHW version in order to switch data_format without problems + b,h,w,c = x.shape.as_list() + oh, ow = h // size, w // size + x = tf.reshape(x, (-1, size, oh, size, ow, c)) + x = tf.transpose(x, (0, 2, 4, 1, 3, 5)) + x = tf.reshape(x, (-1, oh, ow, size* size* c )) + return x + else: + return tf.space_to_depth(x, size, data_format=nn.data_format) +nn.space_to_depth = space_to_depth + +def depth_to_space(x, size): + if nn.data_format == "NHWC": + # match NCHW version in order to switch data_format without problems + + b,h,w,c = x.shape.as_list() + oh, ow = h * size, w * size + oc = c // (size * size) + + x = tf.reshape(x, (-1, h, w, size, size, oc, ) ) + x = tf.transpose(x, (0, 1, 3, 2, 4, 5)) + x = tf.reshape(x, (-1, oh, ow, oc, )) + return x + else: + cfg = nn.getCurrentDeviceConfig() + if not cfg.cpu_only: + return tf.depth_to_space(x, size, data_format=nn.data_format) + b,c,h,w = x.shape.as_list() + oh, ow = h * size, w * size + oc = c // (size * size) + + x = tf.reshape(x, (-1, size, size, oc, h, w, ) ) + x = tf.transpose(x, (0, 3, 4, 1, 5, 2)) + x = tf.reshape(x, (-1, oc, oh, ow)) + return x +nn.depth_to_space = depth_to_space + +def rgb_to_lab(srgb): + srgb_pixels = tf.reshape(srgb, [-1, 3]) + linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32) + exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32) + rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask + rgb_to_xyz = tf.constant([ + # X Y Z + [0.412453, 0.212671, 0.019334], # R + [0.357580, 0.715160, 0.119193], # G + [0.180423, 0.072169, 0.950227], # B + ]) + xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz) + + xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754]) + + epsilon = 6/29 + linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32) + exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32) + fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask + + fxfyfz_to_lab = tf.constant([ + # l a b + [ 0.0, 500.0, 0.0], # fx + [116.0, -500.0, 200.0], # fy + [ 0.0, 0.0, -200.0], # fz + ]) + lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0]) + return tf.reshape(lab_pixels, tf.shape(srgb)) +nn.rgb_to_lab = rgb_to_lab + +def total_variation_mse(images): + """ + Same as generic total_variation, but MSE diff instead of MAE + """ + pixel_dif1 = images[:, 1:, :, :] - images[:, :-1, :, :] + pixel_dif2 = images[:, :, 1:, :] - images[:, :, :-1, :] + + tot_var = ( tf.reduce_sum(tf.square(pixel_dif1), axis=[1,2,3]) + + tf.reduce_sum(tf.square(pixel_dif2), axis=[1,2,3]) ) + return tot_var +nn.total_variation_mse = total_variation_mse + + +def pixel_norm(x, axes): + return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=axes, keepdims=True) + 1e-06) +nn.pixel_norm = pixel_norm + +""" +def tf_suppress_lower_mean(t, eps=0.00001): + if t.shape.ndims != 1: + raise ValueError("tf_suppress_lower_mean: t rank must be 1") + t_mean_eps = tf.reduce_mean(t) - eps + q = tf.clip_by_value(t, t_mean_eps, tf.reduce_max(t) ) + q = tf.clip_by_value(q-t_mean_eps, 0, eps) + q = q * (t/eps) + return q +""" + + + +def _get_pixel_value(img, x, y): + shape = tf.shape(x) + batch_size = shape[0] + height = shape[1] + width = shape[2] + + batch_idx = tf.range(0, batch_size) + batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1)) + b = tf.tile(batch_idx, (1, height, width)) + + indices = tf.stack([b, y, x], 3) + + return tf.gather_nd(img, indices) + +def bilinear_sampler(img, x, y): + H = tf.shape(img)[1] + W = tf.shape(img)[2] + H_MAX = tf.cast(H - 1, tf.int32) + W_MAX = tf.cast(W - 1, tf.int32) + + # grab 4 nearest corner points for each (x_i, y_i) + x0 = tf.cast(tf.floor(x), tf.int32) + x1 = x0 + 1 + y0 = tf.cast(tf.floor(y), tf.int32) + y1 = y0 + 1 + + # clip to range [0, H-1/W-1] to not violate img boundaries + x0 = tf.clip_by_value(x0, 0, W_MAX) + x1 = tf.clip_by_value(x1, 0, W_MAX) + y0 = tf.clip_by_value(y0, 0, H_MAX) + y1 = tf.clip_by_value(y1, 0, H_MAX) + + # get pixel value at corner coords + Ia = _get_pixel_value(img, x0, y0) + Ib = _get_pixel_value(img, x0, y1) + Ic = _get_pixel_value(img, x1, y0) + Id = _get_pixel_value(img, x1, y1) + + # recast as float for delta calculation + x0 = tf.cast(x0, tf.float32) + x1 = tf.cast(x1, tf.float32) + y0 = tf.cast(y0, tf.float32) + y1 = tf.cast(y1, tf.float32) + + # calculate deltas + wa = (x1-x) * (y1-y) + wb = (x1-x) * (y-y0) + wc = (x-x0) * (y1-y) + wd = (x-x0) * (y-y0) + + # add dimension for addition + wa = tf.expand_dims(wa, axis=3) + wb = tf.expand_dims(wb, axis=3) + wc = tf.expand_dims(wc, axis=3) + wd = tf.expand_dims(wd, axis=3) + + # compute output + out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id]) + + return out + +nn.bilinear_sampler = bilinear_sampler + diff --git a/core/leras/optimizers/AdaBelief.py b/core/leras/optimizers/AdaBelief.py new file mode 100644 index 0000000000000000000000000000000000000000..da6e1a2f8ac5ecea2bdf4c2b89b89adf3c2e81b6 --- /dev/null +++ b/core/leras/optimizers/AdaBelief.py @@ -0,0 +1,81 @@ +import numpy as np +from core.leras import nn +from tensorflow.python.ops import control_flow_ops, state_ops + +tf = nn.tf + +class AdaBelief(nn.OptimizerBase): + def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, lr_dropout=1.0, lr_cos=0, clipnorm=0.0, name=None, **kwargs): + super().__init__(name=name) + + if name is None: + raise ValueError('name must be defined.') + + self.lr = lr + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.lr_dropout = lr_dropout + self.lr_cos = lr_cos + self.clipnorm = clipnorm + + with tf.device('/CPU:0') : + with tf.variable_scope(self.name): + self.iterations = tf.Variable(0, dtype=tf.int64, name='iters') + + self.ms_dict = {} + self.vs_dict = {} + self.lr_rnds_dict = {} + + def get_weights(self): + return [self.iterations] + list(self.ms_dict.values()) + list(self.vs_dict.values()) + + def initialize_variables(self, trainable_weights, vars_on_cpu=True, lr_dropout_on_cpu=False): + # Initialize here all trainable variables used in training + e = tf.device('/CPU:0') if vars_on_cpu else None + if e: e.__enter__() + with tf.variable_scope(self.name): + ms = { v.name : tf.get_variable ( f'ms_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights } + vs = { v.name : tf.get_variable ( f'vs_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights } + self.ms_dict.update (ms) + self.vs_dict.update (vs) + + if self.lr_dropout != 1.0: + e = tf.device('/CPU:0') if lr_dropout_on_cpu else None + if e: e.__enter__() + lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ] + if e: e.__exit__(None, None, None) + self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } ) + if e: e.__exit__(None, None, None) + + def get_update_op(self, grads_vars): + updates = [] + + if self.clipnorm > 0.0: + norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars])) + updates += [ state_ops.assign_add( self.iterations, 1) ] + for i, (g,v) in enumerate(grads_vars): + if self.clipnorm > 0.0: + g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) ) + + ms = self.ms_dict[ v.name ] + vs = self.vs_dict[ v.name ] + + m_t = self.beta_1*ms + (1.0-self.beta_1) * g + v_t = self.beta_2*vs + (1.0-self.beta_2) * tf.square(g-m_t) + + lr = tf.constant(self.lr, g.dtype) + if self.lr_cos != 0: + lr *= (tf.cos( tf.cast(self.iterations, g.dtype) * (2*3.1415926535/ float(self.lr_cos) ) ) + 1.0) / 2.0 + + v_diff = - lr * m_t / (tf.sqrt(v_t) + np.finfo( g.dtype.as_numpy_dtype ).resolution ) + if self.lr_dropout != 1.0: + lr_rnd = self.lr_rnds_dict[v.name] + v_diff *= lr_rnd + new_v = v + v_diff + + updates.append (state_ops.assign(ms, m_t)) + updates.append (state_ops.assign(vs, v_t)) + updates.append (state_ops.assign(v, new_v)) + + return control_flow_ops.group ( *updates, name=self.name+'_updates') +nn.AdaBelief = AdaBelief diff --git a/core/leras/optimizers/OptimizerBase.py b/core/leras/optimizers/OptimizerBase.py new file mode 100644 index 0000000000000000000000000000000000000000..e112363ee70efb074be10aa4eda6c53aa8005db8 --- /dev/null +++ b/core/leras/optimizers/OptimizerBase.py @@ -0,0 +1,42 @@ +import copy +from core.leras import nn +tf = nn.tf + +class OptimizerBase(nn.Saveable): + def __init__(self, name=None): + super().__init__(name=name) + + def tf_clip_norm(self, g, c, n): + """Clip the gradient `g` if the L2 norm `n` exceeds `c`. + # Arguments + g: Tensor, the gradient tensor + c: float >= 0. Gradients will be clipped + when their L2 norm exceeds this value. + n: Tensor, actual norm of `g`. + # Returns + Tensor, the gradient clipped if required. + """ + if c <= 0: # if clipnorm == 0 no need to add ops to the graph + return g + + condition = n >= c + then_expression = tf.scalar_mul(c / n, g) + else_expression = g + + # saving the shape to avoid converting sparse tensor to dense + if isinstance(then_expression, tf.Tensor): + g_shape = copy.copy(then_expression.get_shape()) + elif isinstance(then_expression, tf.IndexedSlices): + g_shape = copy.copy(then_expression.dense_shape) + if condition.dtype != tf.bool: + condition = tf.cast(condition, 'bool') + g = tf.cond(condition, + lambda: then_expression, + lambda: else_expression) + if isinstance(then_expression, tf.Tensor): + g.set_shape(g_shape) + elif isinstance(then_expression, tf.IndexedSlices): + g._dense_shape = g_shape + + return g +nn.OptimizerBase = OptimizerBase diff --git a/core/leras/optimizers/RMSprop.py b/core/leras/optimizers/RMSprop.py new file mode 100644 index 0000000000000000000000000000000000000000..0b20fbfb59278b8ec86d42e9732b61b3c7e4b773 --- /dev/null +++ b/core/leras/optimizers/RMSprop.py @@ -0,0 +1,74 @@ +import numpy as np +from tensorflow.python.ops import control_flow_ops, state_ops +from core.leras import nn +tf = nn.tf + +class RMSprop(nn.OptimizerBase): + def __init__(self, lr=0.001, rho=0.9, lr_dropout=1.0, lr_cos=0, clipnorm=0.0, name=None, **kwargs): + super().__init__(name=name) + + if name is None: + raise ValueError('name must be defined.') + + self.lr_dropout = lr_dropout + self.lr_cos = lr_cos + self.lr = lr + self.rho = rho + self.clipnorm = clipnorm + + with tf.device('/CPU:0') : + with tf.variable_scope(self.name): + + self.iterations = tf.Variable(0, dtype=tf.int64, name='iters') + + self.accumulators_dict = {} + self.lr_rnds_dict = {} + + def get_weights(self): + return [self.iterations] + list(self.accumulators_dict.values()) + + def initialize_variables(self, trainable_weights, vars_on_cpu=True, lr_dropout_on_cpu=False): + # Initialize here all trainable variables used in training + e = tf.device('/CPU:0') if vars_on_cpu else None + if e: e.__enter__() + with tf.variable_scope(self.name): + accumulators = { v.name : tf.get_variable ( f'acc_{v.name}'.replace(':','_'), v.shape, dtype=v.dtype, initializer=tf.initializers.constant(0.0), trainable=False) for v in trainable_weights } + self.accumulators_dict.update ( accumulators) + + if self.lr_dropout != 1.0: + e = tf.device('/CPU:0') if lr_dropout_on_cpu else None + if e: e.__enter__() + lr_rnds = [ nn.random_binomial( v.shape, p=self.lr_dropout, dtype=v.dtype) for v in trainable_weights ] + if e: e.__exit__(None, None, None) + self.lr_rnds_dict.update ( { v.name : rnd for v,rnd in zip(trainable_weights,lr_rnds) } ) + if e: e.__exit__(None, None, None) + + def get_update_op(self, grads_vars): + updates = [] + + if self.clipnorm > 0.0: + norm = tf.sqrt( sum([tf.reduce_sum(tf.square(tf.cast(g, tf.float32))) for g,v in grads_vars])) + updates += [ state_ops.assign_add( self.iterations, 1) ] + for i, (g,v) in enumerate(grads_vars): + if self.clipnorm > 0.0: + g = self.tf_clip_norm(g, self.clipnorm, tf.cast(norm, g.dtype) ) + + a = self.accumulators_dict[ v.name ] + + new_a = self.rho * a + (1. - self.rho) * tf.square(g) + + lr = tf.constant(self.lr, g.dtype) + if self.lr_cos != 0: + lr *= (tf.cos( tf.cast(self.iterations, g.dtype) * (2*3.1415926535/ float(self.lr_cos) ) ) + 1.0) / 2.0 + + v_diff = - lr * g / (tf.sqrt(new_a) + np.finfo( g.dtype.as_numpy_dtype ).resolution ) + if self.lr_dropout != 1.0: + lr_rnd = self.lr_rnds_dict[v.name] + v_diff *= lr_rnd + new_v = v + v_diff + + updates.append (state_ops.assign(a, new_a)) + updates.append (state_ops.assign(v, new_v)) + + return control_flow_ops.group ( *updates, name=self.name+'_updates') +nn.RMSprop = RMSprop \ No newline at end of file diff --git a/core/leras/optimizers/__init__.py b/core/leras/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4f8a7e4c1a35c0da9fbe837a772252db51f27869 --- /dev/null +++ b/core/leras/optimizers/__init__.py @@ -0,0 +1,3 @@ +from .OptimizerBase import * +from .RMSprop import * +from .AdaBelief import * \ No newline at end of file diff --git a/core/mathlib/__init__.py b/core/mathlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5fa13325e22b7ffb03af1c7d6f8bc674cfc427 --- /dev/null +++ b/core/mathlib/__init__.py @@ -0,0 +1,97 @@ +import math + +import cv2 +import numpy as np +import numpy.linalg as npla + +from .umeyama import umeyama + + +def get_power_of_two(x): + i = 0 + while (1 << i) < x: + i += 1 + return i + +def rotationMatrixToEulerAngles(R) : + sy = math.sqrt(R[0,0] * R[0,0] + R[1,0] * R[1,0]) + singular = sy < 1e-6 + if not singular : + x = math.atan2(R[2,1] , R[2,2]) + y = math.atan2(-R[2,0], sy) + z = math.atan2(R[1,0], R[0,0]) + else : + x = math.atan2(-R[1,2], R[1,1]) + y = math.atan2(-R[2,0], sy) + z = 0 + return np.array([x, y, z]) + +def polygon_area(x,y): + return 0.5*np.abs(np.dot(x,np.roll(y,1))-np.dot(y,np.roll(x,1))) + +def rotate_point(origin, point, deg): + """ + Rotate a point counterclockwise by a given angle around a given origin. + + The angle should be given in radians. + """ + ox, oy = origin + px, py = point + + rad = deg * math.pi / 180.0 + qx = ox + math.cos(rad) * (px - ox) - math.sin(rad) * (py - oy) + qy = oy + math.sin(rad) * (px - ox) + math.cos(rad) * (py - oy) + return np.float32([qx, qy]) + +def transform_points(points, mat, invert=False): + if invert: + mat = cv2.invertAffineTransform (mat) + points = np.expand_dims(points, axis=1) + points = cv2.transform(points, mat, points.shape) + points = np.squeeze(points) + return points + + +def transform_mat(mat, res, tx, ty, rotation, scale): + """ + transform mat in local space of res + scale -> translate -> rotate + + tx,ty float + rotation int degrees + scale float + """ + + + lt, rt, lb, ct = transform_points ( np.float32([(0,0),(res,0),(0,res),(res / 2, res/2) ]),mat, True) + + hor_v = (rt-lt).astype(np.float32) + hor_size = npla.norm(hor_v) + hor_v /= hor_size + + ver_v = (lb-lt).astype(np.float32) + ver_size = npla.norm(ver_v) + ver_v /= ver_size + + bt_diag_vec = (rt-ct).astype(np.float32) + half_diag_len = npla.norm(bt_diag_vec) + bt_diag_vec /= half_diag_len + + tb_diag_vec = np.float32( [ -bt_diag_vec[1], bt_diag_vec[0] ] ) + + rt = ct + bt_diag_vec*half_diag_len*scale + lb = ct - bt_diag_vec*half_diag_len*scale + lt = ct - tb_diag_vec*half_diag_len*scale + + rt[0] += tx*hor_size + lb[0] += tx*hor_size + lt[0] += tx*hor_size + rt[1] += ty*ver_size + lb[1] += ty*ver_size + lt[1] += ty*ver_size + + rt = rotate_point(ct, rt, rotation) + lb = rotate_point(ct, lb, rotation) + lt = rotate_point(ct, lt, rotation) + + return cv2.getAffineTransform( np.float32([lt, rt, lb]), np.float32([ [0,0], [res,0], [0,res] ]) ) diff --git a/core/mathlib/umeyama.py b/core/mathlib/umeyama.py new file mode 100644 index 0000000000000000000000000000000000000000..826a88f1ce5d3112817a367dd8784efa4fa71dc6 --- /dev/null +++ b/core/mathlib/umeyama.py @@ -0,0 +1,71 @@ +import numpy as np + +def umeyama(src, dst, estimate_scale): + """Estimate N-D similarity transformation with or without scaling. + Parameters + ---------- + src : (M, N) array + Source coordinates. + dst : (M, N) array + Destination coordinates. + estimate_scale : bool + Whether to estimate scaling factor. + Returns + ------- + T : (N + 1, N + 1) + The homogeneous similarity transformation matrix. The matrix contains + NaN values only if the problem is not well-conditioned. + References + ---------- + .. [1] "Least-squares estimation of transformation parameters between two + point patterns", Shinji Umeyama, PAMI 1991, DOI: 10.1109/34.88573 + """ + + num = src.shape[0] + dim = src.shape[1] + + # Compute mean of src and dst. + src_mean = src.mean(axis=0) + dst_mean = dst.mean(axis=0) + + # Subtract mean from src and dst. + src_demean = src - src_mean + dst_demean = dst - dst_mean + + # Eq. (38). + A = np.dot(dst_demean.T, src_demean) / num + + # Eq. (39). + d = np.ones((dim,), dtype=np.double) + if np.linalg.det(A) < 0: + d[dim - 1] = -1 + + T = np.eye(dim + 1, dtype=np.double) + + U, S, V = np.linalg.svd(A) + + # Eq. (40) and (43). + rank = np.linalg.matrix_rank(A) + if rank == 0: + return np.nan * T + elif rank == dim - 1: + if np.linalg.det(U) * np.linalg.det(V) > 0: + T[:dim, :dim] = np.dot(U, V) + else: + s = d[dim - 1] + d[dim - 1] = -1 + T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V)) + d[dim - 1] = s + else: + T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V)) + + if estimate_scale: + # Eq. (41) and (42). + scale = 1.0 / src_demean.var(axis=0).sum() * np.dot(S, d) + else: + scale = 1.0 + + T[:dim, dim] = dst_mean - scale * np.dot(T[:dim, :dim], src_mean.T) + T[:dim, :dim] *= scale + + return T diff --git a/core/mplib/MPSharedList.py b/core/mplib/MPSharedList.py new file mode 100644 index 0000000000000000000000000000000000000000..874c56a6f46ed52c73b905d79ba2043c8c570516 --- /dev/null +++ b/core/mplib/MPSharedList.py @@ -0,0 +1,111 @@ +import multiprocessing +import pickle +import struct +from core.joblib import Subprocessor + +class MPSharedList(): + """ + Provides read-only pickled list of constant objects via shared memory aka 'multiprocessing.Array' + Thus no 4GB limit for subprocesses. + + supports list concat via + or sum() + """ + + def __init__(self, obj_list): + if obj_list is None: + self.obj_counts = None + self.table_offsets = None + self.data_offsets = None + self.sh_bs = None + else: + obj_count, table_offset, data_offset, sh_b = MPSharedList.bake_data(obj_list) + + self.obj_counts = [obj_count] + self.table_offsets = [table_offset] + self.data_offsets = [data_offset] + self.sh_bs = [sh_b] + + def __add__(self, o): + if isinstance(o, MPSharedList): + m = MPSharedList(None) + m.obj_counts = self.obj_counts + o.obj_counts + m.table_offsets = self.table_offsets + o.table_offsets + m.data_offsets = self.data_offsets + o.data_offsets + m.sh_bs = self.sh_bs + o.sh_bs + return m + elif isinstance(o, int): + return self + else: + raise ValueError(f"MPSharedList object of class {o.__class__} is not supported for __add__ operator.") + + def __radd__(self, o): + return self+o + + def __len__(self): + return sum(self.obj_counts) + + def __getitem__(self, key): + obj_count = sum(self.obj_counts) + if key < 0: + key = obj_count+key + if key < 0 or key >= obj_count: + raise ValueError("out of range") + + for i in range(len(self.obj_counts)): + + if key < self.obj_counts[i]: + table_offset = self.table_offsets[i] + data_offset = self.data_offsets[i] + sh_b = self.sh_bs[i] + break + key -= self.obj_counts[i] + + sh_b = memoryview(sh_b).cast('B') + + offset_start, offset_end = struct.unpack(' self.no_response_time_sec: + #subprocess busy too long + io.log_info ( '%s doesnt response, terminating it.' % (cli.name) ) + self.on_data_return (cli.host_dict, cli.sent_data ) + cli.kill() + self.clis.remove(cli) + + for cli in self.clis[:]: + if cli.state == 0: + #free state of subprocess, get some data from get_data + data = self.get_data(cli.host_dict) + if data is not None: + #and send it to subprocess + cli.s2c.put ( {'op': 'data', 'data' : data} ) + cli.sent_time = time.time() + cli.sent_data = data + cli.state = 1 + + if all ([cli.state == 0 for cli in self.clis]): + #gracefully terminating subprocesses + for cli in self.clis[:]: + cli.s2c.put ( {'op': 'close'} ) + cli.sent_time = time.time() + + while True: + for cli in self.clis[:]: + terminate_it = False + while not cli.c2s.empty(): + obj = cli.c2s.get() + obj_op = obj['op'] + if obj_op == 'finalized': + terminate_it = True + break + + if (time.time() - cli.sent_time) > 30: + terminate_it = True + + if terminate_it: + cli.state = 2 + cli.kill() + + if all ([cli.state == 2 for cli in self.clis]): + break + + #finalizing host logic + self.q_timer.stop() + self.q_timer = None + self.on_clients_finalized() + diff --git a/core/qtex/QXIconButton.py b/core/qtex/QXIconButton.py new file mode 100644 index 0000000000000000000000000000000000000000..235d149423d666faa68129a8627f0ad39422044d --- /dev/null +++ b/core/qtex/QXIconButton.py @@ -0,0 +1,83 @@ +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * + +from localization import StringsDB +from .QXMainWindow import * + +class QXIconButton(QPushButton): + """ + Custom Icon button that works through keyEvent system, without shortcut of QAction + works only with QXMainWindow as global window class + currently works only with one-key shortcut + """ + + def __init__(self, icon, + tooltip=None, + shortcut=None, + click_func=None, + first_repeat_delay=300, + repeat_delay=20, + ): + + super().__init__(icon, "") + + self.setIcon(icon) + + if shortcut is not None: + tooltip = f"{tooltip} ( {StringsDB['S_HOT_KEY'] }: {shortcut} )" + + self.setToolTip(tooltip) + + + self.seq = QKeySequence(shortcut) if shortcut is not None else None + + QXMainWindow.inst.add_keyPressEvent_listener ( self.on_keyPressEvent ) + QXMainWindow.inst.add_keyReleaseEvent_listener ( self.on_keyReleaseEvent ) + + self.click_func = click_func + self.first_repeat_delay = first_repeat_delay + self.repeat_delay = repeat_delay + self.repeat_timer = None + + self.op_device = None + + self.pressed.connect( lambda : self.action(is_pressed=True) ) + self.released.connect( lambda : self.action(is_pressed=False) ) + + def action(self, is_pressed=None, op_device=None): + if self.click_func is None: + return + + if is_pressed is not None: + if is_pressed: + if self.repeat_timer is None: + self.click_func() + self.repeat_timer = QTimer() + self.repeat_timer.timeout.connect(self.action) + self.repeat_timer.start(self.first_repeat_delay) + else: + if self.repeat_timer is not None: + self.repeat_timer.stop() + self.repeat_timer = None + else: + self.click_func() + if self.repeat_timer is not None: + self.repeat_timer.setInterval(self.repeat_delay) + + def on_keyPressEvent(self, ev): + key = ev.nativeVirtualKey() + if ev.isAutoRepeat(): + return + + if self.seq is not None: + if key == self.seq[0]: + self.action(is_pressed=True) + + def on_keyReleaseEvent(self, ev): + key = ev.nativeVirtualKey() + if ev.isAutoRepeat(): + return + if self.seq is not None: + if key == self.seq[0]: + self.action(is_pressed=False) diff --git a/core/qtex/QXMainWindow.py b/core/qtex/QXMainWindow.py new file mode 100644 index 0000000000000000000000000000000000000000..a50e59730c4dc37f913d556d228671896063598a --- /dev/null +++ b/core/qtex/QXMainWindow.py @@ -0,0 +1,34 @@ +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * + +class QXMainWindow(QWidget): + """ + Custom mainwindow class that provides global single instance and event listeners + """ + inst = None + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if QXMainWindow.inst is not None: + raise Exception("QXMainWindow can only be one.") + QXMainWindow.inst = self + + self.keyPressEvent_listeners = [] + self.keyReleaseEvent_listeners = [] + self.setFocusPolicy(Qt.WheelFocus) + + def add_keyPressEvent_listener(self, func): + self.keyPressEvent_listeners.append (func) + + def add_keyReleaseEvent_listener(self, func): + self.keyReleaseEvent_listeners.append (func) + + def keyPressEvent(self, ev): + super().keyPressEvent(ev) + for func in self.keyPressEvent_listeners: + func(ev) + + def keyReleaseEvent(self, ev): + super().keyReleaseEvent(ev) + for func in self.keyReleaseEvent_listeners: + func(ev) \ No newline at end of file diff --git a/core/qtex/__init__.py b/core/qtex/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2cb44b52489ff6d05eee1dff8a39bdcba2d58ef1 --- /dev/null +++ b/core/qtex/__init__.py @@ -0,0 +1,3 @@ +from .qtex import * +from .QSubprocessor import * +from .QXIconButton import * \ No newline at end of file diff --git a/core/qtex/qtex.py b/core/qtex/qtex.py new file mode 100644 index 0000000000000000000000000000000000000000..d15e41d76620116cfdc462fa828e92570796570a --- /dev/null +++ b/core/qtex/qtex.py @@ -0,0 +1,80 @@ +import numpy as np +from PyQt5.QtCore import * +from PyQt5.QtGui import * +from PyQt5.QtWidgets import * +from localization import StringsDB + +from .QXMainWindow import * + + +class QActionEx(QAction): + def __init__(self, icon, text, shortcut=None, trigger_func=None, shortcut_in_tooltip=False, is_checkable=False, is_auto_repeat=False ): + super().__init__(icon, text) + if shortcut is not None: + self.setShortcut(shortcut) + if shortcut_in_tooltip: + + self.setToolTip( f"{text} ( {StringsDB['S_HOT_KEY'] }: {shortcut} )") + + if trigger_func is not None: + self.triggered.connect(trigger_func) + if is_checkable: + self.setCheckable(True) + self.setAutoRepeat(is_auto_repeat) + +def QImage_from_np(img): + if img.dtype != np.uint8: + raise ValueError("img should be in np.uint8 format") + + h,w,c = img.shape + if c == 1: + fmt = QImage.Format_Grayscale8 + elif c == 3: + fmt = QImage.Format_BGR888 + elif c == 4: + fmt = QImage.Format_ARGB32 + else: + raise ValueError("unsupported channel count") + + return QImage(img.data, w, h, c*w, fmt ) + +def QImage_to_np(q_img, fmt=QImage.Format_BGR888): + q_img = q_img.convertToFormat(fmt) + + width = q_img.width() + height = q_img.height() + + b = q_img.constBits() + b.setsize(height * width * 3) + arr = np.frombuffer(b, np.uint8).reshape((height, width, 3)) + return arr#[::-1] + +def QPixmap_from_np(img): + return QPixmap.fromImage(QImage_from_np(img)) + +def QPoint_from_np(n): + return QPoint(*n.astype(np.int)) + +def QPoint_to_np(q): + return np.int32( [q.x(), q.y()] ) + +def QSize_to_np(q): + return np.int32( [q.width(), q.height()] ) + +class QDarkPalette(QPalette): + def __init__(self): + super().__init__() + text_color = QColor(200,200,200) + self.setColor(QPalette.Window, QColor(53, 53, 53)) + self.setColor(QPalette.WindowText, text_color ) + self.setColor(QPalette.Base, QColor(25, 25, 25)) + self.setColor(QPalette.AlternateBase, QColor(53, 53, 53)) + self.setColor(QPalette.ToolTipBase, text_color ) + self.setColor(QPalette.ToolTipText, text_color ) + self.setColor(QPalette.Text, text_color ) + self.setColor(QPalette.Button, QColor(53, 53, 53)) + self.setColor(QPalette.ButtonText, Qt.white) + self.setColor(QPalette.BrightText, Qt.red) + self.setColor(QPalette.Link, QColor(42, 130, 218)) + self.setColor(QPalette.Highlight, QColor(42, 130, 218)) + self.setColor(QPalette.HighlightedText, Qt.black) \ No newline at end of file diff --git a/core/randomex.py b/core/randomex.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd18e25edd7139509ac97c1ada6a54e16531cdd --- /dev/null +++ b/core/randomex.py @@ -0,0 +1,16 @@ +import numpy as np + +def random_normal( size=(1,), trunc_val = 2.5, rnd_state=None ): + if rnd_state is None: + rnd_state = np.random + len = np.array(size).prod() + result = np.empty ( (len,) , dtype=np.float32) + + for i in range (len): + while True: + x = rnd_state.normal() + if x >= -trunc_val and x <= trunc_val: + break + result[i] = (x / trunc_val) + + return result.reshape ( size ) \ No newline at end of file diff --git a/core/stdex.py b/core/stdex.py new file mode 100644 index 0000000000000000000000000000000000000000..2f23be99ed0f1a526339e0918550a30a449eabcd --- /dev/null +++ b/core/stdex.py @@ -0,0 +1,36 @@ +import os +import sys + +class suppress_stdout_stderr(object): + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup ( sys.stdout.fileno() ) + self.old_stderr_fileno = os.dup ( sys.stderr.fileno() ) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2 ( self.outnull_file.fileno(), self.old_stdout_fileno_undup ) + os.dup2 ( self.errnull_file.fileno(), self.old_stderr_fileno_undup ) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2 ( self.old_stdout_fileno, self.old_stdout_fileno_undup ) + os.dup2 ( self.old_stderr_fileno, self.old_stderr_fileno_undup ) + + os.close ( self.old_stdout_fileno ) + os.close ( self.old_stderr_fileno ) + + self.outnull_file.close() + self.errnull_file.close() diff --git a/core/structex.py b/core/structex.py new file mode 100644 index 0000000000000000000000000000000000000000..cc63559febb2ca1f30f6ac0d23dbc94a947c3b10 --- /dev/null +++ b/core/structex.py @@ -0,0 +1,5 @@ +import struct + +def struct_unpack(data, counter, fmt): + fmt_size = struct.calcsize(fmt) + return (counter+fmt_size,) + struct.unpack (fmt, data[counter:counter+fmt_size]) diff --git a/doc/DFL_welcome.png b/doc/DFL_welcome.png new file mode 100644 index 0000000000000000000000000000000000000000..2e4e138bbd0327fef3718caa93dda156a4643a44 Binary files /dev/null and b/doc/DFL_welcome.png differ diff --git a/doc/deage_0_1.jpg b/doc/deage_0_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..51e057e1a0e6caf541dcec3aa6e91387d5426881 Binary files /dev/null and b/doc/deage_0_1.jpg differ diff --git a/doc/deage_0_2.jpg b/doc/deage_0_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..996cacd4c6f6a9ad58909a14bbf30b5e6f11037a Binary files /dev/null and b/doc/deage_0_2.jpg differ diff --git a/doc/deepfake_progress.png b/doc/deepfake_progress.png new file mode 100644 index 0000000000000000000000000000000000000000..51a409c7e16932247d489b02f84c883d00aa307c --- /dev/null +++ b/doc/deepfake_progress.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f37a44cf69f12d8c7fc0dc7983dff7095a56a89b480aa2435a3caa9ea10e174 +size 1050296 diff --git a/doc/deepfake_progress_source.psd b/doc/deepfake_progress_source.psd new file mode 100644 index 0000000000000000000000000000000000000000..c1d51481a4064735f68fca8c688b0cd306b04c59 --- /dev/null +++ b/doc/deepfake_progress_source.psd @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:29bb0568f0681b5aff6538a7f7bee7441534d6f7b6de1279027922a953652c0c +size 4180607 diff --git a/doc/head_replace_0_1.jpg b/doc/head_replace_0_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9125d91204e9a731cc6db3605f1496e81a5dcb80 Binary files /dev/null and b/doc/head_replace_0_1.jpg differ diff --git a/doc/head_replace_0_2.jpg b/doc/head_replace_0_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..14c23ec70bfa6c5cdc97117ed0a05ae5913f27a5 Binary files /dev/null and b/doc/head_replace_0_2.jpg differ diff --git a/doc/head_replace_1_1.jpg b/doc/head_replace_1_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..464bf504f5b9aa02eb55b7fb6a479021030b7045 Binary files /dev/null and b/doc/head_replace_1_1.jpg differ diff --git a/doc/head_replace_1_2.jpg b/doc/head_replace_1_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..14c845a42b133067d44de2cd0550850421e5a773 Binary files /dev/null and b/doc/head_replace_1_2.jpg differ diff --git a/doc/head_replace_2_1.jpg b/doc/head_replace_2_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a4478938b06cbdea9d608aacd4adb8f0e9258bba Binary files /dev/null and b/doc/head_replace_2_1.jpg differ diff --git a/doc/head_replace_2_2.jpg b/doc/head_replace_2_2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dc4eaf38b563fc32ea5e0a1b5f236aea6c7b3e17 Binary files /dev/null and b/doc/head_replace_2_2.jpg differ diff --git a/doc/landmarks.jpg b/doc/landmarks.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1b9c275420a2c74c633775ca906c22fc38fb86e3 Binary files /dev/null and b/doc/landmarks.jpg differ diff --git a/doc/landmarks_98.jpg b/doc/landmarks_98.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f2b32abb2f3eaded239a1373aa83e54bd1437bd8 Binary files /dev/null and b/doc/landmarks_98.jpg differ diff --git a/doc/logo_cuda.png b/doc/logo_cuda.png new file mode 100644 index 0000000000000000000000000000000000000000..0b928a645c303597d191538efe215bdfdfedc28d Binary files /dev/null and b/doc/logo_cuda.png differ diff --git a/doc/logo_directx.png b/doc/logo_directx.png new file mode 100644 index 0000000000000000000000000000000000000000..f9fb10ae89ef6ef465703bf46e194802c3420988 Binary files /dev/null and b/doc/logo_directx.png differ diff --git a/doc/logo_python.png b/doc/logo_python.png new file mode 100644 index 0000000000000000000000000000000000000000..f2f1d9057cc516c391bc22c02fe3dafb900023ea Binary files /dev/null and b/doc/logo_python.png differ diff --git a/doc/logo_tensorflow.png b/doc/logo_tensorflow.png new file mode 100644 index 0000000000000000000000000000000000000000..1287842360c1aeedf00c140be840bffb3f9ae05f Binary files /dev/null and b/doc/logo_tensorflow.png differ diff --git a/doc/make_everything_ok.png b/doc/make_everything_ok.png new file mode 100644 index 0000000000000000000000000000000000000000..9a90c0da5b05838a186f4731c596c644cdb9f127 Binary files /dev/null and b/doc/make_everything_ok.png differ diff --git a/doc/meme1.jpg b/doc/meme1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..819d36d43aae19e4eb55f3c45ca30f3c2d3dd9c2 Binary files /dev/null and b/doc/meme1.jpg differ diff --git a/doc/meme2.jpg b/doc/meme2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9899c855b1dec3fec493879450318bfa8224a0af Binary files /dev/null and b/doc/meme2.jpg differ diff --git a/doc/meme3.jpg b/doc/meme3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ee794adef7e990c6954baf018f4e51f5b55e9f3 Binary files /dev/null and b/doc/meme3.jpg differ diff --git a/doc/mini_tutorial.jpg b/doc/mini_tutorial.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2243fd9568172ade037bdfe5a4b92352f717d09f Binary files /dev/null and b/doc/mini_tutorial.jpg differ diff --git a/doc/political_speech1.jpg b/doc/political_speech1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..33ae2ab91f91da69e1d92132eee5eeb87060c35c Binary files /dev/null and b/doc/political_speech1.jpg differ diff --git a/doc/political_speech2.jpg b/doc/political_speech2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f170ee254c2833297e4ad433e2a2f4132204c2b0 Binary files /dev/null and b/doc/political_speech2.jpg differ diff --git a/doc/political_speech3.jpg b/doc/political_speech3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7da3a64a78e70359665861fee3e0bca15aaf9899 Binary files /dev/null and b/doc/political_speech3.jpg differ diff --git a/doc/replace_the_face.jpg b/doc/replace_the_face.jpg new file mode 100644 index 0000000000000000000000000000000000000000..55501d06ec925e6da2395d394a9521cca39b8930 Binary files /dev/null and b/doc/replace_the_face.jpg differ diff --git a/doc/tiktok_icon.png b/doc/tiktok_icon.png new file mode 100644 index 0000000000000000000000000000000000000000..63d3e7e0a67061e48beb375ddbb8336c5c3b31d3 Binary files /dev/null and b/doc/tiktok_icon.png differ diff --git a/doc/youtube_icon.png b/doc/youtube_icon.png new file mode 100644 index 0000000000000000000000000000000000000000..dff95d4b3439c9e7b9b5047f5b3561cfe5e3cb4d Binary files /dev/null and b/doc/youtube_icon.png differ diff --git a/doc/~$nual_ru_source.docx b/doc/~$nual_ru_source.docx new file mode 100644 index 0000000000000000000000000000000000000000..35b65e72872b17333ff26278cfbcebcbc63f17c1 Binary files /dev/null and b/doc/~$nual_ru_source.docx differ diff --git a/facelib/2DFAN.npy b/facelib/2DFAN.npy new file mode 100644 index 0000000000000000000000000000000000000000..7d6d0e8d81fe6183c963eea07bb7d615feb98950 --- /dev/null +++ b/facelib/2DFAN.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca2dc7f0b2aa146842e6de2119fedff1142188b7cff5ab702564952d6cba4624 +size 95570245 diff --git a/facelib/3DFAN.npy b/facelib/3DFAN.npy new file mode 100644 index 0000000000000000000000000000000000000000..a710425b553806b77b6f837c90148f628fc383e9 --- /dev/null +++ b/facelib/3DFAN.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b50d2faf0fd4d6503aba9d19365e7aff06f2ea96bac37fc5f8f25a191a0a63a9 +size 95564988 diff --git a/facelib/FANExtractor.py b/facelib/FANExtractor.py new file mode 100644 index 0000000000000000000000000000000000000000..e71f3934092cb9ea8ddba487c93198994c711848 --- /dev/null +++ b/facelib/FANExtractor.py @@ -0,0 +1,280 @@ +import os +import traceback +from pathlib import Path + +import cv2 +import numpy as np +from numpy import linalg as npla + +from facelib import FaceType, LandmarksProcessor +from core.leras import nn + +""" +ported from https://github.com/1adrianb/face-alignment +""" +class FANExtractor(object): + def __init__ (self, landmarks_3D=False, place_model_on_cpu=False): + + model_path = Path(__file__).parent / ( "2DFAN.npy" if not landmarks_3D else "3DFAN.npy") + if not model_path.exists(): + raise Exception("Unable to load FANExtractor model") + + nn.initialize(data_format="NHWC") + tf = nn.tf + + class ConvBlock(nn.ModelBase): + def on_build(self, in_planes, out_planes): + self.in_planes = in_planes + self.out_planes = out_planes + + self.bn1 = nn.BatchNorm2D(in_planes) + self.conv1 = nn.Conv2D (in_planes, out_planes//2, kernel_size=3, strides=1, padding='SAME', use_bias=False ) + + self.bn2 = nn.BatchNorm2D(out_planes//2) + self.conv2 = nn.Conv2D (out_planes//2, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) + + self.bn3 = nn.BatchNorm2D(out_planes//4) + self.conv3 = nn.Conv2D (out_planes//4, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) + + if self.in_planes != self.out_planes: + self.down_bn1 = nn.BatchNorm2D(in_planes) + self.down_conv1 = nn.Conv2D (in_planes, out_planes, kernel_size=1, strides=1, padding='VALID', use_bias=False ) + else: + self.down_bn1 = None + self.down_conv1 = None + + def forward(self, input): + x = input + x = self.bn1(x) + x = tf.nn.relu(x) + x = out1 = self.conv1(x) + + x = self.bn2(x) + x = tf.nn.relu(x) + x = out2 = self.conv2(x) + + x = self.bn3(x) + x = tf.nn.relu(x) + x = out3 = self.conv3(x) + + x = tf.concat ([out1, out2, out3], axis=-1) + + if self.in_planes != self.out_planes: + downsample = self.down_bn1(input) + downsample = tf.nn.relu (downsample) + downsample = self.down_conv1 (downsample) + x = x + downsample + else: + x = x + input + + return x + + class HourGlass (nn.ModelBase): + def on_build(self, in_planes, depth): + self.b1 = ConvBlock (in_planes, 256) + self.b2 = ConvBlock (in_planes, 256) + + if depth > 1: + self.b2_plus = HourGlass(256, depth-1) + else: + self.b2_plus = ConvBlock(256, 256) + + self.b3 = ConvBlock(256, 256) + + def forward(self, input): + up1 = self.b1(input) + + low1 = tf.nn.avg_pool(input, [1,2,2,1], [1,2,2,1], 'VALID') + low1 = self.b2 (low1) + + low2 = self.b2_plus(low1) + low3 = self.b3(low2) + + up2 = nn.upsample2d(low3) + + return up1+up2 + + class FAN (nn.ModelBase): + def __init__(self): + super().__init__(name='FAN') + + def on_build(self): + self.conv1 = nn.Conv2D (3, 64, kernel_size=7, strides=2, padding='SAME') + self.bn1 = nn.BatchNorm2D(64) + + self.conv2 = ConvBlock(64, 128) + self.conv3 = ConvBlock(128, 128) + self.conv4 = ConvBlock(128, 256) + + self.m = [] + self.top_m = [] + self.conv_last = [] + self.bn_end = [] + self.l = [] + self.bl = [] + self.al = [] + for i in range(4): + self.m += [ HourGlass(256, 4) ] + self.top_m += [ ConvBlock(256, 256) ] + + self.conv_last += [ nn.Conv2D (256, 256, kernel_size=1, strides=1, padding='VALID') ] + self.bn_end += [ nn.BatchNorm2D(256) ] + + self.l += [ nn.Conv2D (256, 68, kernel_size=1, strides=1, padding='VALID') ] + + if i < 4-1: + self.bl += [ nn.Conv2D (256, 256, kernel_size=1, strides=1, padding='VALID') ] + self.al += [ nn.Conv2D (68, 256, kernel_size=1, strides=1, padding='VALID') ] + + def forward(self, inp) : + x, = inp + x = self.conv1(x) + x = self.bn1(x) + x = tf.nn.relu(x) + + x = self.conv2(x) + x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], 'VALID') + x = self.conv3(x) + x = self.conv4(x) + + outputs = [] + previous = x + for i in range(4): + ll = self.m[i] (previous) + ll = self.top_m[i] (ll) + ll = self.conv_last[i] (ll) + ll = self.bn_end[i] (ll) + ll = tf.nn.relu(ll) + tmp_out = self.l[i](ll) + outputs.append(tmp_out) + if i < 4 - 1: + ll = self.bl[i](ll) + previous = previous + ll + self.al[i](tmp_out) + x = outputs[-1] + x = tf.transpose(x, (0,3,1,2) ) + return x + + e = None + if place_model_on_cpu: + e = tf.device("/CPU:0") + + if e is not None: e.__enter__() + self.model = FAN() + self.model.load_weights(str(model_path)) + if e is not None: e.__exit__(None,None,None) + + self.model.build_for_run ([ ( tf.float32, (None,256,256,3) ) ]) + + def extract (self, input_image, rects, second_pass_extractor=None, is_bgr=True, multi_sample=False): + if len(rects) == 0: + return [] + + if is_bgr: + input_image = input_image[:,:,::-1] + is_bgr = False + + (h, w, ch) = input_image.shape + + landmarks = [] + for (left, top, right, bottom) in rects: + scale = (right - left + bottom - top) / 195.0 + + center = np.array( [ (left + right) / 2.0, (top + bottom) / 2.0] ) + centers = [ center ] + + if multi_sample: + centers += [ center + [-1,-1], + center + [1,-1], + center + [1,1], + center + [-1,1], + ] + + images = [] + ptss = [] + + try: + for c in centers: + images += [ self.crop(input_image, c, scale) ] + + images = np.stack (images) + images = images.astype(np.float32) / 255.0 + + predicted = [] + for i in range( len(images) ): + predicted += [ self.model.run ( [ images[i][None,...] ] )[0] ] + + predicted = np.stack(predicted) + + for i, pred in enumerate(predicted): + ptss += [ self.get_pts_from_predict ( pred, centers[i], scale) ] + pts_img = np.mean ( np.array(ptss), 0 ) + + landmarks.append (pts_img) + except: + landmarks.append (None) + + if second_pass_extractor is not None: + for i, lmrks in enumerate(landmarks): + try: + if lmrks is not None: + image_to_face_mat = LandmarksProcessor.get_transform_mat (lmrks, 256, FaceType.FULL) + face_image = cv2.warpAffine(input_image, image_to_face_mat, (256, 256), cv2.INTER_CUBIC ) + + rects2 = second_pass_extractor.extract(face_image, is_bgr=is_bgr) + if len(rects2) == 1: #dont do second pass if faces != 1 detected in cropped image + lmrks2 = self.extract (face_image, [ rects2[0] ], is_bgr=is_bgr, multi_sample=True)[0] + landmarks[i] = LandmarksProcessor.transform_points (lmrks2, image_to_face_mat, True) + except: + pass + + return landmarks + + def transform(self, point, center, scale, resolution): + pt = np.array ( [point[0], point[1], 1.0] ) + h = 200.0 * scale + m = np.eye(3) + m[0,0] = resolution / h + m[1,1] = resolution / h + m[0,2] = resolution * ( -center[0] / h + 0.5 ) + m[1,2] = resolution * ( -center[1] / h + 0.5 ) + m = np.linalg.inv(m) + return np.matmul (m, pt)[0:2] + + def crop(self, image, center, scale, resolution=256.0): + ul = self.transform([1, 1], center, scale, resolution).astype( np.int ) + br = self.transform([resolution, resolution], center, scale, resolution).astype( np.int ) + + if image.ndim > 2: + newDim = np.array([br[1] - ul[1], br[0] - ul[0], image.shape[2]], dtype=np.int32) + newImg = np.zeros(newDim, dtype=np.uint8) + else: + newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int) + newImg = np.zeros(newDim, dtype=np.uint8) + ht = image.shape[0] + wd = image.shape[1] + newX = np.array([max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32) + newY = np.array([max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32) + oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) + oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) + newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :] + + newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), interpolation=cv2.INTER_LINEAR) + return newImg + + def get_pts_from_predict(self, a, center, scale): + a_ch, a_h, a_w = a.shape + + b = a.reshape ( (a_ch, a_h*a_w) ) + c = b.argmax(1).reshape ( (a_ch, 1) ).repeat(2, axis=1).astype(np.float) + c[:,0] %= a_w + c[:,1] = np.apply_along_axis ( lambda x: np.floor(x / a_w), 0, c[:,1] ) + + for i in range(a_ch): + pX, pY = int(c[i,0]), int(c[i,1]) + if pX > 0 and pX < 63 and pY > 0 and pY < 63: + diff = np.array ( [a[i,pY,pX+1]-a[i,pY,pX-1], a[i,pY+1,pX]-a[i,pY-1,pX]] ) + c[i] += np.sign(diff)*0.25 + + c += 0.5 + + return np.array( [ self.transform (c[i], center, scale, a_w) for i in range(a_ch) ] ) diff --git a/facelib/FaceEnhancer.npy b/facelib/FaceEnhancer.npy new file mode 100644 index 0000000000000000000000000000000000000000..be1ff7e9cedeb4499e3a33330960a4e709fa8712 --- /dev/null +++ b/facelib/FaceEnhancer.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:254958f67c9adfe97a0c9fc7b3c343ba490a1519c01862a50945fa228875476a +size 66227502 diff --git a/facelib/FaceEnhancer.py b/facelib/FaceEnhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc0dd9eaa56a4618ec9ed39c17807bff8e572ce --- /dev/null +++ b/facelib/FaceEnhancer.py @@ -0,0 +1,322 @@ +import operator +from pathlib import Path + +import cv2 +import numpy as np + +from core.leras import nn + +class FaceEnhancer(object): + """ + x4 face enhancer + """ + def __init__(self, place_model_on_cpu=False, run_on_cpu=False): + nn.initialize(data_format="NHWC") + tf = nn.tf + + class FaceEnhancer (nn.ModelBase): + def __init__(self, name='FaceEnhancer'): + super().__init__(name=name) + + def on_build(self): + self.conv1 = nn.Conv2D (3, 64, kernel_size=3, strides=1, padding='SAME') + + self.dense1 = nn.Dense (1, 64, use_bias=False) + self.dense2 = nn.Dense (1, 64, use_bias=False) + + self.e0_conv0 = nn.Conv2D (64, 64, kernel_size=3, strides=1, padding='SAME') + self.e0_conv1 = nn.Conv2D (64, 64, kernel_size=3, strides=1, padding='SAME') + + self.e1_conv0 = nn.Conv2D (64, 112, kernel_size=3, strides=1, padding='SAME') + self.e1_conv1 = nn.Conv2D (112, 112, kernel_size=3, strides=1, padding='SAME') + + self.e2_conv0 = nn.Conv2D (112, 192, kernel_size=3, strides=1, padding='SAME') + self.e2_conv1 = nn.Conv2D (192, 192, kernel_size=3, strides=1, padding='SAME') + + self.e3_conv0 = nn.Conv2D (192, 336, kernel_size=3, strides=1, padding='SAME') + self.e3_conv1 = nn.Conv2D (336, 336, kernel_size=3, strides=1, padding='SAME') + + self.e4_conv0 = nn.Conv2D (336, 512, kernel_size=3, strides=1, padding='SAME') + self.e4_conv1 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + + self.center_conv0 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + self.center_conv1 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + self.center_conv2 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + self.center_conv3 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + + self.d4_conv0 = nn.Conv2D (1024, 512, kernel_size=3, strides=1, padding='SAME') + self.d4_conv1 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + + self.d3_conv0 = nn.Conv2D (848, 512, kernel_size=3, strides=1, padding='SAME') + self.d3_conv1 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + + self.d2_conv0 = nn.Conv2D (704, 288, kernel_size=3, strides=1, padding='SAME') + self.d2_conv1 = nn.Conv2D (288, 288, kernel_size=3, strides=1, padding='SAME') + + self.d1_conv0 = nn.Conv2D (400, 160, kernel_size=3, strides=1, padding='SAME') + self.d1_conv1 = nn.Conv2D (160, 160, kernel_size=3, strides=1, padding='SAME') + + self.d0_conv0 = nn.Conv2D (224, 96, kernel_size=3, strides=1, padding='SAME') + self.d0_conv1 = nn.Conv2D (96, 96, kernel_size=3, strides=1, padding='SAME') + + self.out1x_conv0 = nn.Conv2D (96, 48, kernel_size=3, strides=1, padding='SAME') + self.out1x_conv1 = nn.Conv2D (48, 3, kernel_size=3, strides=1, padding='SAME') + + self.dec2x_conv0 = nn.Conv2D (96, 96, kernel_size=3, strides=1, padding='SAME') + self.dec2x_conv1 = nn.Conv2D (96, 96, kernel_size=3, strides=1, padding='SAME') + + self.out2x_conv0 = nn.Conv2D (96, 48, kernel_size=3, strides=1, padding='SAME') + self.out2x_conv1 = nn.Conv2D (48, 3, kernel_size=3, strides=1, padding='SAME') + + self.dec4x_conv0 = nn.Conv2D (96, 72, kernel_size=3, strides=1, padding='SAME') + self.dec4x_conv1 = nn.Conv2D (72, 72, kernel_size=3, strides=1, padding='SAME') + + self.out4x_conv0 = nn.Conv2D (72, 36, kernel_size=3, strides=1, padding='SAME') + self.out4x_conv1 = nn.Conv2D (36, 3 , kernel_size=3, strides=1, padding='SAME') + + def forward(self, inp): + bgr, param, param1 = inp + + x = self.conv1(bgr) + a = self.dense1(param) + a = tf.reshape(a, (-1,1,1,64) ) + + b = self.dense2(param1) + b = tf.reshape(b, (-1,1,1,64) ) + + x = tf.nn.leaky_relu(x+a+b, 0.1) + + x = tf.nn.leaky_relu(self.e0_conv0(x), 0.1) + x = e0 = tf.nn.leaky_relu(self.e0_conv1(x), 0.1) + + x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + x = tf.nn.leaky_relu(self.e1_conv0(x), 0.1) + x = e1 = tf.nn.leaky_relu(self.e1_conv1(x), 0.1) + + x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + x = tf.nn.leaky_relu(self.e2_conv0(x), 0.1) + x = e2 = tf.nn.leaky_relu(self.e2_conv1(x), 0.1) + + x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + x = tf.nn.leaky_relu(self.e3_conv0(x), 0.1) + x = e3 = tf.nn.leaky_relu(self.e3_conv1(x), 0.1) + + x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + x = tf.nn.leaky_relu(self.e4_conv0(x), 0.1) + x = e4 = tf.nn.leaky_relu(self.e4_conv1(x), 0.1) + + x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + x = tf.nn.leaky_relu(self.center_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.center_conv1(x), 0.1) + x = tf.nn.leaky_relu(self.center_conv2(x), 0.1) + x = tf.nn.leaky_relu(self.center_conv3(x), 0.1) + + x = tf.concat( [nn.resize2d_bilinear(x), e4], -1 ) + x = tf.nn.leaky_relu(self.d4_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.d4_conv1(x), 0.1) + + x = tf.concat( [nn.resize2d_bilinear(x), e3], -1 ) + x = tf.nn.leaky_relu(self.d3_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.d3_conv1(x), 0.1) + + x = tf.concat( [nn.resize2d_bilinear(x), e2], -1 ) + x = tf.nn.leaky_relu(self.d2_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.d2_conv1(x), 0.1) + + x = tf.concat( [nn.resize2d_bilinear(x), e1], -1 ) + x = tf.nn.leaky_relu(self.d1_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.d1_conv1(x), 0.1) + + x = tf.concat( [nn.resize2d_bilinear(x), e0], -1 ) + x = tf.nn.leaky_relu(self.d0_conv0(x), 0.1) + x = d0 = tf.nn.leaky_relu(self.d0_conv1(x), 0.1) + + x = tf.nn.leaky_relu(self.out1x_conv0(x), 0.1) + x = self.out1x_conv1(x) + out1x = bgr + tf.nn.tanh(x) + + x = d0 + x = tf.nn.leaky_relu(self.dec2x_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.dec2x_conv1(x), 0.1) + x = d2x = nn.resize2d_bilinear(x) + + x = tf.nn.leaky_relu(self.out2x_conv0(x), 0.1) + x = self.out2x_conv1(x) + + out2x = nn.resize2d_bilinear(out1x) + tf.nn.tanh(x) + + x = d2x + x = tf.nn.leaky_relu(self.dec4x_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.dec4x_conv1(x), 0.1) + x = d4x = nn.resize2d_bilinear(x) + + x = tf.nn.leaky_relu(self.out4x_conv0(x), 0.1) + x = self.out4x_conv1(x) + + out4x = nn.resize2d_bilinear(out2x) + tf.nn.tanh(x) + + return out4x + + model_path = Path(__file__).parent / "FaceEnhancer.npy" + if not model_path.exists(): + raise Exception("Unable to load FaceEnhancer.npy") + + with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name): + self.model = FaceEnhancer() + self.model.load_weights (model_path) + + with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name): + self.model.build_for_run ([ (tf.float32, nn.get4Dshape (192,192,3) ), + (tf.float32, (None,1,) ), + (tf.float32, (None,1,) ), + ]) + + def enhance (self, inp_img, is_tanh=False, preserve_size=True): + if not is_tanh: + inp_img = np.clip( inp_img * 2 -1, -1, 1 ) + + param = np.array([0.2]) + param1 = np.array([1.0]) + up_res = 4 + patch_size = 192 + patch_size_half = patch_size // 2 + + ih,iw,ic = inp_img.shape + h,w,c = ih,iw,ic + + th,tw = h*up_res, w*up_res + + t_padding = 0 + b_padding = 0 + l_padding = 0 + r_padding = 0 + + if h < patch_size: + t_padding = (patch_size-h)//2 + b_padding = (patch_size-h) - t_padding + + if w < patch_size: + l_padding = (patch_size-w)//2 + r_padding = (patch_size-w) - l_padding + + if t_padding != 0: + inp_img = np.concatenate ([ np.zeros ( (t_padding,w,c), dtype=np.float32 ), inp_img ], axis=0 ) + h,w,c = inp_img.shape + + if b_padding != 0: + inp_img = np.concatenate ([ inp_img, np.zeros ( (b_padding,w,c), dtype=np.float32 ) ], axis=0 ) + h,w,c = inp_img.shape + + if l_padding != 0: + inp_img = np.concatenate ([ np.zeros ( (h,l_padding,c), dtype=np.float32 ), inp_img ], axis=1 ) + h,w,c = inp_img.shape + + if r_padding != 0: + inp_img = np.concatenate ([ inp_img, np.zeros ( (h,r_padding,c), dtype=np.float32 ) ], axis=1 ) + h,w,c = inp_img.shape + + + i_max = w-patch_size+1 + j_max = h-patch_size+1 + + final_img = np.zeros ( (h*up_res,w*up_res,c), dtype=np.float32 ) + final_img_div = np.zeros ( (h*up_res,w*up_res,1), dtype=np.float32 ) + + x = np.concatenate ( [ np.linspace (0,1,patch_size_half*up_res), np.linspace (1,0,patch_size_half*up_res) ] ) + x,y = np.meshgrid(x,x) + patch_mask = (x*y)[...,None] + + j=0 + while j < j_max: + i = 0 + while i < i_max: + patch_img = inp_img[j:j+patch_size, i:i+patch_size,:] + x = self.model.run( [ patch_img[None,...], [param], [param1] ] )[0] + final_img [j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += x*patch_mask + final_img_div[j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += patch_mask + if i == i_max-1: + break + i = min( i+patch_size_half, i_max-1) + if j == j_max-1: + break + j = min( j+patch_size_half, j_max-1) + + final_img_div[final_img_div==0] = 1.0 + final_img /= final_img_div + + if t_padding+b_padding+l_padding+r_padding != 0: + final_img = final_img [t_padding*up_res:(h-b_padding)*up_res, l_padding*up_res:(w-r_padding)*up_res,:] + + if preserve_size: + final_img = cv2.resize (final_img, (iw,ih), interpolation=cv2.INTER_LANCZOS4) + + if not is_tanh: + final_img = np.clip( final_img/2+0.5, 0, 1 ) + + return final_img + + +""" + + def enhance (self, inp_img, is_tanh=False, preserve_size=True): + if not is_tanh: + inp_img = np.clip( inp_img * 2 -1, -1, 1 ) + + param = np.array([0.2]) + param1 = np.array([1.0]) + up_res = 4 + patch_size = 192 + patch_size_half = patch_size // 2 + + h,w,c = inp_img.shape + + th,tw = h*up_res, w*up_res + + preupscale_rate = 1.0 + + if h < patch_size or w < patch_size: + preupscale_rate = 1.0 / ( max(h,w) / patch_size ) + + if preupscale_rate != 1.0: + inp_img = cv2.resize (inp_img, ( int(w*preupscale_rate), int(h*preupscale_rate) ), interpolation=cv2.INTER_LANCZOS4) + h,w,c = inp_img.shape + + i_max = w-patch_size+1 + j_max = h-patch_size+1 + + final_img = np.zeros ( (h*up_res,w*up_res,c), dtype=np.float32 ) + final_img_div = np.zeros ( (h*up_res,w*up_res,1), dtype=np.float32 ) + + x = np.concatenate ( [ np.linspace (0,1,patch_size_half*up_res), np.linspace (1,0,patch_size_half*up_res) ] ) + x,y = np.meshgrid(x,x) + patch_mask = (x*y)[...,None] + + j=0 + while j < j_max: + i = 0 + while i < i_max: + patch_img = inp_img[j:j+patch_size, i:i+patch_size,:] + x = self.model.run( [ patch_img[None,...], [param], [param1] ] )[0] + final_img [j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += x*patch_mask + final_img_div[j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += patch_mask + if i == i_max-1: + break + i = min( i+patch_size_half, i_max-1) + if j == j_max-1: + break + j = min( j+patch_size_half, j_max-1) + + final_img_div[final_img_div==0] = 1.0 + final_img /= final_img_div + + if preserve_size: + final_img = cv2.resize (final_img, (w,h), interpolation=cv2.INTER_LANCZOS4) + else: + if preupscale_rate != 1.0: + final_img = cv2.resize (final_img, (tw,th), interpolation=cv2.INTER_LANCZOS4) + + if not is_tanh: + final_img = np.clip( final_img/2+0.5, 0, 1 ) + + return final_img +""" \ No newline at end of file diff --git a/facelib/FaceType.py b/facelib/FaceType.py new file mode 100644 index 0000000000000000000000000000000000000000..745cff320ee177bc4865a1c417a30e464f543e78 --- /dev/null +++ b/facelib/FaceType.py @@ -0,0 +1,37 @@ +from enum import IntEnum + +class FaceType(IntEnum): + #enumerating in order "next contains prev" + HALF = 0 + MID_FULL = 1 + FULL = 2 + FULL_NO_ALIGN = 3 + WHOLE_FACE = 4 + HEAD = 10 + HEAD_NO_ALIGN = 20 + + MARK_ONLY = 100, #no align at all, just embedded faceinfo + + @staticmethod + def fromString (s): + r = from_string_dict.get (s.lower()) + if r is None: + raise Exception ('FaceType.fromString value error') + return r + + @staticmethod + def toString (face_type): + return to_string_dict[face_type] + +to_string_dict = { FaceType.HALF : 'half_face', + FaceType.MID_FULL : 'midfull_face', + FaceType.FULL : 'full_face', + FaceType.FULL_NO_ALIGN : 'full_face_no_align', + FaceType.WHOLE_FACE : 'whole_face', + FaceType.HEAD : 'head', + FaceType.HEAD_NO_ALIGN : 'head_no_align', + + FaceType.MARK_ONLY :'mark_only', + } + +from_string_dict = { to_string_dict[x] : x for x in to_string_dict.keys() } \ No newline at end of file diff --git a/facelib/LandmarksProcessor.py b/facelib/LandmarksProcessor.py new file mode 100644 index 0000000000000000000000000000000000000000..8e5d51bfe9d5963852034a95157db683ed6455d4 --- /dev/null +++ b/facelib/LandmarksProcessor.py @@ -0,0 +1,900 @@ +import colorsys +import math +from enum import IntEnum + +import cv2 +import numpy as np +import numpy.linalg as npla + +from core import imagelib +from core import mathlib +from facelib import FaceType +from core.mathlib.umeyama import umeyama + +landmarks_2D = np.array([ +[ 0.000213256, 0.106454 ], #17 +[ 0.0752622, 0.038915 ], #18 +[ 0.18113, 0.0187482 ], #19 +[ 0.29077, 0.0344891 ], #20 +[ 0.393397, 0.0773906 ], #21 +[ 0.586856, 0.0773906 ], #22 +[ 0.689483, 0.0344891 ], #23 +[ 0.799124, 0.0187482 ], #24 +[ 0.904991, 0.038915 ], #25 +[ 0.98004, 0.106454 ], #26 +[ 0.490127, 0.203352 ], #27 +[ 0.490127, 0.307009 ], #28 +[ 0.490127, 0.409805 ], #29 +[ 0.490127, 0.515625 ], #30 +[ 0.36688, 0.587326 ], #31 +[ 0.426036, 0.609345 ], #32 +[ 0.490127, 0.628106 ], #33 +[ 0.554217, 0.609345 ], #34 +[ 0.613373, 0.587326 ], #35 +[ 0.121737, 0.216423 ], #36 +[ 0.187122, 0.178758 ], #37 +[ 0.265825, 0.179852 ], #38 +[ 0.334606, 0.231733 ], #39 +[ 0.260918, 0.245099 ], #40 +[ 0.182743, 0.244077 ], #41 +[ 0.645647, 0.231733 ], #42 +[ 0.714428, 0.179852 ], #43 +[ 0.793132, 0.178758 ], #44 +[ 0.858516, 0.216423 ], #45 +[ 0.79751, 0.244077 ], #46 +[ 0.719335, 0.245099 ], #47 +[ 0.254149, 0.780233 ], #48 +[ 0.340985, 0.745405 ], #49 +[ 0.428858, 0.727388 ], #50 +[ 0.490127, 0.742578 ], #51 +[ 0.551395, 0.727388 ], #52 +[ 0.639268, 0.745405 ], #53 +[ 0.726104, 0.780233 ], #54 +[ 0.642159, 0.864805 ], #55 +[ 0.556721, 0.902192 ], #56 +[ 0.490127, 0.909281 ], #57 +[ 0.423532, 0.902192 ], #58 +[ 0.338094, 0.864805 ], #59 +[ 0.290379, 0.784792 ], #60 +[ 0.428096, 0.778746 ], #61 +[ 0.490127, 0.785343 ], #62 +[ 0.552157, 0.778746 ], #63 +[ 0.689874, 0.784792 ], #64 +[ 0.553364, 0.824182 ], #65 +[ 0.490127, 0.831803 ], #66 +[ 0.42689 , 0.824182 ] #67 +], dtype=np.float32) + + +landmarks_2D_new = np.array([ +[ 0.000213256, 0.106454 ], #17 +[ 0.0752622, 0.038915 ], #18 +[ 0.18113, 0.0187482 ], #19 +[ 0.29077, 0.0344891 ], #20 +[ 0.393397, 0.0773906 ], #21 +[ 0.586856, 0.0773906 ], #22 +[ 0.689483, 0.0344891 ], #23 +[ 0.799124, 0.0187482 ], #24 +[ 0.904991, 0.038915 ], #25 +[ 0.98004, 0.106454 ], #26 +[ 0.490127, 0.203352 ], #27 +[ 0.490127, 0.307009 ], #28 +[ 0.490127, 0.409805 ], #29 +[ 0.490127, 0.515625 ], #30 +[ 0.36688, 0.587326 ], #31 +[ 0.426036, 0.609345 ], #32 +[ 0.490127, 0.628106 ], #33 +[ 0.554217, 0.609345 ], #34 +[ 0.613373, 0.587326 ], #35 +[ 0.121737, 0.216423 ], #36 +[ 0.187122, 0.178758 ], #37 +[ 0.265825, 0.179852 ], #38 +[ 0.334606, 0.231733 ], #39 +[ 0.260918, 0.245099 ], #40 +[ 0.182743, 0.244077 ], #41 +[ 0.645647, 0.231733 ], #42 +[ 0.714428, 0.179852 ], #43 +[ 0.793132, 0.178758 ], #44 +[ 0.858516, 0.216423 ], #45 +[ 0.79751, 0.244077 ], #46 +[ 0.719335, 0.245099 ], #47 +[ 0.254149, 0.780233 ], #48 +[ 0.726104, 0.780233 ], #54 +], dtype=np.float32) + +mouth_center_landmarks_2D = np.array([ + [-4.4202591e-07, 4.4916576e-01], #48 + [ 1.8399176e-01, 3.7537053e-01], #49 + [ 3.7018123e-01, 3.3719531e-01], #50 + [ 5.0000089e-01, 3.6938059e-01], #51 + [ 6.2981832e-01, 3.3719531e-01], #52 + [ 8.1600773e-01, 3.7537053e-01], #53 + [ 1.0000000e+00, 4.4916576e-01], #54 + [ 8.2213330e-01, 6.2836081e-01], #55 + [ 6.4110327e-01, 7.0757812e-01], #56 + [ 5.0000089e-01, 7.2259867e-01], #57 + [ 3.5889623e-01, 7.0757812e-01], #58 + [ 1.7786618e-01, 6.2836081e-01], #59 + [ 7.6765373e-02, 4.5882553e-01], #60 + [ 3.6856663e-01, 4.4601500e-01], #61 + [ 5.0000089e-01, 4.5999300e-01], #62 + [ 6.3143289e-01, 4.4601500e-01], #63 + [ 9.2323411e-01, 4.5882553e-01], #64 + [ 6.3399029e-01, 5.4228687e-01], #65 + [ 5.0000089e-01, 5.5843467e-01], #66 + [ 3.6601129e-01, 5.4228687e-01] #67 +], dtype=np.float32) + +# 68 point landmark definitions +landmarks_68_pt = { "mouth": (48,68), + "right_eyebrow": (17, 22), + "left_eyebrow": (22, 27), + "right_eye": (36, 42), + "left_eye": (42, 48), + "nose": (27, 36), # missed one point + "jaw": (0, 17) } + +landmarks_68_3D = np.array( [ +[-73.393523 , -29.801432 , 47.667532 ], #00 +[-72.775014 , -10.949766 , 45.909403 ], #01 +[-70.533638 , 7.929818 , 44.842580 ], #02 +[-66.850058 , 26.074280 , 43.141114 ], #03 +[-59.790187 , 42.564390 , 38.635298 ], #04 +[-48.368973 , 56.481080 , 30.750622 ], #05 +[-34.121101 , 67.246992 , 18.456453 ], #06 +[-17.875411 , 75.056892 , 3.609035 ], #07 +[0.098749 , 77.061286 , -0.881698 ], #08 +[17.477031 , 74.758448 , 5.181201 ], #09 +[32.648966 , 66.929021 , 19.176563 ], #10 +[46.372358 , 56.311389 , 30.770570 ], #11 +[57.343480 , 42.419126 , 37.628629 ], #12 +[64.388482 , 25.455880 , 40.886309 ], #13 +[68.212038 , 6.990805 , 42.281449 ], #14 +[70.486405 , -11.666193 , 44.142567 ], #15 +[71.375822 , -30.365191 , 47.140426 ], #16 +[-61.119406 , -49.361602 , 14.254422 ], #17 +[-51.287588 , -58.769795 , 7.268147 ], #18 +[-37.804800 , -61.996155 , 0.442051 ], #19 +[-24.022754 , -61.033399 , -6.606501 ], #20 +[-11.635713 , -56.686759 , -11.967398 ], #21 +[12.056636 , -57.391033 , -12.051204 ], #22 +[25.106256 , -61.902186 , -7.315098 ], #23 +[38.338588 , -62.777713 , -1.022953 ], #24 +[51.191007 , -59.302347 , 5.349435 ], #25 +[60.053851 , -50.190255 , 11.615746 ], #26 +[0.653940 , -42.193790 , -13.380835 ], #27 +[0.804809 , -30.993721 , -21.150853 ], #28 +[0.992204 , -19.944596 , -29.284036 ], #29 +[1.226783 , -8.414541 , -36.948060 ], #00 +[-14.772472 , 2.598255 , -20.132003 ], #01 +[-7.180239 , 4.751589 , -23.536684 ], #02 +[0.555920 , 6.562900 , -25.944448 ], #03 +[8.272499 , 4.661005 , -23.695741 ], #04 +[15.214351 , 2.643046 , -20.858157 ], #05 +[-46.047290 , -37.471411 , 7.037989 ], #06 +[-37.674688 , -42.730510 , 3.021217 ], #07 +[-27.883856 , -42.711517 , 1.353629 ], #08 +[-19.648268 , -36.754742 , -0.111088 ], #09 +[-28.272965 , -35.134493 , -0.147273 ], #10 +[-38.082418 , -34.919043 , 1.476612 ], #11 +[19.265868 , -37.032306 , -0.665746 ], #12 +[27.894191 , -43.342445 , 0.247660 ], #13 +[37.437529 , -43.110822 , 1.696435 ], #14 +[45.170805 , -38.086515 , 4.894163 ], #15 +[38.196454 , -35.532024 , 0.282961 ], #16 +[28.764989 , -35.484289 , -1.172675 ], #17 +[-28.916267 , 28.612716 , -2.240310 ], #18 +[-17.533194 , 22.172187 , -15.934335 ], #19 +[-6.684590 , 19.029051 , -22.611355 ], #20 +[0.381001 , 20.721118 , -23.748437 ], #21 +[8.375443 , 19.035460 , -22.721995 ], #22 +[18.876618 , 22.394109 , -15.610679 ], #23 +[28.794412 , 28.079924 , -3.217393 ], #24 +[19.057574 , 36.298248 , -14.987997 ], #25 +[8.956375 , 39.634575 , -22.554245 ], #26 +[0.381549 , 40.395647 , -23.591626 ], #27 +[-7.428895 , 39.836405 , -22.406106 ], #28 +[-18.160634 , 36.677899 , -15.121907 ], #29 +[-24.377490 , 28.677771 , -4.785684 ], #30 +[-6.897633 , 25.475976 , -20.893742 ], #31 +[0.340663 , 26.014269 , -22.220479 ], #32 +[8.444722 , 25.326198 , -21.025520 ], #33 +[24.474473 , 28.323008 , -5.712776 ], #34 +[8.449166 , 30.596216 , -20.671489 ], #35 +[0.205322 , 31.408738 , -21.903670 ], #36 +[-7.198266 , 30.844876 , -20.328022 ] #37 +], dtype=np.float32) + +FaceType_to_padding_remove_align = { + FaceType.HALF: (0.0, False), + FaceType.MID_FULL: (0.0675, False), + FaceType.FULL: (0.2109375, False), + FaceType.FULL_NO_ALIGN: (0.2109375, True), + FaceType.WHOLE_FACE: (0.40, False), + FaceType.HEAD: (0.70, False), + FaceType.HEAD_NO_ALIGN: (0.70, True), +} + +def convert_98_to_68(lmrks): + #jaw + result = [ lmrks[0] ] + for i in range(2,16,2): + result += [ ( lmrks[i] + (lmrks[i-1]+lmrks[i+1])/2 ) / 2 ] + result += [ lmrks[16] ] + for i in range(18,32,2): + result += [ ( lmrks[i] + (lmrks[i-1]+lmrks[i+1])/2 ) / 2 ] + result += [ lmrks[32] ] + + #eyebrows averaging + result += [ lmrks[33], + (lmrks[34]+lmrks[41])/2, + (lmrks[35]+lmrks[40])/2, + (lmrks[36]+lmrks[39])/2, + (lmrks[37]+lmrks[38])/2, + ] + + result += [ (lmrks[42]+lmrks[50])/2, + (lmrks[43]+lmrks[49])/2, + (lmrks[44]+lmrks[48])/2, + (lmrks[45]+lmrks[47])/2, + lmrks[46] + ] + + #nose + result += list ( lmrks[51:60] ) + + #left eye (from our view) + result += [ lmrks[60], + lmrks[61], + lmrks[63], + lmrks[64], + lmrks[65], + lmrks[67] ] + + #right eye + result += [ lmrks[68], + lmrks[69], + lmrks[71], + lmrks[72], + lmrks[73], + lmrks[75] ] + + #mouth + result += list ( lmrks[76:96] ) + + return np.concatenate (result).reshape ( (68,2) ) + +def transform_points(points, mat, invert=False): + if invert: + mat = cv2.invertAffineTransform (mat) + points = np.expand_dims(points, axis=1) + points = cv2.transform(points, mat, points.shape) + points = np.squeeze(points) + return points + +def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0): + if not isinstance(image_landmarks, np.ndarray): + image_landmarks = np.array (image_landmarks) + + + # estimate landmarks transform from global space to local aligned space with bounds [0..1] + mat = umeyama( np.concatenate ( [ image_landmarks[17:49] , image_landmarks[54:55] ] ) , landmarks_2D_new, True)[0:2] + + # get corner points in global space + g_p = transform_points ( np.float32([(0,0),(1,0),(1,1),(0,1),(0.5,0.5) ]) , mat, True) + g_c = g_p[4] + + # calc diagonal vectors between corners in global space + tb_diag_vec = (g_p[2]-g_p[0]).astype(np.float32) + tb_diag_vec /= npla.norm(tb_diag_vec) + bt_diag_vec = (g_p[1]-g_p[3]).astype(np.float32) + bt_diag_vec /= npla.norm(bt_diag_vec) + + # calc modifier of diagonal vectors for scale and padding value + padding, remove_align = FaceType_to_padding_remove_align.get(face_type, 0.0) + mod = (1.0 / scale)* ( npla.norm(g_p[0]-g_p[2])*(padding*np.sqrt(2.0) + 0.5) ) + + if face_type == FaceType.WHOLE_FACE: + # adjust vertical offset for WHOLE_FACE, 7% below in order to cover more forehead + vec = (g_p[0]-g_p[3]).astype(np.float32) + vec_len = npla.norm(vec) + vec /= vec_len + g_c += vec*vec_len*0.07 + + elif face_type == FaceType.HEAD: + # assuming image_landmarks are 3D_Landmarks extracted for HEAD, + # adjust horizontal offset according to estimated yaw + yaw = estimate_averaged_yaw(transform_points (image_landmarks, mat, False)) + + hvec = (g_p[0]-g_p[1]).astype(np.float32) + hvec_len = npla.norm(hvec) + hvec /= hvec_len + + yaw *= np.abs(math.tanh(yaw*2)) # Damp near zero + + g_c -= hvec * (yaw * hvec_len / 2.0) + + # adjust vertical offset for HEAD, 50% below + vvec = (g_p[0]-g_p[3]).astype(np.float32) + vvec_len = npla.norm(vvec) + vvec /= vvec_len + g_c += vvec*vvec_len*0.50 + + # calc 3 points in global space to estimate 2d affine transform + if not remove_align: + l_t = np.array( [ g_c - tb_diag_vec*mod, + g_c + bt_diag_vec*mod, + g_c + tb_diag_vec*mod ] ) + else: + # remove_align - face will be centered in the frame but not aligned + l_t = np.array( [ g_c - tb_diag_vec*mod, + g_c + bt_diag_vec*mod, + g_c + tb_diag_vec*mod, + g_c - bt_diag_vec*mod, + ] ) + + # get area of face square in global space + area = mathlib.polygon_area(l_t[:,0], l_t[:,1] ) + + # calc side of square + side = np.float32(math.sqrt(area) / 2) + + # calc 3 points with unrotated square + l_t = np.array( [ g_c + [-side,-side], + g_c + [ side,-side], + g_c + [ side, side] ] ) + + # calc affine transform from 3 global space points to 3 local space points size of 'output_size' + pts2 = np.float32(( (0,0),(output_size,0),(output_size,output_size) )) + mat = cv2.getAffineTransform(l_t,pts2) + return mat + +def get_rect_from_landmarks(image_landmarks): + mat = get_transform_mat(image_landmarks, 256, FaceType.FULL_NO_ALIGN) + + g_p = transform_points ( np.float32([(0,0),(255,255) ]) , mat, True) + + (l,t,r,b) = g_p[0][0], g_p[0][1], g_p[1][0], g_p[1][1] + + return (l,t,r,b) + +def expand_eyebrows(lmrks, eyebrows_expand_mod=1.0): + if len(lmrks) != 68: + raise Exception('works only with 68 landmarks') + lmrks = np.array( lmrks.copy(), dtype=np.int ) + + # #nose + ml_pnt = (lmrks[36] + lmrks[0]) // 2 + mr_pnt = (lmrks[16] + lmrks[45]) // 2 + + # mid points between the mid points and eye + ql_pnt = (lmrks[36] + ml_pnt) // 2 + qr_pnt = (lmrks[45] + mr_pnt) // 2 + + # Top of the eye arrays + bot_l = np.array((ql_pnt, lmrks[36], lmrks[37], lmrks[38], lmrks[39])) + bot_r = np.array((lmrks[42], lmrks[43], lmrks[44], lmrks[45], qr_pnt)) + + # Eyebrow arrays + top_l = lmrks[17:22] + top_r = lmrks[22:27] + + # Adjust eyebrow arrays + lmrks[17:22] = top_l + eyebrows_expand_mod * 0.5 * (top_l - bot_l) + lmrks[22:27] = top_r + eyebrows_expand_mod * 0.5 * (top_r - bot_r) + return lmrks + + + + +def get_image_hull_mask (image_shape, image_landmarks, eyebrows_expand_mod=1.0 ): + hull_mask = np.zeros(image_shape[0:2]+(1,),dtype=np.float32) + + lmrks = expand_eyebrows(image_landmarks, eyebrows_expand_mod) + + r_jaw = (lmrks[0:9], lmrks[17:18]) + l_jaw = (lmrks[8:17], lmrks[26:27]) + r_cheek = (lmrks[17:20], lmrks[8:9]) + l_cheek = (lmrks[24:27], lmrks[8:9]) + nose_ridge = (lmrks[19:25], lmrks[8:9],) + r_eye = (lmrks[17:22], lmrks[27:28], lmrks[31:36], lmrks[8:9]) + l_eye = (lmrks[22:27], lmrks[27:28], lmrks[31:36], lmrks[8:9]) + nose = (lmrks[27:31], lmrks[31:36]) + parts = [r_jaw, l_jaw, r_cheek, l_cheek, nose_ridge, r_eye, l_eye, nose] + + for item in parts: + merged = np.concatenate(item) + cv2.fillConvexPoly(hull_mask, cv2.convexHull(merged), (1,) ) + + return hull_mask + +def get_image_eye_mask (image_shape, image_landmarks): + if len(image_landmarks) != 68: + raise Exception('get_image_eye_mask works only with 68 landmarks') + + h,w,c = image_shape + + hull_mask = np.zeros( (h,w,1),dtype=np.float32) + + image_landmarks = image_landmarks.astype(np.int) + + cv2.fillConvexPoly( hull_mask, cv2.convexHull( image_landmarks[36:42]), (1,) ) + cv2.fillConvexPoly( hull_mask, cv2.convexHull( image_landmarks[42:48]), (1,) ) + + dilate = h // 32 + hull_mask = cv2.dilate(hull_mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(dilate,dilate)), iterations = 1 ) + + blur = h // 16 + blur = blur + (1-blur % 2) + hull_mask = cv2.GaussianBlur(hull_mask, (blur, blur) , 0) + hull_mask = hull_mask[...,None] + + return hull_mask + +def get_image_mouth_mask (image_shape, image_landmarks): + if len(image_landmarks) != 68: + raise Exception('get_image_eye_mask works only with 68 landmarks') + + h,w,c = image_shape + + hull_mask = np.zeros( (h,w,1),dtype=np.float32) + + image_landmarks = image_landmarks.astype(np.int) + + cv2.fillConvexPoly( hull_mask, cv2.convexHull( image_landmarks[60:]), (1,) ) + + dilate = h // 32 + hull_mask = cv2.dilate(hull_mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(dilate,dilate)), iterations = 1 ) + + blur = h // 16 + blur = blur + (1-blur % 2) + hull_mask = cv2.GaussianBlur(hull_mask, (blur, blur) , 0) + hull_mask = hull_mask[...,None] + + return hull_mask + +def alpha_to_color (img_alpha, color): + if len(img_alpha.shape) == 2: + img_alpha = img_alpha[...,None] + h,w,c = img_alpha.shape + result = np.zeros( (h,w, len(color) ), dtype=np.float32 ) + result[:,:] = color + + return result * img_alpha + + + +def get_cmask (image_shape, lmrks, eyebrows_expand_mod=1.0): + h,w,c = image_shape + + hull = get_image_hull_mask (image_shape, lmrks, eyebrows_expand_mod ) + + result = np.zeros( (h,w,3), dtype=np.float32 ) + + + + def process(w,h, data ): + d = {} + cur_lc = 0 + all_lines = [] + for s, pts_loop_ar in data: + lines = [] + for pts, loop in pts_loop_ar: + pts_len = len(pts) + lines.append ( [ [ pts[i], pts[(i+1) % pts_len ] ] for i in range(pts_len - (0 if loop else 1) ) ] ) + lines = np.concatenate (lines) + + lc = lines.shape[0] + all_lines.append(lines) + d[s] = cur_lc, cur_lc+lc + cur_lc += lc + all_lines = np.concatenate (all_lines, 0) + + #calculate signed distance for all points and lines + line_count = all_lines.shape[0] + pts_count = w*h + + all_lines = np.repeat ( all_lines[None,...], pts_count, axis=0 ).reshape ( (pts_count*line_count,2,2) ) + + pts = np.empty( (h,w,line_count,2), dtype=np.float32 ) + pts[...,1] = np.arange(h)[:,None,None] + pts[...,0] = np.arange(w)[:,None] + pts = pts.reshape ( (h*w*line_count, -1) ) + + a = all_lines[:,0,:] + b = all_lines[:,1,:] + pa = pts-a + ba = b-a + ph = np.clip ( np.einsum('ij,ij->i', pa, ba) / np.einsum('ij,ij->i', ba, ba), 0, 1 ) + dists = npla.norm ( pa - ba*ph[...,None], axis=1).reshape ( (h,w,line_count) ) + + def get_dists(name, thickness=0): + s,e = d[name] + result = dists[...,s:e] + if thickness != 0: + result = np.abs(result)-thickness + return np.min (result, axis=-1) + + return get_dists + + l_eye = lmrks[42:48] + r_eye = lmrks[36:42] + l_brow = lmrks[22:27] + r_brow = lmrks[17:22] + mouth = lmrks[48:60] + + up_nose = np.concatenate( (lmrks[27:31], lmrks[33:34]) ) + down_nose = lmrks[31:36] + nose = np.concatenate ( (up_nose, down_nose) ) + + gdf = process ( w,h, + ( + ('eyes', ((l_eye, True), (r_eye, True)) ), + ('brows', ((l_brow, False), (r_brow,False)) ), + ('up_nose', ((up_nose, False),) ), + ('down_nose', ((down_nose, False),) ), + ('mouth', ((mouth, True),) ), + ) + ) + + eyes_fall_dist = w // 32 + eyes_thickness = max( w // 64, 1 ) + + brows_fall_dist = w // 32 + brows_thickness = max( w // 256, 1 ) + + nose_fall_dist = w / 12 + nose_thickness = max( w // 96, 1 ) + + mouth_fall_dist = w // 32 + mouth_thickness = max( w // 64, 1 ) + + eyes_mask = gdf('eyes',eyes_thickness) + eyes_mask = 1-np.clip( eyes_mask/ eyes_fall_dist, 0, 1) + #eyes_mask = np.clip ( 1- ( np.sqrt( np.maximum(eyes_mask,0) ) / eyes_fall_dist ), 0, 1) + #eyes_mask = np.clip ( 1- ( np.cbrt( np.maximum(eyes_mask,0) ) / eyes_fall_dist ), 0, 1) + + brows_mask = gdf('brows', brows_thickness) + brows_mask = 1-np.clip( brows_mask / brows_fall_dist, 0, 1) + #brows_mask = np.clip ( 1- ( np.sqrt( np.maximum(brows_mask,0) ) / brows_fall_dist ), 0, 1) + + mouth_mask = gdf('mouth', mouth_thickness) + mouth_mask = 1-np.clip( mouth_mask / mouth_fall_dist, 0, 1) + #mouth_mask = np.clip ( 1- ( np.sqrt( np.maximum(mouth_mask,0) ) / mouth_fall_dist ), 0, 1) + + def blend(a,b,k): + x = np.clip ( 0.5+0.5*(b-a)/k, 0.0, 1.0 ) + return (a-b)*x+b - k*x*(1.0-x) + + + #nose_mask = (a-b)*x+b - k*x*(1.0-x) + + #nose_mask = np.minimum (up_nose_mask , down_nose_mask ) + #nose_mask = 1-np.clip( nose_mask / nose_fall_dist, 0, 1) + + nose_mask = blend ( gdf('up_nose', nose_thickness), gdf('down_nose', nose_thickness), nose_thickness*3 ) + nose_mask = 1-np.clip( nose_mask / nose_fall_dist, 0, 1) + + up_nose_mask = gdf('up_nose', nose_thickness) + up_nose_mask = 1-np.clip( up_nose_mask / nose_fall_dist, 0, 1) + #up_nose_mask = np.clip ( 1- ( np.cbrt( np.maximum(up_nose_mask,0) ) / nose_fall_dist ), 0, 1) + + down_nose_mask = gdf('down_nose', nose_thickness) + down_nose_mask = 1-np.clip( down_nose_mask / nose_fall_dist, 0, 1) + #down_nose_mask = np.clip ( 1- ( np.cbrt( np.maximum(down_nose_mask,0) ) / nose_fall_dist ), 0, 1) + + #nose_mask = np.clip( up_nose_mask + down_nose_mask, 0, 1 ) + #nose_mask /= np.max(nose_mask) + #nose_mask = np.maximum (up_nose_mask , down_nose_mask ) + #nose_mask = down_nose_mask + + #nose_mask = np.zeros_like(nose_mask) + + eyes_mask = eyes_mask * (1-mouth_mask) + nose_mask = nose_mask * (1-eyes_mask) + + hull_mask = hull[...,0].copy() + hull_mask = hull_mask * (1-eyes_mask) * (1-brows_mask) * (1-nose_mask) * (1-mouth_mask) + + #eyes_mask = eyes_mask * (1-nose_mask) + + mouth_mask= mouth_mask * (1-nose_mask) + + brows_mask = brows_mask * (1-nose_mask)* (1-eyes_mask ) + + hull_mask = alpha_to_color(hull_mask, (0,1,0) ) + eyes_mask = alpha_to_color(eyes_mask, (1,0,0) ) + brows_mask = alpha_to_color(brows_mask, (0,0,1) ) + nose_mask = alpha_to_color(nose_mask, (0,1,1) ) + mouth_mask = alpha_to_color(mouth_mask, (0,0,1) ) + + #nose_mask = np.maximum( up_nose_mask, down_nose_mask ) + + result = hull_mask + mouth_mask+ nose_mask + brows_mask + eyes_mask + result *= hull + #result = np.clip (result, 0, 1) + return result + +def blur_image_hull_mask (hull_mask): + + maxregion = np.argwhere(hull_mask==1.0) + miny,minx = maxregion.min(axis=0)[:2] + maxy,maxx = maxregion.max(axis=0)[:2] + lenx = maxx - minx; + leny = maxy - miny; + masky = int(minx+(lenx//2)) + maskx = int(miny+(leny//2)) + lowest_len = min (lenx, leny) + ero = int( lowest_len * 0.085 ) + blur = int( lowest_len * 0.10 ) + + hull_mask = cv2.erode(hull_mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(ero,ero)), iterations = 1 ) + hull_mask = cv2.blur(hull_mask, (blur, blur) ) + hull_mask = np.expand_dims (hull_mask,-1) + + return hull_mask + +mirror_idxs = [ + [0,16], + [1,15], + [2,14], + [3,13], + [4,12], + [5,11], + [6,10], + [7,9], + + [17,26], + [18,25], + [19,24], + [20,23], + [21,22], + + [36,45], + [37,44], + [38,43], + [39,42], + [40,47], + [41,46], + + [31,35], + [32,34], + + [50,52], + [49,53], + [48,54], + [59,55], + [58,56], + [67,65], + [60,64], + [61,63] ] + +def mirror_landmarks (landmarks, val): + result = landmarks.copy() + + for idx in mirror_idxs: + result [ idx ] = result [ idx[::-1] ] + + result[:,0] = val - result[:,0] - 1 + return result + +def get_face_struct_mask (image_shape, image_landmarks, eyebrows_expand_mod=1.0, color=(1,) ): + mask = np.zeros(image_shape[0:2]+( len(color),),dtype=np.float32) + lmrks = expand_eyebrows(image_landmarks, eyebrows_expand_mod) + draw_landmarks (mask, image_landmarks, color=color, draw_circles=False, thickness=2) + return mask + +def draw_landmarks (image, image_landmarks, color=(0,255,0), draw_circles=True, thickness=1, transparent_mask=False): + if len(image_landmarks) != 68: + raise Exception('get_image_eye_mask works only with 68 landmarks') + + int_lmrks = np.array(image_landmarks, dtype=np.int) + + jaw = int_lmrks[slice(*landmarks_68_pt["jaw"])] + right_eyebrow = int_lmrks[slice(*landmarks_68_pt["right_eyebrow"])] + left_eyebrow = int_lmrks[slice(*landmarks_68_pt["left_eyebrow"])] + mouth = int_lmrks[slice(*landmarks_68_pt["mouth"])] + right_eye = int_lmrks[slice(*landmarks_68_pt["right_eye"])] + left_eye = int_lmrks[slice(*landmarks_68_pt["left_eye"])] + nose = int_lmrks[slice(*landmarks_68_pt["nose"])] + + # open shapes + cv2.polylines(image, tuple(np.array([v]) for v in ( right_eyebrow, jaw, left_eyebrow, np.concatenate((nose, [nose[-6]])) )), + False, color, thickness=thickness, lineType=cv2.LINE_AA) + # closed shapes + cv2.polylines(image, tuple(np.array([v]) for v in (right_eye, left_eye, mouth)), + True, color, thickness=thickness, lineType=cv2.LINE_AA) + + if draw_circles: + # the rest of the cicles + for x, y in np.concatenate((right_eyebrow, left_eyebrow, mouth, right_eye, left_eye, nose), axis=0): + cv2.circle(image, (x, y), 1, color, 1, lineType=cv2.LINE_AA) + # jaw big circles + for x, y in jaw: + cv2.circle(image, (x, y), 2, color, lineType=cv2.LINE_AA) + + if transparent_mask: + mask = get_image_hull_mask (image.shape, image_landmarks) + image[...] = ( image * (1-mask) + image * mask / 2 )[...] + +def draw_rect_landmarks (image, rect, image_landmarks, face_type, face_size=256, transparent_mask=False, landmarks_color=(0,255,0)): + draw_landmarks(image, image_landmarks, color=landmarks_color, transparent_mask=transparent_mask) + imagelib.draw_rect (image, rect, (255,0,0), 2 ) + + image_to_face_mat = get_transform_mat (image_landmarks, face_size, face_type) + points = transform_points ( [ (0,0), (0,face_size-1), (face_size-1, face_size-1), (face_size-1,0) ], image_to_face_mat, True) + imagelib.draw_polygon (image, points, (0,0,255), 2) + + points = transform_points ( [ ( int(face_size*0.05), 0), ( int(face_size*0.1), int(face_size*0.1) ), ( 0, int(face_size*0.1) ) ], image_to_face_mat, True) + imagelib.draw_polygon (image, points, (0,0,255), 2) + +def calc_face_pitch(landmarks): + if not isinstance(landmarks, np.ndarray): + landmarks = np.array (landmarks) + t = ( (landmarks[6][1]-landmarks[8][1]) + (landmarks[10][1]-landmarks[8][1]) ) / 2.0 + b = landmarks[8][1] + return float(b-t) + +def estimate_averaged_yaw(landmarks): + # Works much better than solvePnP if landmarks from "3DFAN" + if not isinstance(landmarks, np.ndarray): + landmarks = np.array (landmarks) + l = ( (landmarks[27][0]-landmarks[0][0]) + (landmarks[28][0]-landmarks[1][0]) + (landmarks[29][0]-landmarks[2][0]) ) / 3.0 + r = ( (landmarks[16][0]-landmarks[27][0]) + (landmarks[15][0]-landmarks[28][0]) + (landmarks[14][0]-landmarks[29][0]) ) / 3.0 + return float(r-l) + +def estimate_pitch_yaw_roll(aligned_landmarks, size=256): + """ + returns pitch,yaw,roll [-pi/2...+pi/2] + """ + shape = (size,size) + focal_length = shape[1] + camera_center = (shape[1] / 2, shape[0] / 2) + camera_matrix = np.array( + [[focal_length, 0, camera_center[0]], + [0, focal_length, camera_center[1]], + [0, 0, 1]], dtype=np.float32) + + (_, rotation_vector, _) = cv2.solvePnP( + np.concatenate( (landmarks_68_3D[:27], landmarks_68_3D[30:36]) , axis=0) , + np.concatenate( (aligned_landmarks[:27], aligned_landmarks[30:36]) , axis=0).astype(np.float32), + camera_matrix, + np.zeros((4, 1)) ) + + pitch, yaw, roll = mathlib.rotationMatrixToEulerAngles( cv2.Rodrigues(rotation_vector)[0] ) + + half_pi = math.pi / 2.0 + pitch = np.clip ( pitch, -half_pi, half_pi ) + yaw = np.clip ( yaw , -half_pi, half_pi ) + roll = np.clip ( roll, -half_pi, half_pi ) + + return -pitch, yaw, roll + +#if remove_align: +# bbox = transform_points ( [ (0,0), (0,output_size), (output_size, output_size), (output_size,0) ], mat, True) +# #import code +# #code.interact(local=dict(globals(), **locals())) +# area = mathlib.polygon_area(bbox[:,0], bbox[:,1] ) +# side = math.sqrt(area) / 2 +# center = transform_points ( [(output_size/2,output_size/2)], mat, True) +# pts1 = np.float32(( center+[-side,-side], center+[side,-side], center+[side,-side] )) +# pts2 = np.float32([[0,0],[output_size,0],[0,output_size]]) +# mat = cv2.getAffineTransform(pts1,pts2) +#if full_face_align_top and (face_type == FaceType.FULL or face_type == FaceType.FULL_NO_ALIGN): +# #lmrks2 = expand_eyebrows(image_landmarks) +# #lmrks2_ = transform_points( [ lmrks2[19], lmrks2[24] ], mat, False ) +# #y_diff = np.float32( (0,np.min(lmrks2_[:,1])) ) +# #y_diff = transform_points( [ np.float32( (0,0) ), y_diff], mat, True) +# #y_diff = y_diff[1]-y_diff[0] +# +# x_diff = np.float32((0,0)) +# +# lmrks2_ = transform_points( [ image_landmarks[0], image_landmarks[16] ], mat, False ) +# if lmrks2_[0,0] < 0: +# x_diff = lmrks2_[0,0] +# x_diff = transform_points( [ np.float32( (0,0) ), np.float32((x_diff,0)) ], mat, True) +# x_diff = x_diff[1]-x_diff[0] +# elif lmrks2_[1,0] >= output_size: +# x_diff = lmrks2_[1,0]-(output_size-1) +# x_diff = transform_points( [ np.float32( (0,0) ), np.float32((x_diff,0)) ], mat, True) +# x_diff = x_diff[1]-x_diff[0] +# +# mat = cv2.getAffineTransform( l_t+y_diff+x_diff ,pts2) + + +""" +def get_averaged_transform_mat (img_landmarks, + img_landmarks_prev, + img_landmarks_next, + average_frame_count, + average_center_frame_count, + output_size, face_type, scale=1.0): + + l_c_list = [] + tb_diag_vec_list = [] + bt_diag_vec_list = [] + mod_list = [] + + count = max(average_frame_count,average_center_frame_count) + for i in range ( -count, count+1, 1 ): + if i < 0: + lmrks = img_landmarks_prev[i] if -i < len(img_landmarks_prev) else None + elif i > 0: + lmrks = img_landmarks_next[i] if i < len(img_landmarks_next) else None + else: + lmrks = img_landmarks + + if lmrks is None: + continue + + l_c, tb_diag_vec, bt_diag_vec, mod = get_transform_mat_data (lmrks, face_type, scale=scale) + + if i >= -average_frame_count and i <= average_frame_count: + tb_diag_vec_list.append(tb_diag_vec) + bt_diag_vec_list.append(bt_diag_vec) + mod_list.append(mod) + + if i >= -average_center_frame_count and i <= average_center_frame_count: + l_c_list.append(l_c) + + tb_diag_vec = np.mean( np.array(tb_diag_vec_list), axis=0 ) + bt_diag_vec = np.mean( np.array(bt_diag_vec_list), axis=0 ) + mod = np.mean( np.array(mod_list), axis=0 ) + l_c = np.mean( np.array(l_c_list), axis=0 ) + + return get_transform_mat_by_data (l_c, tb_diag_vec, bt_diag_vec, mod, output_size, face_type) + + +def get_transform_mat (image_landmarks, output_size, face_type, scale=1.0): + if not isinstance(image_landmarks, np.ndarray): + image_landmarks = np.array (image_landmarks) + + # get face padding value for FaceType + padding, remove_align = FaceType_to_padding_remove_align.get(face_type, 0.0) + + # estimate landmarks transform from global space to local aligned space with bounds [0..1] + mat = umeyama( np.concatenate ( [ image_landmarks[17:49] , image_landmarks[54:55] ] ) , landmarks_2D_new, True)[0:2] + + # get corner points in global space + l_p = transform_points ( np.float32([(0,0),(1,0),(1,1),(0,1),(0.5,0.5)]) , mat, True) + l_c = l_p[4] + + # calc diagonal vectors between corners in global space + tb_diag_vec = (l_p[2]-l_p[0]).astype(np.float32) + tb_diag_vec /= npla.norm(tb_diag_vec) + bt_diag_vec = (l_p[1]-l_p[3]).astype(np.float32) + bt_diag_vec /= npla.norm(bt_diag_vec) + + # calc modifier of diagonal vectors for scale and padding value + mod = (1.0 / scale)* ( npla.norm(l_p[0]-l_p[2])*(padding*np.sqrt(2.0) + 0.5) ) + + # calc 3 points in global space to estimate 2d affine transform + if not remove_align: + l_t = np.array( [ np.round( l_c - tb_diag_vec*mod ), + np.round( l_c + bt_diag_vec*mod ), + np.round( l_c + tb_diag_vec*mod ) ] ) + else: + # remove_align - face will be centered in the frame but not aligned + l_t = np.array( [ np.round( l_c - tb_diag_vec*mod ), + np.round( l_c + bt_diag_vec*mod ), + np.round( l_c + tb_diag_vec*mod ), + np.round( l_c - bt_diag_vec*mod ), + ] ) + + # get area of face square in global space + area = mathlib.polygon_area(l_t[:,0], l_t[:,1] ) + + # calc side of square + side = np.float32(math.sqrt(area) / 2) + + # calc 3 points with unrotated square + l_t = np.array( [ np.round( l_c + [-side,-side] ), + np.round( l_c + [ side,-side] ), + np.round( l_c + [ side, side] ) ] ) + + # calc affine transform from 3 global space points to 3 local space points size of 'output_size' + pts2 = np.float32(( (0,0),(output_size,0),(output_size,output_size) )) + mat = cv2.getAffineTransform(l_t,pts2) + + return mat +""" \ No newline at end of file diff --git a/facelib/S3FD.npy b/facelib/S3FD.npy new file mode 100644 index 0000000000000000000000000000000000000000..94ba64b2037ee038ab8f7b1a8b19a933b4bcc5c1 --- /dev/null +++ b/facelib/S3FD.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b4894ecfba8e6461eb1a69490f76184620c816605238ff0ba3216f1143a06c29 +size 89840843 diff --git a/facelib/S3FDExtractor.py b/facelib/S3FDExtractor.py new file mode 100644 index 0000000000000000000000000000000000000000..0e743e387ca10579c34522d59e2f4d298390371e --- /dev/null +++ b/facelib/S3FDExtractor.py @@ -0,0 +1,269 @@ +import operator +from pathlib import Path + +import cv2 +import numpy as np + +from core.leras import nn + +class S3FDExtractor(object): + def __init__(self, place_model_on_cpu=False): + nn.initialize(data_format="NHWC") + tf = nn.tf + + model_path = Path(__file__).parent / "S3FD.npy" + if not model_path.exists(): + raise Exception("Unable to load S3FD.npy") + + class L2Norm(nn.LayerBase): + def __init__(self, n_channels, **kwargs): + self.n_channels = n_channels + super().__init__(**kwargs) + + def build_weights(self): + self.weight = tf.get_variable ("weight", (1, 1, 1, self.n_channels), dtype=nn.floatx, initializer=tf.initializers.ones ) + + def get_weights(self): + return [self.weight] + + def __call__(self, inputs): + x = inputs + x = x / (tf.sqrt( tf.reduce_sum( tf.pow(x, 2), axis=-1, keepdims=True ) ) + 1e-10) * self.weight + return x + + class S3FD(nn.ModelBase): + def __init__(self): + super().__init__(name='S3FD') + + def on_build(self): + self.minus = tf.constant([104,117,123], dtype=nn.floatx ) + self.conv1_1 = nn.Conv2D(3, 64, kernel_size=3, strides=1, padding='SAME') + self.conv1_2 = nn.Conv2D(64, 64, kernel_size=3, strides=1, padding='SAME') + + self.conv2_1 = nn.Conv2D(64, 128, kernel_size=3, strides=1, padding='SAME') + self.conv2_2 = nn.Conv2D(128, 128, kernel_size=3, strides=1, padding='SAME') + + self.conv3_1 = nn.Conv2D(128, 256, kernel_size=3, strides=1, padding='SAME') + self.conv3_2 = nn.Conv2D(256, 256, kernel_size=3, strides=1, padding='SAME') + self.conv3_3 = nn.Conv2D(256, 256, kernel_size=3, strides=1, padding='SAME') + + self.conv4_1 = nn.Conv2D(256, 512, kernel_size=3, strides=1, padding='SAME') + self.conv4_2 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') + self.conv4_3 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') + + self.conv5_1 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') + self.conv5_2 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') + self.conv5_3 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') + + self.fc6 = nn.Conv2D(512, 1024, kernel_size=3, strides=1, padding=3) + self.fc7 = nn.Conv2D(1024, 1024, kernel_size=1, strides=1, padding='SAME') + + self.conv6_1 = nn.Conv2D(1024, 256, kernel_size=1, strides=1, padding='SAME') + self.conv6_2 = nn.Conv2D(256, 512, kernel_size=3, strides=2, padding='SAME') + + self.conv7_1 = nn.Conv2D(512, 128, kernel_size=1, strides=1, padding='SAME') + self.conv7_2 = nn.Conv2D(128, 256, kernel_size=3, strides=2, padding='SAME') + + self.conv3_3_norm = L2Norm(256) + self.conv4_3_norm = L2Norm(512) + self.conv5_3_norm = L2Norm(512) + + + self.conv3_3_norm_mbox_conf = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME') + self.conv3_3_norm_mbox_loc = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME') + + self.conv4_3_norm_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME') + self.conv4_3_norm_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME') + + self.conv5_3_norm_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME') + self.conv5_3_norm_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME') + + self.fc7_mbox_conf = nn.Conv2D(1024, 2, kernel_size=3, strides=1, padding='SAME') + self.fc7_mbox_loc = nn.Conv2D(1024, 4, kernel_size=3, strides=1, padding='SAME') + + self.conv6_2_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME') + self.conv6_2_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME') + + self.conv7_2_mbox_conf = nn.Conv2D(256, 2, kernel_size=3, strides=1, padding='SAME') + self.conv7_2_mbox_loc = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME') + + def forward(self, inp): + x, = inp + x = x - self.minus + x = tf.nn.relu(self.conv1_1(x)) + x = tf.nn.relu(self.conv1_2(x)) + x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + + x = tf.nn.relu(self.conv2_1(x)) + x = tf.nn.relu(self.conv2_2(x)) + x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + + x = tf.nn.relu(self.conv3_1(x)) + x = tf.nn.relu(self.conv3_2(x)) + x = tf.nn.relu(self.conv3_3(x)) + f3_3 = x + x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + + x = tf.nn.relu(self.conv4_1(x)) + x = tf.nn.relu(self.conv4_2(x)) + x = tf.nn.relu(self.conv4_3(x)) + f4_3 = x + x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + + x = tf.nn.relu(self.conv5_1(x)) + x = tf.nn.relu(self.conv5_2(x)) + x = tf.nn.relu(self.conv5_3(x)) + f5_3 = x + x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + + x = tf.nn.relu(self.fc6(x)) + x = tf.nn.relu(self.fc7(x)) + ffc7 = x + + x = tf.nn.relu(self.conv6_1(x)) + x = tf.nn.relu(self.conv6_2(x)) + f6_2 = x + + x = tf.nn.relu(self.conv7_1(x)) + x = tf.nn.relu(self.conv7_2(x)) + f7_2 = x + + f3_3 = self.conv3_3_norm(f3_3) + f4_3 = self.conv4_3_norm(f4_3) + f5_3 = self.conv5_3_norm(f5_3) + + cls1 = self.conv3_3_norm_mbox_conf(f3_3) + reg1 = self.conv3_3_norm_mbox_loc(f3_3) + + cls2 = tf.nn.softmax(self.conv4_3_norm_mbox_conf(f4_3)) + reg2 = self.conv4_3_norm_mbox_loc(f4_3) + + cls3 = tf.nn.softmax(self.conv5_3_norm_mbox_conf(f5_3)) + reg3 = self.conv5_3_norm_mbox_loc(f5_3) + + cls4 = tf.nn.softmax(self.fc7_mbox_conf(ffc7)) + reg4 = self.fc7_mbox_loc(ffc7) + + cls5 = tf.nn.softmax(self.conv6_2_mbox_conf(f6_2)) + reg5 = self.conv6_2_mbox_loc(f6_2) + + cls6 = tf.nn.softmax(self.conv7_2_mbox_conf(f7_2)) + reg6 = self.conv7_2_mbox_loc(f7_2) + + # max-out background label + bmax = tf.maximum(tf.maximum(cls1[:,:,:,0:1], cls1[:,:,:,1:2]), cls1[:,:,:,2:3]) + + cls1 = tf.concat ([bmax, cls1[:,:,:,3:4] ], axis=-1) + cls1 = tf.nn.softmax(cls1) + + return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] + + e = None + if place_model_on_cpu: + e = tf.device("/CPU:0") + + if e is not None: e.__enter__() + self.model = S3FD() + self.model.load_weights (model_path) + if e is not None: e.__exit__(None,None,None) + + self.model.build_for_run ([ ( tf.float32, nn.get4Dshape (None,None,3) ) ]) + + def __enter__(self): + return self + + def __exit__(self, exc_type=None, exc_value=None, traceback=None): + return False #pass exception between __enter__ and __exit__ to outter level + + def extract (self, input_image, is_bgr=True, is_remove_intersects=False): + + if is_bgr: + input_image = input_image[:,:,::-1] + is_bgr = False + + (h, w, ch) = input_image.shape + + d = max(w, h) + scale_to = 640 if d >= 1280 else d / 2 + scale_to = max(64, scale_to) + + input_scale = d / scale_to + input_image = cv2.resize (input_image, ( int(w/input_scale), int(h/input_scale) ), interpolation=cv2.INTER_LINEAR) + + olist = self.model.run ([ input_image[None,...] ] ) + + detected_faces = [] + for ltrb in self.refine (olist): + l,t,r,b = [ x*input_scale for x in ltrb] + bt = b-t + if min(r-l,bt) < 40: #filtering faces < 40pix by any side + continue + b += bt*0.1 #enlarging bottom line a bit for 2DFAN-4, because default is not enough covering a chin + detected_faces.append ( [int(x) for x in (l,t,r,b) ] ) + + #sort by largest area first + detected_faces = [ [(l,t,r,b), (r-l)*(b-t) ] for (l,t,r,b) in detected_faces ] + detected_faces = sorted(detected_faces, key=operator.itemgetter(1), reverse=True ) + detected_faces = [ x[0] for x in detected_faces] + + if is_remove_intersects: + for i in range( len(detected_faces)-1, 0, -1): + l1,t1,r1,b1 = detected_faces[i] + l0,t0,r0,b0 = detected_faces[i-1] + + dx = min(r0, r1) - max(l0, l1) + dy = min(b0, b1) - max(t0, t1) + if (dx>=0) and (dy>=0): + detected_faces.pop(i) + + return detected_faces + + def refine(self, olist): + bboxlist = [] + for i, ((ocls,), (oreg,)) in enumerate ( zip ( olist[::2], olist[1::2] ) ): + stride = 2**(i + 2) # 4,8,16,32,64,128 + s_d2 = stride / 2 + s_m4 = stride * 4 + + for hindex, windex in zip(*np.where(ocls[...,1] > 0.05)): + score = ocls[hindex, windex, 1] + loc = oreg[hindex, windex, :] + priors = np.array([windex * stride + s_d2, hindex * stride + s_d2, s_m4, s_m4]) + priors_2p = priors[2:] + box = np.concatenate((priors[:2] + loc[:2] * 0.1 * priors_2p, + priors_2p * np.exp(loc[2:] * 0.2)) ) + box[:2] -= box[2:] / 2 + box[2:] += box[:2] + + bboxlist.append([*box, score]) + + bboxlist = np.array(bboxlist) + if len(bboxlist) == 0: + bboxlist = np.zeros((1, 5)) + + bboxlist = bboxlist[self.refine_nms(bboxlist, 0.3), :] + bboxlist = [ x[:-1].astype(np.int) for x in bboxlist if x[-1] >= 0.5] + return bboxlist + + def refine_nms(self, dets, thresh): + keep = list() + if len(dets) == 0: + return keep + + x_1, y_1, x_2, y_2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] + areas = (x_2 - x_1 + 1) * (y_2 - y_1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx_1, yy_1 = np.maximum(x_1[i], x_1[order[1:]]), np.maximum(y_1[i], y_1[order[1:]]) + xx_2, yy_2 = np.minimum(x_2[i], x_2[order[1:]]), np.minimum(y_2[i], y_2[order[1:]]) + + width, height = np.maximum(0.0, xx_2 - xx_1 + 1), np.maximum(0.0, yy_2 - yy_1 + 1) + ovr = width * height / (areas[i] + areas[order[1:]] - width * height) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + return keep diff --git a/facelib/XSegNet.py b/facelib/XSegNet.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2bd0885eb02539efe0eb4d0fd7f7c8eb8053c5 --- /dev/null +++ b/facelib/XSegNet.py @@ -0,0 +1,108 @@ +import os +import pickle +from functools import partial +from pathlib import Path + +import cv2 +import numpy as np + +from core.interact import interact as io +from core.leras import nn + + +class XSegNet(object): + VERSION = 1 + + def __init__ (self, name, + resolution=256, + load_weights=True, + weights_file_root=None, + training=False, + place_model_on_cpu=False, + run_on_cpu=False, + optimizer=None, + data_format="NHWC", + raise_on_no_model_files=False): + + self.resolution = resolution + self.weights_file_root = Path(weights_file_root) if weights_file_root is not None else Path(__file__).parent + + nn.initialize(data_format=data_format) + tf = nn.tf + + model_name = f'{name}_{resolution}' + self.model_filename_list = [] + + with tf.device ('/CPU:0'): + #Place holders on CPU + self.input_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,3) ) + self.target_t = tf.placeholder (nn.floatx, nn.get4Dshape(resolution,resolution,1) ) + + # Initializing model classes + with tf.device ('/CPU:0' if place_model_on_cpu else nn.tf_default_device_name): + self.model = nn.XSeg(3, 32, 1, name=name) + self.model_weights = self.model.get_weights() + if training: + if optimizer is None: + raise ValueError("Optimizer should be provided for training mode.") + self.opt = optimizer + self.opt.initialize_variables (self.model_weights, vars_on_cpu=place_model_on_cpu) + self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ] + + + self.model_filename_list += [ [self.model, f'{model_name}.npy'] ] + + if not training: + with tf.device ('/CPU:0' if run_on_cpu else nn.tf_default_device_name): + _, pred = self.model(self.input_t) + + def net_run(input_np): + return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0] + self.net_run = net_run + + self.initialized = True + # Loading/initializing all models/optimizers weights + for model, filename in self.model_filename_list: + do_init = not load_weights + + if not do_init: + model_file_path = self.weights_file_root / filename + do_init = not model.load_weights( model_file_path ) + if do_init: + if raise_on_no_model_files: + raise Exception(f'{model_file_path} does not exists.') + if not training: + self.initialized = False + break + + if do_init: + model.init_weights() + + def get_resolution(self): + return self.resolution + + def flow(self, x, pretrain=False): + return self.model(x, pretrain=pretrain) + + def get_weights(self): + return self.model_weights + + def save_weights(self): + for model, filename in io.progress_bar_generator(self.model_filename_list, "Saving", leave=False): + model.save_weights( self.weights_file_root / filename ) + + def extract (self, input_image): + if not self.initialized: + return 0.5*np.ones ( (self.resolution, self.resolution, 1), nn.floatx.as_numpy_dtype ) + + input_shape_len = len(input_image.shape) + if input_shape_len == 3: + input_image = input_image[None,...] + + result = np.clip ( self.net_run(input_image), 0, 1.0 ) + result[result < 0.1] = 0 #get rid of noise + + if input_shape_len == 3: + result = result[0] + + return result \ No newline at end of file diff --git a/facelib/__init__.py b/facelib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e46ca518c59fd87cd9bdb97b4429e47a60fd8cf7 --- /dev/null +++ b/facelib/__init__.py @@ -0,0 +1,5 @@ +from .FaceType import FaceType +from .S3FDExtractor import S3FDExtractor +from .FANExtractor import FANExtractor +from .FaceEnhancer import FaceEnhancer +from .XSegNet import XSegNet \ No newline at end of file diff --git a/localization/__init__.py b/localization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd8c6e515d22f0e7aef97af01ef14c99e8024ca --- /dev/null +++ b/localization/__init__.py @@ -0,0 +1,2 @@ +from .localization import StringsDB, system_language, get_default_ttf_font_name + diff --git a/localization/localization.py b/localization/localization.py new file mode 100644 index 0000000000000000000000000000000000000000..3df7bbd270e0355938d2689f3861dd3a55c72603 --- /dev/null +++ b/localization/localization.py @@ -0,0 +1,42 @@ +import sys +import locale + +system_locale = locale.getdefaultlocale()[0] +# system_locale may be nil +system_language = system_locale[0:2] if system_locale is not None else "en" +if system_language not in ['en','ru','zh']: + system_language = 'en' + +windows_font_name_map = { + 'en' : 'cour', + 'ru' : 'cour', + 'zh' : 'simsun_01' +} + +darwin_font_name_map = { + 'en' : 'cour', + 'ru' : 'cour', + 'zh' : 'Apple LiSung Light' +} + +linux_font_name_map = { + 'en' : 'cour', + 'ru' : 'cour', + 'zh' : 'cour' +} + +def get_default_ttf_font_name(): + platform = sys.platform + if platform[0:3] == 'win': return windows_font_name_map.get(system_language, 'cour') + elif platform == 'darwin': return darwin_font_name_map.get(system_language, 'cour') + else: return linux_font_name_map.get(system_language, 'cour') + +SID_HOT_KEY = 1 + +if system_language == 'en': + StringsDB = {'S_HOT_KEY' : 'hot key'} +elif system_language == 'ru': + StringsDB = {'S_HOT_KEY' : 'горячая клавиша'} +elif system_language == 'zh': + StringsDB = {'S_HOT_KEY' : '热键'} + \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..493c7edd60dde6ae1d1673383ef9f2e069f57eb8 --- /dev/null +++ b/main.py @@ -0,0 +1,353 @@ +if __name__ == "__main__": + # Fix for linux + import multiprocessing + multiprocessing.set_start_method("spawn") + + from core.leras import nn + nn.initialize_main_env() + import os + import sys + import time + import argparse + + from core import pathex + from core import osex + from pathlib import Path + from core.interact import interact as io + + if sys.version_info[0] < 3 or (sys.version_info[0] == 3 and sys.version_info[1] < 6): + raise Exception("This program requires at least Python 3.6") + + class fixPathAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, os.path.abspath(os.path.expanduser(values))) + + exit_code = 0 + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers() + + def process_extract(arguments): + osex.set_process_lowest_prio() + from mainscripts import Extractor + Extractor.main( detector = arguments.detector, + input_path = Path(arguments.input_dir), + output_path = Path(arguments.output_dir), + output_debug = arguments.output_debug, + manual_fix = arguments.manual_fix, + manual_output_debug_fix = arguments.manual_output_debug_fix, + manual_window_size = arguments.manual_window_size, + face_type = arguments.face_type, + max_faces_from_image = arguments.max_faces_from_image, + image_size = arguments.image_size, + jpeg_quality = arguments.jpeg_quality, + cpu_only = arguments.cpu_only, + force_gpu_idxs = [ int(x) for x in arguments.force_gpu_idxs.split(',') ] if arguments.force_gpu_idxs is not None else None, + ) + + p = subparsers.add_parser( "extract", help="Extract the faces from a pictures.") + p.add_argument('--detector', dest="detector", choices=['s3fd','manual'], default=None, help="Type of detector.") + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") + p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the extracted files will be stored.") + p.add_argument('--output-debug', action="store_true", dest="output_debug", default=None, help="Writes debug images to _debug\ directory.") + p.add_argument('--no-output-debug', action="store_false", dest="output_debug", default=None, help="Don't writes debug images to _debug\ directory.") + p.add_argument('--face-type', dest="face_type", choices=['half_face', 'full_face', 'whole_face', 'head', 'mark_only'], default=None) + p.add_argument('--max-faces-from-image', type=int, dest="max_faces_from_image", default=None, help="Max faces from image.") + p.add_argument('--image-size', type=int, dest="image_size", default=None, help="Output image size.") + p.add_argument('--jpeg-quality', type=int, dest="jpeg_quality", default=None, help="Jpeg quality.") + p.add_argument('--manual-fix', action="store_true", dest="manual_fix", default=False, help="Enables manual extract only frames where faces were not recognized.") + p.add_argument('--manual-output-debug-fix', action="store_true", dest="manual_output_debug_fix", default=False, help="Performs manual reextract input-dir frames which were deleted from [output_dir]_debug\ dir.") + p.add_argument('--manual-window-size', type=int, dest="manual_window_size", default=1368, help="Manual fix window size. Default: 1368.") + p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Extract on CPU..") + p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") + + p.set_defaults (func=process_extract) + + def process_sort(arguments): + osex.set_process_lowest_prio() + from mainscripts import Sorter + Sorter.main (input_path=Path(arguments.input_dir), sort_by_method=arguments.sort_by_method) + + p = subparsers.add_parser( "sort", help="Sort faces in a directory.") + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") + p.add_argument('--by', dest="sort_by_method", default=None, choices=("blur", "motion-blur", "face-yaw", "face-pitch", "face-source-rect-size", "hist", "hist-dissim", "brightness", "hue", "black", "origname", "oneface", "final-by-blur", "final-by-size", "absdiff"), help="Method of sorting. 'origname' sort by original filename to recover original sequence." ) + p.set_defaults (func=process_sort) + + def process_util(arguments): + osex.set_process_lowest_prio() + from mainscripts import Util + + if arguments.add_landmarks_debug_images: + Util.add_landmarks_debug_images (input_path=arguments.input_dir) + + if arguments.recover_original_aligned_filename: + Util.recover_original_aligned_filename (input_path=arguments.input_dir) + + if arguments.save_faceset_metadata: + Util.save_faceset_metadata_folder (input_path=arguments.input_dir) + + if arguments.restore_faceset_metadata: + Util.restore_faceset_metadata_folder (input_path=arguments.input_dir) + + if arguments.pack_faceset: + io.log_info ("Performing faceset packing...\r\n") + from samplelib import PackedFaceset + PackedFaceset.pack( Path(arguments.input_dir) ) + + if arguments.unpack_faceset: + io.log_info ("Performing faceset unpacking...\r\n") + from samplelib import PackedFaceset + PackedFaceset.unpack( Path(arguments.input_dir) ) + + p = subparsers.add_parser( "util", help="Utilities.") + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") + p.add_argument('--add-landmarks-debug-images', action="store_true", dest="add_landmarks_debug_images", default=False, help="Add landmarks debug image for aligned faces.") + p.add_argument('--recover-original-aligned-filename', action="store_true", dest="recover_original_aligned_filename", default=False, help="Recover original aligned filename.") + p.add_argument('--save-faceset-metadata', action="store_true", dest="save_faceset_metadata", default=False, help="Save faceset metadata to file.") + p.add_argument('--restore-faceset-metadata', action="store_true", dest="restore_faceset_metadata", default=False, help="Restore faceset metadata to file. Image filenames must be the same as used with save.") + p.add_argument('--pack-faceset', action="store_true", dest="pack_faceset", default=False, help="") + p.add_argument('--unpack-faceset', action="store_true", dest="unpack_faceset", default=False, help="") + + p.set_defaults (func=process_util) + + def process_train(arguments): + osex.set_process_lowest_prio() + + + kwargs = {'model_class_name' : arguments.model_name, + 'saved_models_path' : Path(arguments.model_dir), + 'training_data_src_path' : Path(arguments.training_data_src_dir), + 'training_data_dst_path' : Path(arguments.training_data_dst_dir), + 'pretraining_data_path' : Path(arguments.pretraining_data_dir) if arguments.pretraining_data_dir is not None else None, + 'pretrained_model_path' : Path(arguments.pretrained_model_dir) if arguments.pretrained_model_dir is not None else None, + 'no_preview' : arguments.no_preview, + 'force_model_name' : arguments.force_model_name, + 'force_gpu_idxs' : [ int(x) for x in arguments.force_gpu_idxs.split(',') ] if arguments.force_gpu_idxs is not None else None, + 'cpu_only' : arguments.cpu_only, + 'silent_start' : arguments.silent_start, + 'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ], + 'debug' : arguments.debug, + } + from mainscripts import Trainer + Trainer.main(**kwargs) + + p = subparsers.add_parser( "train", help="Trainer") + p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir", help="Dir of extracted SRC faceset.") + p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir", help="Dir of extracted DST faceset.") + p.add_argument('--pretraining-data-dir', action=fixPathAction, dest="pretraining_data_dir", default=None, help="Optional dir of extracted faceset that will be used in pretraining mode.") + p.add_argument('--pretrained-model-dir', action=fixPathAction, dest="pretrained_model_dir", default=None, help="Optional dir of pretrain model files. (Currently only for Quick96).") + p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Saved models dir.") + p.add_argument('--model', required=True, dest="model_name", choices=pathex.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Model class name.") + p.add_argument('--debug', action="store_true", dest="debug", default=False, help="Debug samples.") + p.add_argument('--no-preview', action="store_true", dest="no_preview", default=False, help="Disable preview window.") + p.add_argument('--force-model-name', dest="force_model_name", default=None, help="Forcing to choose model name from model/ folder.") + p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.") + p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") + p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.") + + p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+') + p.set_defaults (func=process_train) + + def process_exportdfm(arguments): + osex.set_process_lowest_prio() + from mainscripts import ExportDFM + ExportDFM.main(model_class_name = arguments.model_name, saved_models_path = Path(arguments.model_dir)) + + p = subparsers.add_parser( "exportdfm", help="Export model to use in DeepFaceLive.") + p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Saved models dir.") + p.add_argument('--model', required=True, dest="model_name", choices=pathex.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Model class name.") + p.set_defaults (func=process_exportdfm) + + def process_merge(arguments): + osex.set_process_lowest_prio() + from mainscripts import Merger + Merger.main ( model_class_name = arguments.model_name, + saved_models_path = Path(arguments.model_dir), + force_model_name = arguments.force_model_name, + input_path = Path(arguments.input_dir), + output_path = Path(arguments.output_dir), + output_mask_path = Path(arguments.output_mask_dir), + aligned_path = Path(arguments.aligned_dir) if arguments.aligned_dir is not None else None, + force_gpu_idxs = arguments.force_gpu_idxs, + cpu_only = arguments.cpu_only) + + p = subparsers.add_parser( "merge", help="Merger") + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") + p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the merged files will be stored.") + p.add_argument('--output-mask-dir', required=True, action=fixPathAction, dest="output_mask_dir", help="Output mask directory. This is where the mask files will be stored.") + p.add_argument('--aligned-dir', action=fixPathAction, dest="aligned_dir", default=None, help="Aligned directory. This is where the extracted of dst faces stored.") + p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.") + p.add_argument('--model', required=True, dest="model_name", choices=pathex.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Model class name.") + p.add_argument('--force-model-name', dest="force_model_name", default=None, help="Forcing to choose model name from model/ folder.") + p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Merge on CPU.") + p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") + p.set_defaults(func=process_merge) + + videoed_parser = subparsers.add_parser( "videoed", help="Video processing.").add_subparsers() + + def process_videoed_extract_video(arguments): + osex.set_process_lowest_prio() + from mainscripts import VideoEd + VideoEd.extract_video (arguments.input_file, arguments.output_dir, arguments.output_ext, arguments.fps) + p = videoed_parser.add_parser( "extract-video", help="Extract images from video file.") + p.add_argument('--input-file', required=True, action=fixPathAction, dest="input_file", help="Input file to be processed. Specify .*-extension to find first file.") + p.add_argument('--output-dir', required=True, action=fixPathAction, dest="output_dir", help="Output directory. This is where the extracted images will be stored.") + p.add_argument('--output-ext', dest="output_ext", default=None, help="Image format (extension) of output files.") + p.add_argument('--fps', type=int, dest="fps", default=None, help="How many frames of every second of the video will be extracted. 0 - full fps.") + p.set_defaults(func=process_videoed_extract_video) + + def process_videoed_cut_video(arguments): + osex.set_process_lowest_prio() + from mainscripts import VideoEd + VideoEd.cut_video (arguments.input_file, + arguments.from_time, + arguments.to_time, + arguments.audio_track_id, + arguments.bitrate) + p = videoed_parser.add_parser( "cut-video", help="Cut video file.") + p.add_argument('--input-file', required=True, action=fixPathAction, dest="input_file", help="Input file to be processed. Specify .*-extension to find first file.") + p.add_argument('--from-time', dest="from_time", default=None, help="From time, for example 00:00:00.000") + p.add_argument('--to-time', dest="to_time", default=None, help="To time, for example 00:00:00.000") + p.add_argument('--audio-track-id', type=int, dest="audio_track_id", default=None, help="Specify audio track id.") + p.add_argument('--bitrate', type=int, dest="bitrate", default=None, help="Bitrate of output file in Megabits.") + p.set_defaults(func=process_videoed_cut_video) + + def process_videoed_denoise_image_sequence(arguments): + osex.set_process_lowest_prio() + from mainscripts import VideoEd + VideoEd.denoise_image_sequence (arguments.input_dir, arguments.factor) + p = videoed_parser.add_parser( "denoise-image-sequence", help="Denoise sequence of images, keeping sharp edges. Helps to remove pixel shake from the predicted face.") + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory to be processed.") + p.add_argument('--factor', type=int, dest="factor", default=None, help="Denoise factor (1-20).") + p.set_defaults(func=process_videoed_denoise_image_sequence) + + def process_videoed_video_from_sequence(arguments): + osex.set_process_lowest_prio() + from mainscripts import VideoEd + VideoEd.video_from_sequence (input_dir = arguments.input_dir, + output_file = arguments.output_file, + reference_file = arguments.reference_file, + ext = arguments.ext, + fps = arguments.fps, + bitrate = arguments.bitrate, + include_audio = arguments.include_audio, + lossless = arguments.lossless) + + p = videoed_parser.add_parser( "video-from-sequence", help="Make video from image sequence.") + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input file to be processed. Specify .*-extension to find first file.") + p.add_argument('--output-file', required=True, action=fixPathAction, dest="output_file", help="Input file to be processed. Specify .*-extension to find first file.") + p.add_argument('--reference-file', action=fixPathAction, dest="reference_file", help="Reference file used to determine proper FPS and transfer audio from it. Specify .*-extension to find first file.") + p.add_argument('--ext', dest="ext", default='png', help="Image format (extension) of input files.") + p.add_argument('--fps', type=int, dest="fps", default=None, help="FPS of output file. Overwritten by reference-file.") + p.add_argument('--bitrate', type=int, dest="bitrate", default=None, help="Bitrate of output file in Megabits.") + p.add_argument('--include-audio', action="store_true", dest="include_audio", default=False, help="Include audio from reference file.") + p.add_argument('--lossless', action="store_true", dest="lossless", default=False, help="PNG codec.") + + p.set_defaults(func=process_videoed_video_from_sequence) + + facesettool_parser = subparsers.add_parser( "facesettool", help="Faceset tools.").add_subparsers() + + def process_faceset_enhancer(arguments): + osex.set_process_lowest_prio() + from mainscripts import FacesetEnhancer + FacesetEnhancer.process_folder ( Path(arguments.input_dir), + cpu_only=arguments.cpu_only, + force_gpu_idxs=arguments.force_gpu_idxs + ) + + p = facesettool_parser.add_parser ("enhance", help="Enhance details in DFL faceset.") + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") + p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Process on CPU.") + p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.") + + p.set_defaults(func=process_faceset_enhancer) + + + p = facesettool_parser.add_parser ("resize", help="Resize DFL faceset.") + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory of aligned faces.") + + def process_faceset_resizer(arguments): + osex.set_process_lowest_prio() + from mainscripts import FacesetResizer + FacesetResizer.process_folder ( Path(arguments.input_dir) ) + p.set_defaults(func=process_faceset_resizer) + + def process_dev_test(arguments): + osex.set_process_lowest_prio() + from mainscripts import dev_misc + dev_misc.dev_test( arguments.input_dir ) + + p = subparsers.add_parser( "dev_test", help="") + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") + p.set_defaults (func=process_dev_test) + + # ========== XSeg + xseg_parser = subparsers.add_parser( "xseg", help="XSeg tools.").add_subparsers() + + p = xseg_parser.add_parser( "editor", help="XSeg editor.") + + def process_xsegeditor(arguments): + osex.set_process_lowest_prio() + from XSegEditor import XSegEditor + global exit_code + exit_code = XSegEditor.start (Path(arguments.input_dir)) + + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") + + p.set_defaults (func=process_xsegeditor) + + p = xseg_parser.add_parser( "apply", help="Apply trained XSeg model to the extracted faces.") + + def process_xsegapply(arguments): + osex.set_process_lowest_prio() + from mainscripts import XSegUtil + XSegUtil.apply_xseg (Path(arguments.input_dir), Path(arguments.model_dir)) + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") + p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir") + p.set_defaults (func=process_xsegapply) + + + p = xseg_parser.add_parser( "remove", help="Remove applied XSeg masks from the extracted faces.") + def process_xsegremove(arguments): + osex.set_process_lowest_prio() + from mainscripts import XSegUtil + XSegUtil.remove_xseg (Path(arguments.input_dir) ) + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") + p.set_defaults (func=process_xsegremove) + + + p = xseg_parser.add_parser( "remove_labels", help="Remove XSeg labels from the extracted faces.") + def process_xsegremovelabels(arguments): + osex.set_process_lowest_prio() + from mainscripts import XSegUtil + XSegUtil.remove_xseg_labels (Path(arguments.input_dir) ) + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") + p.set_defaults (func=process_xsegremovelabels) + + + p = xseg_parser.add_parser( "fetch", help="Copies faces containing XSeg polygons in _xseg dir.") + + def process_xsegfetch(arguments): + osex.set_process_lowest_prio() + from mainscripts import XSegUtil + XSegUtil.fetch_xseg (Path(arguments.input_dir) ) + p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir") + p.set_defaults (func=process_xsegfetch) + + def bad_args(arguments): + parser.print_help() + exit(0) + parser.set_defaults(func=bad_args) + + arguments = parser.parse_args() + arguments.func(arguments) + + if exit_code == 0: + print ("Done.") + + exit(exit_code) + +''' +import code +code.interact(local=dict(globals(), **locals())) +''' diff --git a/mainscripts/ExportDFM.py b/mainscripts/ExportDFM.py new file mode 100644 index 0000000000000000000000000000000000000000..cf7d64e34205af14ad203eef8d6f5be09e0d879d --- /dev/null +++ b/mainscripts/ExportDFM.py @@ -0,0 +1,22 @@ +import os +import sys +import traceback +import queue +import threading +import time +import numpy as np +import itertools +from pathlib import Path +from core import pathex +from core import imagelib +import cv2 +import models +from core.interact import interact as io + + +def main(model_class_name, saved_models_path): + model = models.import_model(model_class_name)( + is_exporting=True, + saved_models_path=saved_models_path, + cpu_only=True) + model.export_dfm () diff --git a/mainscripts/Extractor.py b/mainscripts/Extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..365804fced3ad3130604ba3778feaba2ac01bc18 --- /dev/null +++ b/mainscripts/Extractor.py @@ -0,0 +1,845 @@ +import traceback +import math +import multiprocessing +import operator +import os +import shutil +import sys +import time +from pathlib import Path + +import cv2 +import numpy as np +from numpy import linalg as npla + +import facelib +from core import imagelib +from core import mathlib +from facelib import FaceType, LandmarksProcessor +from core.interact import interact as io +from core.joblib import Subprocessor +from core.leras import nn +from core import pathex +from core.cv2ex import * +from DFLIMG import * + +DEBUG = False + +class ExtractSubprocessor(Subprocessor): + class Data(object): + def __init__(self, filepath=None, rects=None, landmarks = None, landmarks_accurate=True, manual=False, force_output_path=None, final_output_files = None): + self.filepath = filepath + self.rects = rects or [] + self.rects_rotation = 0 + self.landmarks_accurate = landmarks_accurate + self.manual = manual + self.landmarks = landmarks or [] + self.force_output_path = force_output_path + self.final_output_files = final_output_files or [] + self.faces_detected = 0 + + class Cli(Subprocessor.Cli): + + #override + def on_initialize(self, client_dict): + self.type = client_dict['type'] + self.image_size = client_dict['image_size'] + self.jpeg_quality = client_dict['jpeg_quality'] + self.face_type = client_dict['face_type'] + self.max_faces_from_image = client_dict['max_faces_from_image'] + self.device_idx = client_dict['device_idx'] + self.cpu_only = client_dict['device_type'] == 'CPU' + self.final_output_path = client_dict['final_output_path'] + self.output_debug_path = client_dict['output_debug_path'] + + #transfer and set stdin in order to work code.interact in debug subprocess + stdin_fd = client_dict['stdin_fd'] + if stdin_fd is not None and DEBUG: + sys.stdin = os.fdopen(stdin_fd) + + if self.cpu_only: + device_config = nn.DeviceConfig.CPU() + place_model_on_cpu = True + else: + device_config = nn.DeviceConfig.GPUIndexes ([self.device_idx]) + place_model_on_cpu = device_config.devices[0].total_mem_gb < 4 + + if self.type == 'all' or 'rects' in self.type or 'landmarks' in self.type: + nn.initialize (device_config) + + self.log_info (f"Running on {client_dict['device_name'] }") + + if self.type == 'all' or self.type == 'rects-s3fd' or 'landmarks' in self.type: + self.rects_extractor = facelib.S3FDExtractor(place_model_on_cpu=place_model_on_cpu) + + if self.type == 'all' or 'landmarks' in self.type: + # for head type, extract "3D landmarks" + self.landmarks_extractor = facelib.FANExtractor(landmarks_3D=self.face_type >= FaceType.HEAD, + place_model_on_cpu=place_model_on_cpu) + + self.cached_image = (None, None) + + #override + def process_data(self, data): + if 'landmarks' in self.type and len(data.rects) == 0: + return data + + filepath = data.filepath + cached_filepath, image = self.cached_image + if cached_filepath != filepath: + image = cv2_imread( filepath ) + if image is None: + self.log_err (f'Failed to open {filepath}, reason: cv2_imread() fail.') + return data + image = imagelib.normalize_channels(image, 3) + image = imagelib.cut_odd_image(image) + self.cached_image = ( filepath, image ) + + h, w, c = image.shape + + if 'rects' in self.type or self.type == 'all': + data = ExtractSubprocessor.Cli.rects_stage (data=data, + image=image, + max_faces_from_image=self.max_faces_from_image, + rects_extractor=self.rects_extractor, + ) + + if 'landmarks' in self.type or self.type == 'all': + data = ExtractSubprocessor.Cli.landmarks_stage (data=data, + image=image, + landmarks_extractor=self.landmarks_extractor, + rects_extractor=self.rects_extractor, + ) + + if self.type == 'final' or self.type == 'all': + data = ExtractSubprocessor.Cli.final_stage(data=data, + image=image, + face_type=self.face_type, + image_size=self.image_size, + jpeg_quality=self.jpeg_quality, + output_debug_path=self.output_debug_path, + final_output_path=self.final_output_path, + ) + return data + + @staticmethod + def rects_stage(data, + image, + max_faces_from_image, + rects_extractor, + ): + h,w,c = image.shape + if min(h,w) < 128: + # Image is too small + data.rects = [] + else: + for rot in ([0, 90, 270, 180]): + if rot == 0: + rotated_image = image + elif rot == 90: + rotated_image = image.swapaxes( 0,1 )[:,::-1,:] + elif rot == 180: + rotated_image = image[::-1,::-1,:] + elif rot == 270: + rotated_image = image.swapaxes( 0,1 )[::-1,:,:] + rects = data.rects = rects_extractor.extract (rotated_image, is_bgr=True) + if len(rects) != 0: + data.rects_rotation = rot + break + if max_faces_from_image is not None and \ + max_faces_from_image > 0 and \ + len(data.rects) > 0: + data.rects = data.rects[0:max_faces_from_image] + return data + + + @staticmethod + def landmarks_stage(data, + image, + landmarks_extractor, + rects_extractor, + ): + h, w, ch = image.shape + + if data.rects_rotation == 0: + rotated_image = image + elif data.rects_rotation == 90: + rotated_image = image.swapaxes( 0,1 )[:,::-1,:] + elif data.rects_rotation == 180: + rotated_image = image[::-1,::-1,:] + elif data.rects_rotation == 270: + rotated_image = image.swapaxes( 0,1 )[::-1,:,:] + + data.landmarks = landmarks_extractor.extract (rotated_image, data.rects, rects_extractor if (data.landmarks_accurate) else None, is_bgr=True) + if data.rects_rotation != 0: + for i, (rect, lmrks) in enumerate(zip(data.rects, data.landmarks)): + new_rect, new_lmrks = rect, lmrks + (l,t,r,b) = rect + if data.rects_rotation == 90: + new_rect = ( t, h-l, b, h-r) + if lmrks is not None: + new_lmrks = lmrks[:,::-1].copy() + new_lmrks[:,1] = h - new_lmrks[:,1] + elif data.rects_rotation == 180: + if lmrks is not None: + new_rect = ( w-l, h-t, w-r, h-b) + new_lmrks = lmrks.copy() + new_lmrks[:,0] = w - new_lmrks[:,0] + new_lmrks[:,1] = h - new_lmrks[:,1] + elif data.rects_rotation == 270: + new_rect = ( w-b, l, w-t, r ) + if lmrks is not None: + new_lmrks = lmrks[:,::-1].copy() + new_lmrks[:,0] = w - new_lmrks[:,0] + data.rects[i], data.landmarks[i] = new_rect, new_lmrks + + return data + + @staticmethod + def final_stage(data, + image, + face_type, + image_size, + jpeg_quality, + output_debug_path=None, + final_output_path=None, + ): + data.final_output_files = [] + filepath = data.filepath + rects = data.rects + landmarks = data.landmarks + + if output_debug_path is not None: + debug_image = image.copy() + + face_idx = 0 + for rect, image_landmarks in zip( rects, landmarks ): + if image_landmarks is None: + continue + + rect = np.array(rect) + + if face_type == FaceType.MARK_ONLY: + image_to_face_mat = None + face_image = image + face_image_landmarks = image_landmarks + else: + image_to_face_mat = LandmarksProcessor.get_transform_mat (image_landmarks, image_size, face_type) + + face_image = cv2.warpAffine(image, image_to_face_mat, (image_size, image_size), cv2.INTER_LANCZOS4) + face_image_landmarks = LandmarksProcessor.transform_points (image_landmarks, image_to_face_mat) + + landmarks_bbox = LandmarksProcessor.transform_points ( [ (0,0), (0,image_size-1), (image_size-1, image_size-1), (image_size-1,0) ], image_to_face_mat, True) + + rect_area = mathlib.polygon_area(np.array(rect[[0,2,2,0]]).astype(np.float32), np.array(rect[[1,1,3,3]]).astype(np.float32)) + landmarks_area = mathlib.polygon_area(landmarks_bbox[:,0].astype(np.float32), landmarks_bbox[:,1].astype(np.float32) ) + + if not data.manual and face_type <= FaceType.FULL_NO_ALIGN and landmarks_area > 4*rect_area: #get rid of faces which umeyama-landmark-area > 4*detector-rect-area + continue + + if output_debug_path is not None: + LandmarksProcessor.draw_rect_landmarks (debug_image, rect, image_landmarks, face_type, image_size, transparent_mask=True) + + output_path = final_output_path + if data.force_output_path is not None: + output_path = data.force_output_path + + output_filepath = output_path / f"{filepath.stem}_{face_idx}.jpg" + cv2_imwrite(output_filepath, face_image, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality ] ) + + dflimg = DFLJPG.load(output_filepath) + dflimg.set_face_type(FaceType.toString(face_type)) + dflimg.set_landmarks(face_image_landmarks.tolist()) + dflimg.set_source_filename(filepath.name) + dflimg.set_source_rect(rect) + dflimg.set_source_landmarks(image_landmarks.tolist()) + dflimg.set_image_to_face_mat(image_to_face_mat) + dflimg.save() + + data.final_output_files.append (output_filepath) + face_idx += 1 + data.faces_detected = face_idx + + if output_debug_path is not None: + cv2_imwrite( output_debug_path / (filepath.stem+'.jpg'), debug_image, [int(cv2.IMWRITE_JPEG_QUALITY), 50] ) + + return data + + #overridable + def get_data_name (self, data): + #return string identificator of your data + return data.filepath + + @staticmethod + def get_devices_for_config (type, device_config): + devices = device_config.devices + cpu_only = len(devices) == 0 + + if 'rects' in type or \ + 'landmarks' in type or \ + 'all' in type: + + if not cpu_only: + if type == 'landmarks-manual': + devices = [devices.get_best_device()] + + result = [] + + for device in devices: + count = 1 + + if count == 1: + result += [ (device.index, 'GPU', device.name, device.total_mem_gb) ] + else: + for i in range(count): + result += [ (device.index, 'GPU', f"{device.name} #{i}", device.total_mem_gb) ] + + return result + else: + if type == 'landmarks-manual': + return [ (0, 'CPU', 'CPU', 0 ) ] + else: + return [ (i, 'CPU', 'CPU%d' % (i), 0 ) for i in range( min(8, multiprocessing.cpu_count() // 2) ) ] + + elif type == 'final': + return [ (i, 'CPU', 'CPU%d' % (i), 0 ) for i in (range(min(8, multiprocessing.cpu_count())) if not DEBUG else [0]) ] + + def __init__(self, input_data, type, image_size=None, jpeg_quality=None, face_type=None, output_debug_path=None, manual_window_size=0, max_faces_from_image=0, final_output_path=None, device_config=None): + if type == 'landmarks-manual': + for x in input_data: + x.manual = True + + self.input_data = input_data + + self.type = type + self.image_size = image_size + self.jpeg_quality = jpeg_quality + self.face_type = face_type + self.output_debug_path = output_debug_path + self.final_output_path = final_output_path + self.manual_window_size = manual_window_size + self.max_faces_from_image = max_faces_from_image + self.result = [] + + self.devices = ExtractSubprocessor.get_devices_for_config(self.type, device_config) + + super().__init__('Extractor', ExtractSubprocessor.Cli, + 999999 if type == 'landmarks-manual' or DEBUG else 120) + + #override + def on_clients_initialized(self): + if self.type == 'landmarks-manual': + self.wnd_name = 'Manual pass' + io.named_window(self.wnd_name) + io.capture_mouse(self.wnd_name) + io.capture_keys(self.wnd_name) + + self.cache_original_image = (None, None) + self.cache_image = (None, None) + self.cache_text_lines_img = (None, None) + self.hide_help = False + self.landmarks_accurate = True + self.force_landmarks = False + + self.landmarks = None + self.x = 0 + self.y = 0 + self.rect_size = 100 + self.rect_locked = False + self.extract_needed = True + + self.image = None + self.image_filepath = None + + io.progress_bar (None, len (self.input_data)) + + #override + def on_clients_finalized(self): + if self.type == 'landmarks-manual': + io.destroy_all_windows() + + io.progress_bar_close() + + #override + def process_info_generator(self): + base_dict = {'type' : self.type, + 'image_size': self.image_size, + 'jpeg_quality' : self.jpeg_quality, + 'face_type': self.face_type, + 'max_faces_from_image':self.max_faces_from_image, + 'output_debug_path': self.output_debug_path, + 'final_output_path': self.final_output_path, + 'stdin_fd': sys.stdin.fileno() } + + + for (device_idx, device_type, device_name, device_total_vram_gb) in self.devices: + client_dict = base_dict.copy() + client_dict['device_idx'] = device_idx + client_dict['device_name'] = device_name + client_dict['device_type'] = device_type + yield client_dict['device_name'], {}, client_dict + + #override + def get_data(self, host_dict): + if self.type == 'landmarks-manual': + need_remark_face = False + while len (self.input_data) > 0: + data = self.input_data[0] + filepath, data_rects, data_landmarks = data.filepath, data.rects, data.landmarks + is_frame_done = False + + if self.image_filepath != filepath: + self.image_filepath = filepath + if self.cache_original_image[0] == filepath: + self.original_image = self.cache_original_image[1] + else: + self.original_image = imagelib.normalize_channels( cv2_imread( filepath ), 3 ) + + self.cache_original_image = (filepath, self.original_image ) + + (h,w,c) = self.original_image.shape + self.view_scale = 1.0 if self.manual_window_size == 0 else self.manual_window_size / ( h * (16.0/9.0) ) + + if self.cache_image[0] == (h,w,c) + (self.view_scale,filepath): + self.image = self.cache_image[1] + else: + self.image = cv2.resize (self.original_image, ( int(w*self.view_scale), int(h*self.view_scale) ), interpolation=cv2.INTER_LINEAR) + self.cache_image = ( (h,w,c) + (self.view_scale,filepath), self.image ) + + (h,w,c) = self.image.shape + + sh = (0,0, w, min(100, h) ) + if self.cache_text_lines_img[0] == sh: + self.text_lines_img = self.cache_text_lines_img[1] + else: + self.text_lines_img = (imagelib.get_draw_text_lines ( self.image, sh, + [ '[L Mouse click] - lock/unlock selection. [Mouse wheel] - change rect', + '[R Mouse Click] - manual face rectangle', + '[Enter] / [Space] - confirm / skip frame', + '[,] [.]- prev frame, next frame. [Q] - skip remaining frames', + '[a] - accuracy on/off (more fps)', + '[h] - hide this help' + ], (1, 1, 1) )*255).astype(np.uint8) + + self.cache_text_lines_img = (sh, self.text_lines_img) + + if need_remark_face: # need remark image from input data that already has a marked face? + need_remark_face = False + if len(data_rects) != 0: # If there was already a face then lock the rectangle to it until the mouse is clicked + self.rect = data_rects.pop() + self.landmarks = data_landmarks.pop() + data_rects.clear() + data_landmarks.clear() + + self.rect_locked = True + self.rect_size = ( self.rect[2] - self.rect[0] ) / 2 + self.x = ( self.rect[0] + self.rect[2] ) / 2 + self.y = ( self.rect[1] + self.rect[3] ) / 2 + self.redraw() + + if len(data_rects) == 0: + (h,w,c) = self.image.shape + while True: + io.process_messages(0.0001) + + if not self.force_landmarks: + new_x = self.x + new_y = self.y + + new_rect_size = self.rect_size + + mouse_events = io.get_mouse_events(self.wnd_name) + for ev in mouse_events: + (x, y, ev, flags) = ev + if ev == io.EVENT_MOUSEWHEEL and not self.rect_locked: + mod = 1 if flags > 0 else -1 + diff = 1 if new_rect_size <= 40 else np.clip(new_rect_size / 10, 1, 10) + new_rect_size = max (5, new_rect_size + diff*mod) + elif ev == io.EVENT_LBUTTONDOWN: + if self.force_landmarks: + self.x = new_x + self.y = new_y + self.force_landmarks = False + self.rect_locked = True + self.redraw() + else: + self.rect_locked = not self.rect_locked + self.extract_needed = True + elif ev == io.EVENT_RBUTTONDOWN: + self.force_landmarks = not self.force_landmarks + if self.force_landmarks: + self.rect_locked = False + elif not self.rect_locked: + new_x = np.clip (x, 0, w-1) / self.view_scale + new_y = np.clip (y, 0, h-1) / self.view_scale + + key_events = io.get_key_events(self.wnd_name) + key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) + + if key == ord('\r') or key == ord('\n'): + #confirm frame + is_frame_done = True + data_rects.append (self.rect) + data_landmarks.append (self.landmarks) + break + elif key == ord(' '): + #confirm skip frame + is_frame_done = True + break + elif key == ord(',') and len(self.result) > 0: + #go prev frame + + if self.rect_locked: + self.rect_locked = False + # Only save the face if the rect is still locked + data_rects.append (self.rect) + data_landmarks.append (self.landmarks) + + + self.input_data.insert(0, self.result.pop() ) + io.progress_bar_inc(-1) + need_remark_face = True + + break + elif key == ord('.'): + #go next frame + + if self.rect_locked: + self.rect_locked = False + # Only save the face if the rect is still locked + data_rects.append (self.rect) + data_landmarks.append (self.landmarks) + + need_remark_face = True + is_frame_done = True + break + elif key == ord('q'): + #skip remaining + + if self.rect_locked: + self.rect_locked = False + data_rects.append (self.rect) + data_landmarks.append (self.landmarks) + + while len(self.input_data) > 0: + self.result.append( self.input_data.pop(0) ) + io.progress_bar_inc(1) + + break + + elif key == ord('h'): + self.hide_help = not self.hide_help + break + elif key == ord('a'): + self.landmarks_accurate = not self.landmarks_accurate + break + + if self.force_landmarks: + pt2 = np.float32([new_x, new_y]) + pt1 = np.float32([self.x, self.y]) + + pt_vec_len = npla.norm(pt2-pt1) + pt_vec = pt2-pt1 + if pt_vec_len != 0: + pt_vec /= pt_vec_len + + self.rect_size = pt_vec_len + self.rect = ( int(self.x-self.rect_size), + int(self.y-self.rect_size), + int(self.x+self.rect_size), + int(self.y+self.rect_size) ) + + if pt_vec_len > 0: + lmrks = np.concatenate ( (np.zeros ((17,2), np.float32), LandmarksProcessor.landmarks_2D), axis=0 ) + lmrks -= lmrks[30:31,:] + mat = cv2.getRotationMatrix2D( (0, 0), -np.arctan2( pt_vec[1], pt_vec[0] )*180/math.pi , pt_vec_len) + mat[:, 2] += (self.x, self.y) + self.landmarks = LandmarksProcessor.transform_points(lmrks, mat ) + + + self.redraw() + + elif self.x != new_x or \ + self.y != new_y or \ + self.rect_size != new_rect_size or \ + self.extract_needed: + self.x = new_x + self.y = new_y + self.rect_size = new_rect_size + self.rect = ( int(self.x-self.rect_size), + int(self.y-self.rect_size), + int(self.x+self.rect_size), + int(self.y+self.rect_size) ) + + return ExtractSubprocessor.Data (filepath, rects=[self.rect], landmarks_accurate=self.landmarks_accurate) + + else: + is_frame_done = True + + if is_frame_done: + self.result.append ( data ) + self.input_data.pop(0) + io.progress_bar_inc(1) + self.extract_needed = True + self.rect_locked = False + else: + if len (self.input_data) > 0: + return self.input_data.pop(0) + + return None + + #override + def on_data_return (self, host_dict, data): + if not self.type != 'landmarks-manual': + self.input_data.insert(0, data) + + def redraw(self): + (h,w,c) = self.image.shape + + if not self.hide_help: + image = cv2.addWeighted (self.image,1.0,self.text_lines_img,1.0,0) + else: + image = self.image.copy() + + view_rect = (np.array(self.rect) * self.view_scale).astype(np.int).tolist() + view_landmarks = (np.array(self.landmarks) * self.view_scale).astype(np.int).tolist() + + if self.rect_size <= 40: + scaled_rect_size = h // 3 if w > h else w // 3 + + p1 = (self.x - self.rect_size, self.y - self.rect_size) + p2 = (self.x + self.rect_size, self.y - self.rect_size) + p3 = (self.x - self.rect_size, self.y + self.rect_size) + + wh = h if h < w else w + np1 = (w / 2 - wh / 4, h / 2 - wh / 4) + np2 = (w / 2 + wh / 4, h / 2 - wh / 4) + np3 = (w / 2 - wh / 4, h / 2 + wh / 4) + + mat = cv2.getAffineTransform( np.float32([p1,p2,p3])*self.view_scale, np.float32([np1,np2,np3]) ) + image = cv2.warpAffine(image, mat,(w,h) ) + view_landmarks = LandmarksProcessor.transform_points (view_landmarks, mat) + + landmarks_color = (255,255,0) if self.rect_locked else (0,255,0) + LandmarksProcessor.draw_rect_landmarks (image, view_rect, view_landmarks, self.face_type, self.image_size, landmarks_color=landmarks_color) + self.extract_needed = False + + io.show_image (self.wnd_name, image) + + + #override + def on_result (self, host_dict, data, result): + if self.type == 'landmarks-manual': + filepath, landmarks = result.filepath, result.landmarks + + if len(landmarks) != 0 and landmarks[0] is not None: + self.landmarks = landmarks[0] + + self.redraw() + else: + self.result.append ( result ) + io.progress_bar_inc(1) + + + + #override + def get_result(self): + return self.result + + +class DeletedFilesSearcherSubprocessor(Subprocessor): + class Cli(Subprocessor.Cli): + #override + def on_initialize(self, client_dict): + self.debug_paths_stems = client_dict['debug_paths_stems'] + return None + + #override + def process_data(self, data): + input_path_stem = Path(data[0]).stem + return any ( [ input_path_stem == d_stem for d_stem in self.debug_paths_stems] ) + + #override + def get_data_name (self, data): + #return string identificator of your data + return data[0] + + #override + def __init__(self, input_paths, debug_paths ): + self.input_paths = input_paths + self.debug_paths_stems = [ Path(d).stem for d in debug_paths] + self.result = [] + super().__init__('DeletedFilesSearcherSubprocessor', DeletedFilesSearcherSubprocessor.Cli, 60) + + #override + def process_info_generator(self): + for i in range(min(multiprocessing.cpu_count(), 8)): + yield 'CPU%d' % (i), {}, {'debug_paths_stems' : self.debug_paths_stems} + + #override + def on_clients_initialized(self): + io.progress_bar ("Searching deleted files", len (self.input_paths)) + + #override + def on_clients_finalized(self): + io.progress_bar_close() + + #override + def get_data(self, host_dict): + if len (self.input_paths) > 0: + return [self.input_paths.pop(0)] + return None + + #override + def on_data_return (self, host_dict, data): + self.input_paths.insert(0, data[0]) + + #override + def on_result (self, host_dict, data, result): + if result == False: + self.result.append( data[0] ) + io.progress_bar_inc(1) + + #override + def get_result(self): + return self.result + +def main(detector=None, + input_path=None, + output_path=None, + output_debug=None, + manual_fix=False, + manual_output_debug_fix=False, + manual_window_size=1368, + face_type='full_face', + max_faces_from_image=None, + image_size=None, + jpeg_quality=None, + cpu_only = False, + force_gpu_idxs = None, + ): + + if not input_path.exists(): + io.log_err ('Input directory not found. Please ensure it exists.') + return + + if not output_path.exists(): + output_path.mkdir(parents=True, exist_ok=True) + + if face_type is not None: + face_type = FaceType.fromString(face_type) + + if face_type is None: + if manual_output_debug_fix: + files = pathex.get_image_paths(output_path) + if len(files) != 0: + dflimg = DFLIMG.load(Path(files[0])) + if dflimg is not None and dflimg.has_data(): + face_type = FaceType.fromString ( dflimg.get_face_type() ) + + input_image_paths = pathex.get_image_unique_filestem_paths(input_path, verbose_print_func=io.log_info) + output_images_paths = pathex.get_image_paths(output_path) + output_debug_path = output_path.parent / (output_path.name + '_debug') + + continue_extraction = False + if not manual_output_debug_fix and len(output_images_paths) > 0: + if len(output_images_paths) > 128: + continue_extraction = io.input_bool ("Continue extraction?", True, help_message="Extraction can be continued, but you must specify the same options again.") + + if len(output_images_paths) > 128 and continue_extraction: + try: + input_image_paths = input_image_paths[ [ Path(x).stem for x in input_image_paths ].index ( Path(output_images_paths[-128]).stem.split('_')[0] ) : ] + except: + io.log_err("Error in fetching the last index. Extraction cannot be continued.") + return + elif input_path != output_path: + io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n") + for filename in output_images_paths: + Path(filename).unlink() + + device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(choose_only_one=detector=='manual', suggest_all_gpu=True) ) \ + if not cpu_only else nn.DeviceConfig.CPU() + + if face_type is None: + face_type = io.input_str ("Face type", 'wf', ['f','wf','head'], help_message="Full face / whole face / head. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower() + face_type = {'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[face_type] + + if max_faces_from_image is None: + max_faces_from_image = io.input_int(f"Max number of faces from image", 0, help_message="If you extract a src faceset that has frames with a large number of faces, it is advisable to set max faces to 3 to speed up extraction. 0 - unlimited") + + if image_size is None: + image_size = io.input_int(f"Image size", 512 if face_type < FaceType.HEAD else 768, valid_range=[256,2048], help_message="Output image size. The higher image size, the worse face-enhancer works. Use higher than 512 value only if the source image is sharp enough and the face does not need to be enhanced.") + + if jpeg_quality is None: + jpeg_quality = io.input_int(f"Jpeg quality", 90, valid_range=[1,100], help_message="Jpeg quality. The higher jpeg quality the larger the output file size.") + + if detector is None: + io.log_info ("Choose detector type.") + io.log_info ("[0] S3FD") + io.log_info ("[1] manual") + detector = {0:'s3fd', 1:'manual'}[ io.input_int("", 0, [0,1]) ] + + + if output_debug is None: + output_debug = io.input_bool (f"Write debug images to {output_debug_path.name}?", False) + + if output_debug: + output_debug_path.mkdir(parents=True, exist_ok=True) + + if manual_output_debug_fix: + if not output_debug_path.exists(): + io.log_err(f'{output_debug_path} not found. Re-extract faces with "Write debug images" option.') + return + else: + detector = 'manual' + io.log_info('Performing re-extract frames which were deleted from _debug directory.') + + input_image_paths = DeletedFilesSearcherSubprocessor (input_image_paths, pathex.get_image_paths(output_debug_path) ).run() + input_image_paths = sorted (input_image_paths) + io.log_info('Found %d images.' % (len(input_image_paths))) + else: + if not continue_extraction and output_debug_path.exists(): + for filename in pathex.get_image_paths(output_debug_path): + Path(filename).unlink() + + images_found = len(input_image_paths) + faces_detected = 0 + if images_found != 0: + if detector == 'manual': + io.log_info ('Performing manual extract...') + data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_image_paths ], 'landmarks-manual', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run() + + io.log_info ('Performing 3rd pass...') + data = ExtractSubprocessor (data, 'final', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run() + + else: + io.log_info ('Extracting faces...') + data = ExtractSubprocessor ([ ExtractSubprocessor.Data(Path(filename)) for filename in input_image_paths ], + 'all', + image_size, + jpeg_quality, + face_type, + output_debug_path if output_debug else None, + max_faces_from_image=max_faces_from_image, + final_output_path=output_path, + device_config=device_config).run() + + faces_detected += sum([d.faces_detected for d in data]) + + if manual_fix: + if all ( np.array ( [ d.faces_detected > 0 for d in data] ) == True ): + io.log_info ('All faces are detected, manual fix not needed.') + else: + fix_data = [ ExtractSubprocessor.Data(d.filepath) for d in data if d.faces_detected == 0 ] + io.log_info ('Performing manual fix for %d images...' % (len(fix_data)) ) + fix_data = ExtractSubprocessor (fix_data, 'landmarks-manual', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run() + fix_data = ExtractSubprocessor (fix_data, 'final', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, final_output_path=output_path, device_config=device_config).run() + faces_detected += sum([d.faces_detected for d in fix_data]) + + + io.log_info ('-------------------------') + io.log_info ('Images found: %d' % (images_found) ) + io.log_info ('Faces detected: %d' % (faces_detected) ) + io.log_info ('-------------------------') diff --git a/mainscripts/FacesetEnhancer.py b/mainscripts/FacesetEnhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..3de9cea45a3a2f5388e7e5b31dfc78472a14f383 --- /dev/null +++ b/mainscripts/FacesetEnhancer.py @@ -0,0 +1,156 @@ +import multiprocessing +import shutil + +from DFLIMG import * +from core.interact import interact as io +from core.joblib import Subprocessor +from core.leras import nn +from core import pathex +from core.cv2ex import * + + +class FacesetEnhancerSubprocessor(Subprocessor): + + #override + def __init__(self, image_paths, output_dirpath, device_config): + self.image_paths = image_paths + self.output_dirpath = output_dirpath + self.result = [] + self.nn_initialize_mp_lock = multiprocessing.Lock() + self.devices = FacesetEnhancerSubprocessor.get_devices_for_config(device_config) + + super().__init__('FacesetEnhancer', FacesetEnhancerSubprocessor.Cli, 600) + + #override + def on_clients_initialized(self): + io.progress_bar (None, len (self.image_paths)) + + #override + def on_clients_finalized(self): + io.progress_bar_close() + + #override + def process_info_generator(self): + base_dict = {'output_dirpath':self.output_dirpath, + 'nn_initialize_mp_lock': self.nn_initialize_mp_lock,} + + for (device_idx, device_type, device_name, device_total_vram_gb) in self.devices: + client_dict = base_dict.copy() + client_dict['device_idx'] = device_idx + client_dict['device_name'] = device_name + client_dict['device_type'] = device_type + yield client_dict['device_name'], {}, client_dict + + #override + def get_data(self, host_dict): + if len (self.image_paths) > 0: + return self.image_paths.pop(0) + + #override + def on_data_return (self, host_dict, data): + self.image_paths.insert(0, data) + + #override + def on_result (self, host_dict, data, result): + io.progress_bar_inc(1) + if result[0] == 1: + self.result +=[ (result[1], result[2]) ] + + #override + def get_result(self): + return self.result + + @staticmethod + def get_devices_for_config (device_config): + devices = device_config.devices + cpu_only = len(devices) == 0 + + if not cpu_only: + return [ (device.index, 'GPU', device.name, device.total_mem_gb) for device in devices ] + else: + return [ (i, 'CPU', 'CPU%d' % (i), 0 ) for i in range( min(8, multiprocessing.cpu_count() // 2) ) ] + + class Cli(Subprocessor.Cli): + + #override + def on_initialize(self, client_dict): + device_idx = client_dict['device_idx'] + cpu_only = client_dict['device_type'] == 'CPU' + self.output_dirpath = client_dict['output_dirpath'] + nn_initialize_mp_lock = client_dict['nn_initialize_mp_lock'] + + if cpu_only: + device_config = nn.DeviceConfig.CPU() + device_vram = 99 + else: + device_config = nn.DeviceConfig.GPUIndexes ([device_idx]) + device_vram = device_config.devices[0].total_mem_gb + + nn.initialize (device_config) + + intro_str = 'Running on %s.' % (client_dict['device_name']) + + self.log_info (intro_str) + + from facelib import FaceEnhancer + self.fe = FaceEnhancer( place_model_on_cpu=(device_vram<=2 or cpu_only), run_on_cpu=cpu_only ) + + #override + def process_data(self, filepath): + try: + dflimg = DFLIMG.load (filepath) + if dflimg is None or not dflimg.has_data(): + self.log_err (f"{filepath.name} is not a dfl image file") + else: + dfl_dict = dflimg.get_dict() + + img = cv2_imread(filepath).astype(np.float32) / 255.0 + img = self.fe.enhance(img) + img = np.clip (img*255, 0, 255).astype(np.uint8) + + output_filepath = self.output_dirpath / filepath.name + + cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) + + dflimg = DFLIMG.load (output_filepath) + dflimg.set_dict(dfl_dict) + dflimg.save() + + return (1, filepath, output_filepath) + except: + self.log_err (f"Exception occured while processing file {filepath}. Error: {traceback.format_exc()}") + + return (0, filepath, None) + +def process_folder ( dirpath, cpu_only=False, force_gpu_idxs=None ): + device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_all_gpu=True) ) \ + if not cpu_only else nn.DeviceConfig.CPU() + + output_dirpath = dirpath.parent / (dirpath.name + '_enhanced') + output_dirpath.mkdir (exist_ok=True, parents=True) + + dirpath_parts = '/'.join( dirpath.parts[-2:]) + output_dirpath_parts = '/'.join( output_dirpath.parts[-2:] ) + io.log_info (f"Enhancing faceset in {dirpath_parts}") + io.log_info ( f"Processing to {output_dirpath_parts}") + + output_images_paths = pathex.get_image_paths(output_dirpath) + if len(output_images_paths) > 0: + for filename in output_images_paths: + Path(filename).unlink() + + image_paths = [Path(x) for x in pathex.get_image_paths( dirpath )] + result = FacesetEnhancerSubprocessor ( image_paths, output_dirpath, device_config=device_config).run() + + is_merge = io.input_bool (f"\r\nMerge {output_dirpath_parts} to {dirpath_parts} ?", True) + if is_merge: + io.log_info (f"Copying processed files to {dirpath_parts}") + + for (filepath, output_filepath) in result: + try: + shutil.copy (output_filepath, filepath) + except: + pass + + io.log_info (f"Removing {output_dirpath_parts}") + shutil.rmtree(output_dirpath) diff --git a/mainscripts/FacesetResizer.py b/mainscripts/FacesetResizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcd1b86123dfd07d43d0f798b273a5a9a1fca45 --- /dev/null +++ b/mainscripts/FacesetResizer.py @@ -0,0 +1,209 @@ +import multiprocessing +import shutil + +import cv2 +from core import pathex +from core.cv2ex import * +from core.interact import interact as io +from core.joblib import Subprocessor +from DFLIMG import * +from facelib import FaceType, LandmarksProcessor + + +class FacesetResizerSubprocessor(Subprocessor): + + #override + def __init__(self, image_paths, output_dirpath, image_size, face_type=None): + self.image_paths = image_paths + self.output_dirpath = output_dirpath + self.image_size = image_size + self.face_type = face_type + self.result = [] + + super().__init__('FacesetResizer', FacesetResizerSubprocessor.Cli, 600) + + #override + def on_clients_initialized(self): + io.progress_bar (None, len (self.image_paths)) + + #override + def on_clients_finalized(self): + io.progress_bar_close() + + #override + def process_info_generator(self): + base_dict = {'output_dirpath':self.output_dirpath, 'image_size':self.image_size, 'face_type':self.face_type} + + for device_idx in range( min(8, multiprocessing.cpu_count()) ): + client_dict = base_dict.copy() + device_name = f'CPU #{device_idx}' + client_dict['device_name'] = device_name + yield device_name, {}, client_dict + + #override + def get_data(self, host_dict): + if len (self.image_paths) > 0: + return self.image_paths.pop(0) + + #override + def on_data_return (self, host_dict, data): + self.image_paths.insert(0, data) + + #override + def on_result (self, host_dict, data, result): + io.progress_bar_inc(1) + if result[0] == 1: + self.result +=[ (result[1], result[2]) ] + + #override + def get_result(self): + return self.result + + class Cli(Subprocessor.Cli): + + #override + def on_initialize(self, client_dict): + self.output_dirpath = client_dict['output_dirpath'] + self.image_size = client_dict['image_size'] + self.face_type = client_dict['face_type'] + self.log_info (f"Running on { client_dict['device_name'] }") + + #override + def process_data(self, filepath): + try: + dflimg = DFLIMG.load (filepath) + if dflimg is None or not dflimg.has_data(): + self.log_err (f"{filepath.name} is not a dfl image file") + else: + img = cv2_imread(filepath) + h,w = img.shape[:2] + if h != w: + raise Exception(f'w != h in {filepath}') + + image_size = self.image_size + face_type = self.face_type + output_filepath = self.output_dirpath / filepath.name + + if face_type is not None: + lmrks = dflimg.get_landmarks() + mat = LandmarksProcessor.get_transform_mat(lmrks, image_size, face_type) + + img = cv2.warpAffine(img, mat, (image_size, image_size), flags=cv2.INTER_LANCZOS4 ) + img = np.clip(img, 0, 255).astype(np.uint8) + + cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) + + dfl_dict = dflimg.get_dict() + dflimg = DFLIMG.load (output_filepath) + dflimg.set_dict(dfl_dict) + + xseg_mask = dflimg.get_xseg_mask() + if xseg_mask is not None: + xseg_res = 256 + + xseg_lmrks = lmrks.copy() + xseg_lmrks *= (xseg_res / w) + xseg_mat = LandmarksProcessor.get_transform_mat(xseg_lmrks, xseg_res, face_type) + + xseg_mask = cv2.warpAffine(xseg_mask, xseg_mat, (xseg_res, xseg_res), flags=cv2.INTER_LANCZOS4 ) + xseg_mask[xseg_mask < 0.5] = 0 + xseg_mask[xseg_mask >= 0.5] = 1 + + dflimg.set_xseg_mask(xseg_mask) + + seg_ie_polys = dflimg.get_seg_ie_polys() + + for poly in seg_ie_polys.get_polys(): + poly_pts = poly.get_pts() + poly_pts = LandmarksProcessor.transform_points(poly_pts, mat) + poly.set_points(poly_pts) + + dflimg.set_seg_ie_polys(seg_ie_polys) + + lmrks = LandmarksProcessor.transform_points(lmrks, mat) + dflimg.set_landmarks(lmrks) + + image_to_face_mat = dflimg.get_image_to_face_mat() + if image_to_face_mat is not None: + image_to_face_mat = LandmarksProcessor.get_transform_mat ( dflimg.get_source_landmarks(), image_size, face_type ) + dflimg.set_image_to_face_mat(image_to_face_mat) + dflimg.set_face_type( FaceType.toString(face_type) ) + dflimg.save() + + else: + dfl_dict = dflimg.get_dict() + + scale = w / image_size + + img = cv2.resize(img, (image_size, image_size), interpolation=cv2.INTER_LANCZOS4) + + cv2_imwrite ( str(output_filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) + + dflimg = DFLIMG.load (output_filepath) + dflimg.set_dict(dfl_dict) + + lmrks = dflimg.get_landmarks() + lmrks /= scale + dflimg.set_landmarks(lmrks) + + seg_ie_polys = dflimg.get_seg_ie_polys() + seg_ie_polys.mult_points( 1.0 / scale) + dflimg.set_seg_ie_polys(seg_ie_polys) + + image_to_face_mat = dflimg.get_image_to_face_mat() + + if image_to_face_mat is not None: + face_type = FaceType.fromString ( dflimg.get_face_type() ) + image_to_face_mat = LandmarksProcessor.get_transform_mat ( dflimg.get_source_landmarks(), image_size, face_type ) + dflimg.set_image_to_face_mat(image_to_face_mat) + dflimg.save() + + return (1, filepath, output_filepath) + except: + self.log_err (f"Exception occured while processing file {filepath}. Error: {traceback.format_exc()}") + + return (0, filepath, None) + +def process_folder ( dirpath): + + image_size = io.input_int(f"New image size", 512, valid_range=[128,2048]) + + face_type = io.input_str ("Change face type", 'same', ['h','mf','f','wf','head','same']).lower() + if face_type == 'same': + face_type = None + else: + face_type = {'h' : FaceType.HALF, + 'mf' : FaceType.MID_FULL, + 'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[face_type] + + + output_dirpath = dirpath.parent / (dirpath.name + '_resized') + output_dirpath.mkdir (exist_ok=True, parents=True) + + dirpath_parts = '/'.join( dirpath.parts[-2:]) + output_dirpath_parts = '/'.join( output_dirpath.parts[-2:] ) + io.log_info (f"Resizing faceset in {dirpath_parts}") + io.log_info ( f"Processing to {output_dirpath_parts}") + + output_images_paths = pathex.get_image_paths(output_dirpath) + if len(output_images_paths) > 0: + for filename in output_images_paths: + Path(filename).unlink() + + image_paths = [Path(x) for x in pathex.get_image_paths( dirpath )] + result = FacesetResizerSubprocessor ( image_paths, output_dirpath, image_size, face_type).run() + + is_merge = io.input_bool (f"\r\nMerge {output_dirpath_parts} to {dirpath_parts} ?", True) + if is_merge: + io.log_info (f"Copying processed files to {dirpath_parts}") + + for (filepath, output_filepath) in result: + try: + shutil.copy (output_filepath, filepath) + except: + pass + + io.log_info (f"Removing {output_dirpath_parts}") + shutil.rmtree(output_dirpath) diff --git a/mainscripts/Merger.py b/mainscripts/Merger.py new file mode 100644 index 0000000000000000000000000000000000000000..0703dc1dbc358ff6f2b24b0febc5a19fd1e60638 --- /dev/null +++ b/mainscripts/Merger.py @@ -0,0 +1,281 @@ +import math +import multiprocessing +import traceback +from pathlib import Path + +import numpy as np +import numpy.linalg as npla + +import samplelib +from core import pathex +from core.cv2ex import * +from core.interact import interact as io +from core.joblib import MPClassFuncOnDemand, MPFunc +from core.leras import nn +from DFLIMG import DFLIMG +from facelib import FaceEnhancer, FaceType, LandmarksProcessor, XSegNet +from merger import FrameInfo, InteractiveMergerSubprocessor, MergerConfig + + +def main (model_class_name=None, + saved_models_path=None, + training_data_src_path=None, + force_model_name=None, + input_path=None, + output_path=None, + output_mask_path=None, + aligned_path=None, + force_gpu_idxs=None, + cpu_only=None): + io.log_info ("Running merger.\r\n") + + try: + if not input_path.exists(): + io.log_err('Input directory not found. Please ensure it exists.') + return + + if not output_path.exists(): + output_path.mkdir(parents=True, exist_ok=True) + + if not output_mask_path.exists(): + output_mask_path.mkdir(parents=True, exist_ok=True) + + if not saved_models_path.exists(): + io.log_err('Model directory not found. Please ensure it exists.') + return + + # Initialize model + import models + model = models.import_model(model_class_name)(is_training=False, + saved_models_path=saved_models_path, + force_gpu_idxs=force_gpu_idxs, + force_model_name=force_model_name, + cpu_only=cpu_only) + + predictor_func, predictor_input_shape, cfg = model.get_MergerConfig() + + # Preparing MP functions + predictor_func = MPFunc(predictor_func) + + run_on_cpu = len(nn.getCurrentDeviceConfig().devices) == 0 + xseg_256_extract_func = MPClassFuncOnDemand(XSegNet, 'extract', + name='XSeg', + resolution=256, + weights_file_root=saved_models_path, + place_model_on_cpu=True, + run_on_cpu=run_on_cpu) + + face_enhancer_func = MPClassFuncOnDemand(FaceEnhancer, 'enhance', + place_model_on_cpu=True, + run_on_cpu=run_on_cpu) + + is_interactive = io.input_bool ("Use interactive merger?", True) if not io.is_colab() else False + + if not is_interactive: + cfg.ask_settings() + + subprocess_count = io.input_int("Number of workers?", max(8, multiprocessing.cpu_count()), + valid_range=[1, multiprocessing.cpu_count()], help_message="Specify the number of threads to process. A low value may affect performance. A high value may result in memory error. The value may not be greater than CPU cores." ) + + input_path_image_paths = pathex.get_image_paths(input_path) + + if cfg.type == MergerConfig.TYPE_MASKED: + if not aligned_path.exists(): + io.log_err('Aligned directory not found. Please ensure it exists.') + return + + packed_samples = None + try: + packed_samples = samplelib.PackedFaceset.load(aligned_path) + except: + io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(aligned_path)}, {traceback.format_exc()}") + + + if packed_samples is not None: + io.log_info ("Using packed faceset.") + def generator(): + for sample in io.progress_bar_generator( packed_samples, "Collecting alignments"): + filepath = Path(sample.filename) + yield filepath, DFLIMG.load(filepath, loader_func=lambda x: sample.read_raw_file() ) + else: + def generator(): + for filepath in io.progress_bar_generator( pathex.get_image_paths(aligned_path), "Collecting alignments"): + filepath = Path(filepath) + yield filepath, DFLIMG.load(filepath) + + alignments = {} + multiple_faces_detected = False + + for filepath, dflimg in generator(): + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") + continue + + source_filename = dflimg.get_source_filename() + if source_filename is None: + continue + + source_filepath = Path(source_filename) + source_filename_stem = source_filepath.stem + + if source_filename_stem not in alignments.keys(): + alignments[ source_filename_stem ] = [] + + alignments_ar = alignments[ source_filename_stem ] + alignments_ar.append ( (dflimg.get_source_landmarks(), filepath, source_filepath ) ) + + if len(alignments_ar) > 1: + multiple_faces_detected = True + + if multiple_faces_detected: + io.log_info ("") + io.log_info ("Warning: multiple faces detected. Only one alignment file should refer one source file.") + io.log_info ("") + + for a_key in list(alignments.keys()): + a_ar = alignments[a_key] + if len(a_ar) > 1: + for _, filepath, source_filepath in a_ar: + io.log_info (f"alignment {filepath.name} refers to {source_filepath.name} ") + io.log_info ("") + + alignments[a_key] = [ a[0] for a in a_ar] + + if multiple_faces_detected: + io.log_info ("It is strongly recommended to process the faces separatelly.") + io.log_info ("Use 'recover original filename' to determine the exact duplicates.") + io.log_info ("") + + frames = [ InteractiveMergerSubprocessor.Frame( frame_info=FrameInfo(filepath=Path(p), + landmarks_list=alignments.get(Path(p).stem, None) + ) + ) + for p in input_path_image_paths ] + + if multiple_faces_detected: + io.log_info ("Warning: multiple faces detected. Motion blur will not be used.") + io.log_info ("") + else: + s = 256 + local_pts = [ (s//2-1, s//2-1), (s//2-1,0) ] #center+up + frames_len = len(frames) + for i in io.progress_bar_generator( range(len(frames)) , "Computing motion vectors"): + fi_prev = frames[max(0, i-1)].frame_info + fi = frames[i].frame_info + fi_next = frames[min(i+1, frames_len-1)].frame_info + if len(fi_prev.landmarks_list) == 0 or \ + len(fi.landmarks_list) == 0 or \ + len(fi_next.landmarks_list) == 0: + continue + + mat_prev = LandmarksProcessor.get_transform_mat ( fi_prev.landmarks_list[0], s, face_type=FaceType.FULL) + mat = LandmarksProcessor.get_transform_mat ( fi.landmarks_list[0] , s, face_type=FaceType.FULL) + mat_next = LandmarksProcessor.get_transform_mat ( fi_next.landmarks_list[0], s, face_type=FaceType.FULL) + + pts_prev = LandmarksProcessor.transform_points (local_pts, mat_prev, True) + pts = LandmarksProcessor.transform_points (local_pts, mat, True) + pts_next = LandmarksProcessor.transform_points (local_pts, mat_next, True) + + prev_vector = pts[0]-pts_prev[0] + next_vector = pts_next[0]-pts[0] + + motion_vector = pts_next[0] - pts_prev[0] + fi.motion_power = npla.norm(motion_vector) + + motion_vector = motion_vector / fi.motion_power if fi.motion_power != 0 else np.array([0,0],dtype=np.float32) + + fi.motion_deg = -math.atan2(motion_vector[1],motion_vector[0])*180 / math.pi + + + if len(frames) == 0: + io.log_info ("No frames to merge in input_dir.") + else: + if False: + pass + else: + InteractiveMergerSubprocessor ( + is_interactive = is_interactive, + merger_session_filepath = model.get_strpath_storage_for_file('merger_session.dat'), + predictor_func = predictor_func, + predictor_input_shape = predictor_input_shape, + face_enhancer_func = face_enhancer_func, + xseg_256_extract_func = xseg_256_extract_func, + merger_config = cfg, + frames = frames, + frames_root_path = input_path, + output_path = output_path, + output_mask_path = output_mask_path, + model_iter = model.get_iter(), + subprocess_count = subprocess_count, + ).run() + + model.finalize() + + except Exception as e: + print ( traceback.format_exc() ) + + +""" +elif cfg.type == MergerConfig.TYPE_FACE_AVATAR: +filesdata = [] +for filepath in io.progress_bar_generator(input_path_image_paths, "Collecting info"): + filepath = Path(filepath) + + dflimg = DFLIMG.x(filepath) + if dflimg is None: + io.log_err ("%s is not a dfl image file" % (filepath.name) ) + continue + filesdata += [ ( FrameInfo(filepath=filepath, landmarks_list=[dflimg.get_landmarks()] ), dflimg.get_source_filename() ) ] + +filesdata = sorted(filesdata, key=operator.itemgetter(1)) #sort by source_filename +frames = [] +filesdata_len = len(filesdata) +for i in range(len(filesdata)): + frame_info = filesdata[i][0] + + prev_temporal_frame_infos = [] + next_temporal_frame_infos = [] + + for t in range (cfg.temporal_face_count): + prev_frame_info = filesdata[ max(i -t, 0) ][0] + next_frame_info = filesdata[ min(i +t, filesdata_len-1 )][0] + + prev_temporal_frame_infos.insert (0, prev_frame_info ) + next_temporal_frame_infos.append ( next_frame_info ) + + frames.append ( InteractiveMergerSubprocessor.Frame(prev_temporal_frame_infos=prev_temporal_frame_infos, + frame_info=frame_info, + next_temporal_frame_infos=next_temporal_frame_infos) ) +""" + +#interpolate landmarks +#from facelib import LandmarksProcessor +#from facelib import FaceType +#a = sorted(alignments.keys()) +#a_len = len(a) +# +#box_pts = 3 +#box = np.ones(box_pts)/box_pts +#for i in range( a_len ): +# if i >= box_pts and i <= a_len-box_pts-1: +# af0 = alignments[ a[i] ][0] ##first face +# m0 = LandmarksProcessor.get_transform_mat (af0, 256, face_type=FaceType.FULL) +# +# points = [] +# +# for j in range(-box_pts, box_pts+1): +# af = alignments[ a[i+j] ][0] ##first face +# m = LandmarksProcessor.get_transform_mat (af, 256, face_type=FaceType.FULL) +# p = LandmarksProcessor.transform_points (af, m) +# points.append (p) +# +# points = np.array(points) +# points_len = len(points) +# t_points = np.transpose(points, [1,0,2]) +# +# p1 = np.array ( [ int(np.convolve(x[:,0], box, mode='same')[points_len//2]) for x in t_points ] ) +# p2 = np.array ( [ int(np.convolve(x[:,1], box, mode='same')[points_len//2]) for x in t_points ] ) +# +# new_points = np.concatenate( [np.expand_dims(p1,-1),np.expand_dims(p2,-1)], -1 ) +# +# alignments[ a[i] ][0] = LandmarksProcessor.transform_points (new_points, m0, True).astype(np.int32) diff --git a/mainscripts/Sorter.py b/mainscripts/Sorter.py new file mode 100644 index 0000000000000000000000000000000000000000..39eec5e1180297f89a793cd798a5d4625a95cbfe --- /dev/null +++ b/mainscripts/Sorter.py @@ -0,0 +1,937 @@ +import math +import multiprocessing +import operator +import os +import sys +import tempfile +from functools import cmp_to_key +from pathlib import Path + +import cv2 +import numpy as np +from numpy import linalg as npla + +from core import imagelib, mathlib, pathex +from core.cv2ex import * +from core.imagelib import estimate_sharpness +from core.interact import interact as io +from core.joblib import Subprocessor +from core.leras import nn +from DFLIMG import * +from facelib import LandmarksProcessor + + +class BlurEstimatorSubprocessor(Subprocessor): + class Cli(Subprocessor.Cli): + def on_initialize(self, client_dict): + self.estimate_motion_blur = client_dict['estimate_motion_blur'] + + #override + def process_data(self, data): + filepath = Path( data[0] ) + dflimg = DFLIMG.load (filepath) + + if dflimg is None or not dflimg.has_data(): + self.log_err (f"{filepath.name} is not a dfl image file") + return [ str(filepath), 0 ] + else: + image = cv2_imread( str(filepath) ) + + face_mask = LandmarksProcessor.get_image_hull_mask (image.shape, dflimg.get_landmarks()) + image = (image*face_mask).astype(np.uint8) + + + if self.estimate_motion_blur: + value = cv2.Laplacian(image, cv2.CV_64F, ksize=11).var() + else: + value = estimate_sharpness(image) + + return [ str(filepath), value ] + + + #override + def get_data_name (self, data): + #return string identificator of your data + return data[0] + + #override + def __init__(self, input_data, estimate_motion_blur=False ): + self.input_data = input_data + self.estimate_motion_blur = estimate_motion_blur + self.img_list = [] + self.trash_img_list = [] + super().__init__('BlurEstimator', BlurEstimatorSubprocessor.Cli, 60) + + #override + def on_clients_initialized(self): + io.progress_bar ("", len (self.input_data)) + + #override + def on_clients_finalized(self): + io.progress_bar_close () + + #override + def process_info_generator(self): + cpu_count = multiprocessing.cpu_count() + io.log_info(f'Running on {cpu_count} CPUs') + + for i in range(cpu_count): + yield 'CPU%d' % (i), {}, {'estimate_motion_blur':self.estimate_motion_blur} + + #override + def get_data(self, host_dict): + if len (self.input_data) > 0: + return self.input_data.pop(0) + + return None + + #override + def on_data_return (self, host_dict, data): + self.input_data.insert(0, data) + + #override + def on_result (self, host_dict, data, result): + if result[1] == 0: + self.trash_img_list.append ( result ) + else: + self.img_list.append ( result ) + + io.progress_bar_inc(1) + + #override + def get_result(self): + return self.img_list, self.trash_img_list + + +def sort_by_blur(input_path): + io.log_info ("Sorting by blur...") + + img_list = [ (filename,[]) for filename in pathex.get_image_paths(input_path) ] + img_list, trash_img_list = BlurEstimatorSubprocessor (img_list).run() + + io.log_info ("Sorting...") + img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) + + return img_list, trash_img_list + +def sort_by_motion_blur(input_path): + io.log_info ("Sorting by motion blur...") + + img_list = [ (filename,[]) for filename in pathex.get_image_paths(input_path) ] + img_list, trash_img_list = BlurEstimatorSubprocessor (img_list, estimate_motion_blur=True).run() + + io.log_info ("Sorting...") + img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) + + return img_list, trash_img_list + +def sort_by_face_yaw(input_path): + io.log_info ("Sorting by face yaw...") + img_list = [] + trash_img_list = [] + for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"): + filepath = Path(filepath) + + dflimg = DFLIMG.load (filepath) + + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") + trash_img_list.append ( [str(filepath)] ) + continue + + pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] ) + + img_list.append( [str(filepath), yaw ] ) + + io.log_info ("Sorting...") + img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) + + return img_list, trash_img_list + +def sort_by_face_pitch(input_path): + io.log_info ("Sorting by face pitch...") + img_list = [] + trash_img_list = [] + for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"): + filepath = Path(filepath) + + dflimg = DFLIMG.load (filepath) + + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") + trash_img_list.append ( [str(filepath)] ) + continue + + pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] ) + + img_list.append( [str(filepath), pitch ] ) + + io.log_info ("Sorting...") + img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) + + return img_list, trash_img_list + +def sort_by_face_source_rect_size(input_path): + io.log_info ("Sorting by face rect size...") + img_list = [] + trash_img_list = [] + for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"): + filepath = Path(filepath) + + dflimg = DFLIMG.load (filepath) + + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") + trash_img_list.append ( [str(filepath)] ) + continue + + source_rect = dflimg.get_source_rect() + rect_area = mathlib.polygon_area(np.array(source_rect[[0,2,2,0]]).astype(np.float32), np.array(source_rect[[1,1,3,3]]).astype(np.float32)) + + img_list.append( [str(filepath), rect_area ] ) + + io.log_info ("Sorting...") + img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) + + return img_list, trash_img_list + + + +class HistSsimSubprocessor(Subprocessor): + class Cli(Subprocessor.Cli): + #override + def process_data(self, data): + img_list = [] + for x in data: + img = cv2_imread(x) + img_list.append ([x, cv2.calcHist([img], [0], None, [256], [0, 256]), + cv2.calcHist([img], [1], None, [256], [0, 256]), + cv2.calcHist([img], [2], None, [256], [0, 256]) + ]) + + img_list_len = len(img_list) + for i in range(img_list_len-1): + min_score = float("inf") + j_min_score = i+1 + for j in range(i+1,len(img_list)): + score = cv2.compareHist(img_list[i][1], img_list[j][1], cv2.HISTCMP_BHATTACHARYYA) + \ + cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) + \ + cv2.compareHist(img_list[i][3], img_list[j][3], cv2.HISTCMP_BHATTACHARYYA) + if score < min_score: + min_score = score + j_min_score = j + img_list[i+1], img_list[j_min_score] = img_list[j_min_score], img_list[i+1] + + self.progress_bar_inc(1) + + return img_list + + #override + def get_data_name (self, data): + return "Bunch of images" + + #override + def __init__(self, img_list ): + self.img_list = img_list + self.img_list_len = len(img_list) + + slice_count = 20000 + sliced_count = self.img_list_len // slice_count + + if sliced_count > 12: + sliced_count = 11.9 + slice_count = int(self.img_list_len / sliced_count) + sliced_count = self.img_list_len // slice_count + + self.img_chunks_list = [ self.img_list[i*slice_count : (i+1)*slice_count] for i in range(sliced_count) ] + \ + [ self.img_list[sliced_count*slice_count:] ] + + self.result = [] + super().__init__('HistSsim', HistSsimSubprocessor.Cli, 0) + + #override + def process_info_generator(self): + cpu_count = len(self.img_chunks_list) + io.log_info(f'Running on {cpu_count} threads') + for i in range(cpu_count): + yield 'CPU%d' % (i), {'i':i}, {} + + #override + def on_clients_initialized(self): + io.progress_bar ("Sorting", len(self.img_list)) + io.progress_bar_inc(len(self.img_chunks_list)) + + #override + def on_clients_finalized(self): + io.progress_bar_close() + + #override + def get_data(self, host_dict): + if len (self.img_chunks_list) > 0: + return self.img_chunks_list.pop(0) + return None + + #override + def on_data_return (self, host_dict, data): + raise Exception("Fail to process data. Decrease number of images and try again.") + + #override + def on_result (self, host_dict, data, result): + self.result += result + return 0 + + #override + def get_result(self): + return self.result + +def sort_by_hist(input_path): + io.log_info ("Sorting by histogram similarity...") + img_list = HistSsimSubprocessor(pathex.get_image_paths(input_path)).run() + return img_list, [] + +class HistDissimSubprocessor(Subprocessor): + class Cli(Subprocessor.Cli): + #override + def on_initialize(self, client_dict): + self.img_list = client_dict['img_list'] + self.img_list_len = len(self.img_list) + + #override + def process_data(self, data): + i = data[0] + score_total = 0 + for j in range( 0, self.img_list_len): + if i == j: + continue + score_total += cv2.compareHist(self.img_list[i][1], self.img_list[j][1], cv2.HISTCMP_BHATTACHARYYA) + + return score_total + + #override + def get_data_name (self, data): + #return string identificator of your data + return self.img_list[data[0]][0] + + #override + def __init__(self, img_list ): + self.img_list = img_list + self.img_list_range = [i for i in range(0, len(img_list) )] + self.result = [] + super().__init__('HistDissim', HistDissimSubprocessor.Cli, 60) + + #override + def on_clients_initialized(self): + io.progress_bar ("Sorting", len (self.img_list) ) + + #override + def on_clients_finalized(self): + io.progress_bar_close() + + #override + def process_info_generator(self): + cpu_count = min(multiprocessing.cpu_count(), 8) + io.log_info(f'Running on {cpu_count} CPUs') + for i in range(cpu_count): + yield 'CPU%d' % (i), {}, {'img_list' : self.img_list} + + #override + def get_data(self, host_dict): + if len (self.img_list_range) > 0: + return [self.img_list_range.pop(0)] + + return None + + #override + def on_data_return (self, host_dict, data): + self.img_list_range.insert(0, data[0]) + + #override + def on_result (self, host_dict, data, result): + self.img_list[data[0]][2] = result + io.progress_bar_inc(1) + + #override + def get_result(self): + return self.img_list + +def sort_by_hist_dissim(input_path): + io.log_info ("Sorting by histogram dissimilarity...") + + img_list = [] + trash_img_list = [] + for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"): + filepath = Path(filepath) + + dflimg = DFLIMG.load (filepath) + + image = cv2_imread(str(filepath)) + + if dflimg is not None and dflimg.has_data(): + face_mask = LandmarksProcessor.get_image_hull_mask (image.shape, dflimg.get_landmarks()) + image = (image*face_mask).astype(np.uint8) + + img_list.append ([str(filepath), cv2.calcHist([cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)], [0], None, [256], [0, 256]), 0 ]) + + img_list = HistDissimSubprocessor(img_list).run() + + io.log_info ("Sorting...") + img_list = sorted(img_list, key=operator.itemgetter(2), reverse=True) + + return img_list, trash_img_list + +def sort_by_brightness(input_path): + io.log_info ("Sorting by brightness...") + img_list = [ [x, np.mean ( cv2.cvtColor(cv2_imread(x), cv2.COLOR_BGR2HSV)[...,2].flatten() )] for x in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading") ] + io.log_info ("Sorting...") + img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) + return img_list, [] + +def sort_by_hue(input_path): + io.log_info ("Sorting by hue...") + img_list = [ [x, np.mean ( cv2.cvtColor(cv2_imread(x), cv2.COLOR_BGR2HSV)[...,0].flatten() )] for x in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading") ] + io.log_info ("Sorting...") + img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) + return img_list, [] + +def sort_by_black(input_path): + io.log_info ("Sorting by amount of black pixels...") + + img_list = [] + for x in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"): + img = cv2_imread(x) + img_list.append ([x, img[(img == 0)].size ]) + + io.log_info ("Sorting...") + img_list = sorted(img_list, key=operator.itemgetter(1), reverse=False) + + return img_list, [] + +def sort_by_origname(input_path): + io.log_info ("Sort by original filename...") + + img_list = [] + trash_img_list = [] + for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Loading"): + filepath = Path(filepath) + + dflimg = DFLIMG.load (filepath) + + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") + trash_img_list.append( [str(filepath)] ) + continue + + img_list.append( [str(filepath), dflimg.get_source_filename()] ) + + io.log_info ("Sorting...") + img_list = sorted(img_list, key=operator.itemgetter(1)) + return img_list, trash_img_list + +def sort_by_oneface_in_image(input_path): + io.log_info ("Sort by one face in images...") + image_paths = pathex.get_image_paths(input_path) + a = np.array ([ ( int(x[0]), int(x[1]) ) \ + for x in [ Path(filepath).stem.split('_') for filepath in image_paths ] if len(x) == 2 + ]) + if len(a) > 0: + idxs = np.ndarray.flatten ( np.argwhere ( a[:,1] != 0 ) ) + idxs = np.unique ( a[idxs][:,0] ) + idxs = np.ndarray.flatten ( np.argwhere ( np.array([ x[0] in idxs for x in a ]) == True ) ) + if len(idxs) > 0: + io.log_info ("Found %d images." % (len(idxs)) ) + img_list = [ (path,) for i,path in enumerate(image_paths) if i not in idxs ] + trash_img_list = [ (image_paths[x],) for x in idxs ] + return img_list, trash_img_list + + io.log_info ("Nothing found. Possible recover original filenames first.") + return [], [] + +class FinalLoaderSubprocessor(Subprocessor): + class Cli(Subprocessor.Cli): + #override + def on_initialize(self, client_dict): + self.faster = client_dict['faster'] + + #override + def process_data(self, data): + filepath = Path(data[0]) + + try: + dflimg = DFLIMG.load (filepath) + + if dflimg is None or not dflimg.has_data(): + self.log_err (f"{filepath.name} is not a dfl image file") + return [ 1, [str(filepath)] ] + + bgr = cv2_imread(str(filepath)) + if bgr is None: + raise Exception ("Unable to load %s" % (filepath.name) ) + + gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) + if self.faster: + source_rect = dflimg.get_source_rect() + sharpness = mathlib.polygon_area(np.array(source_rect[[0,2,2,0]]).astype(np.float32), np.array(source_rect[[1,1,3,3]]).astype(np.float32)) + else: + face_mask = LandmarksProcessor.get_image_hull_mask (gray.shape, dflimg.get_landmarks()) + sharpness = estimate_sharpness( (gray[...,None]*face_mask).astype(np.uint8) ) + + pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll ( dflimg.get_landmarks(), size=dflimg.get_shape()[1] ) + + hist = cv2.calcHist([gray], [0], None, [256], [0, 256]) + except Exception as e: + self.log_err (e) + return [ 1, [str(filepath)] ] + + return [ 0, [str(filepath), sharpness, hist, yaw, pitch ] ] + + #override + def get_data_name (self, data): + #return string identificator of your data + return data[0] + + #override + def __init__(self, img_list, faster ): + self.img_list = img_list + + self.faster = faster + self.result = [] + self.result_trash = [] + + super().__init__('FinalLoader', FinalLoaderSubprocessor.Cli, 60) + + #override + def on_clients_initialized(self): + io.progress_bar ("Loading", len (self.img_list)) + + #override + def on_clients_finalized(self): + io.progress_bar_close() + + #override + def process_info_generator(self): + cpu_count = min(multiprocessing.cpu_count(), 8) + io.log_info(f'Running on {cpu_count} CPUs') + + for i in range(cpu_count): + yield 'CPU%d' % (i), {}, {'faster': self.faster} + + #override + def get_data(self, host_dict): + if len (self.img_list) > 0: + return [self.img_list.pop(0)] + + return None + + #override + def on_data_return (self, host_dict, data): + self.img_list.insert(0, data[0]) + + #override + def on_result (self, host_dict, data, result): + if result[0] == 0: + self.result.append (result[1]) + else: + self.result_trash.append (result[1]) + io.progress_bar_inc(1) + + #override + def get_result(self): + return self.result, self.result_trash + +class FinalHistDissimSubprocessor(Subprocessor): + class Cli(Subprocessor.Cli): + #override + def process_data(self, data): + idx, pitch_yaw_img_list = data + + for p in range ( len(pitch_yaw_img_list) ): + + img_list = pitch_yaw_img_list[p] + if img_list is not None: + for i in range( len(img_list) ): + score_total = 0 + for j in range( len(img_list) ): + if i == j: + continue + score_total += cv2.compareHist(img_list[i][2], img_list[j][2], cv2.HISTCMP_BHATTACHARYYA) + img_list[i][3] = score_total + + pitch_yaw_img_list[p] = sorted(img_list, key=operator.itemgetter(3), reverse=True) + + return idx, pitch_yaw_img_list + + #override + def get_data_name (self, data): + return "Bunch of images" + + #override + def __init__(self, pitch_yaw_sample_list ): + self.pitch_yaw_sample_list = pitch_yaw_sample_list + self.pitch_yaw_sample_list_len = len(pitch_yaw_sample_list) + + self.pitch_yaw_sample_list_idxs = [ i for i in range(self.pitch_yaw_sample_list_len) if self.pitch_yaw_sample_list[i] is not None ] + self.result = [ None for _ in range(self.pitch_yaw_sample_list_len) ] + super().__init__('FinalHistDissimSubprocessor', FinalHistDissimSubprocessor.Cli) + + #override + def process_info_generator(self): + cpu_count = min(multiprocessing.cpu_count(), 8) + io.log_info(f'Running on {cpu_count} CPUs') + for i in range(cpu_count): + yield 'CPU%d' % (i), {}, {} + + #override + def on_clients_initialized(self): + io.progress_bar ("Sort by hist-dissim", len(self.pitch_yaw_sample_list_idxs) ) + + #override + def on_clients_finalized(self): + io.progress_bar_close() + + #override + def get_data(self, host_dict): + if len (self.pitch_yaw_sample_list_idxs) > 0: + idx = self.pitch_yaw_sample_list_idxs.pop(0) + + return idx, self.pitch_yaw_sample_list[idx] + return None + + #override + def on_data_return (self, host_dict, data): + self.pitch_yaw_sample_list_idxs.insert(0, data[0]) + + #override + def on_result (self, host_dict, data, result): + idx, yaws_sample_list = data + self.result[idx] = yaws_sample_list + io.progress_bar_inc(1) + + #override + def get_result(self): + return self.result + +def sort_best_faster(input_path): + return sort_best(input_path, faster=True) + +def sort_best(input_path, faster=False): + target_count = io.input_int ("Target number of faces?", 2000) + + io.log_info ("Performing sort by best faces.") + if faster: + io.log_info("Using faster algorithm. Faces will be sorted by source-rect-area instead of blur.") + + img_list, trash_img_list = FinalLoaderSubprocessor( pathex.get_image_paths(input_path), faster ).run() + final_img_list = [] + + grads = 128 + imgs_per_grad = round (target_count / grads) + + #instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2 + grads_space = np.linspace (-1.2, 1.2,grads) + + yaws_sample_list = [None]*grads + for g in io.progress_bar_generator ( range(grads), "Sort by yaw"): + yaw = grads_space[g] + next_yaw = grads_space[g+1] if g < grads-1 else yaw + + yaw_samples = [] + for img in img_list: + s_yaw = -img[3] + if (g == 0 and s_yaw < next_yaw) or \ + (g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \ + (g == grads-1 and s_yaw >= yaw): + yaw_samples += [ img ] + if len(yaw_samples) > 0: + yaws_sample_list[g] = yaw_samples + + total_lack = 0 + for g in io.progress_bar_generator ( range(grads), ""): + img_list = yaws_sample_list[g] + img_list_len = len(img_list) if img_list is not None else 0 + + lack = imgs_per_grad - img_list_len + total_lack += max(lack, 0) + + imgs_per_grad += total_lack // grads + + + sharpned_imgs_per_grad = imgs_per_grad*10 + for g in io.progress_bar_generator ( range (grads), "Sort by blur"): + img_list = yaws_sample_list[g] + if img_list is None: + continue + + img_list = sorted(img_list, key=operator.itemgetter(1), reverse=True) + + if len(img_list) > sharpned_imgs_per_grad: + trash_img_list += img_list[sharpned_imgs_per_grad:] + img_list = img_list[0:sharpned_imgs_per_grad] + + yaws_sample_list[g] = img_list + + + yaw_pitch_sample_list = [None]*grads + pitch_grads = imgs_per_grad + + for g in io.progress_bar_generator ( range (grads), "Sort by pitch"): + img_list = yaws_sample_list[g] + if img_list is None: + continue + + pitch_sample_list = [None]*pitch_grads + + grads_space = np.linspace (-math.pi / 2,math.pi / 2, pitch_grads ) + + for pg in range (pitch_grads): + + pitch = grads_space[pg] + next_pitch = grads_space[pg+1] if pg < pitch_grads-1 else pitch + + pitch_samples = [] + for img in img_list: + s_pitch = img[4] + if (pg == 0 and s_pitch < next_pitch) or \ + (pg < pitch_grads-1 and s_pitch >= pitch and s_pitch < next_pitch) or \ + (pg == pitch_grads-1 and s_pitch >= pitch): + pitch_samples += [ img ] + + if len(pitch_samples) > 0: + pitch_sample_list[pg] = pitch_samples + yaw_pitch_sample_list[g] = pitch_sample_list + + yaw_pitch_sample_list = FinalHistDissimSubprocessor(yaw_pitch_sample_list).run() + + for g in io.progress_bar_generator (range (grads), "Fetching the best"): + pitch_sample_list = yaw_pitch_sample_list[g] + if pitch_sample_list is None: + continue + + n = imgs_per_grad + + while n > 0: + n_prev = n + for pg in range(pitch_grads): + img_list = pitch_sample_list[pg] + if img_list is None: + continue + final_img_list += [ img_list.pop(0) ] + if len(img_list) == 0: + pitch_sample_list[pg] = None + n -= 1 + if n == 0: + break + if n_prev == n: + break + + for pg in range(pitch_grads): + img_list = pitch_sample_list[pg] + if img_list is None: + continue + trash_img_list += img_list + + return final_img_list, trash_img_list + +""" +def sort_by_vggface(input_path): + io.log_info ("Sorting by face similarity using VGGFace model...") + + model = VGGFace() + + final_img_list = [] + trash_img_list = [] + + image_paths = pathex.get_image_paths(input_path) + img_list = [ (x,) for x in image_paths ] + img_list_len = len(img_list) + img_list_range = [*range(img_list_len)] + + feats = [None]*img_list_len + for i in io.progress_bar_generator(img_list_range, "Loading"): + img = cv2_imread( img_list[i][0] ).astype(np.float32) + img = imagelib.normalize_channels (img, 3) + img = cv2.resize (img, (224,224) ) + img = img[..., ::-1] + img[..., 0] -= 93.5940 + img[..., 1] -= 104.7624 + img[..., 2] -= 129.1863 + feats[i] = model.predict( img[None,...] )[0] + + tmp = np.zeros( (img_list_len,) ) + float_inf = float("inf") + for i in io.progress_bar_generator ( range(img_list_len-1), "Sorting" ): + i_feat = feats[i] + + for j in img_list_range: + tmp[j] = npla.norm(i_feat-feats[j]) if j >= i+1 else float_inf + + idx = np.argmin(tmp) + + img_list[i+1], img_list[idx] = img_list[idx], img_list[i+1] + feats[i+1], feats[idx] = feats[idx], feats[i+1] + + return img_list, trash_img_list +""" + +def sort_by_absdiff(input_path): + io.log_info ("Sorting by absolute difference...") + + is_sim = io.input_bool ("Sort by similar?", True, help_message="Otherwise sort by dissimilar.") + + from core.leras import nn + + device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True) + nn.initialize( device_config=device_config, data_format="NHWC" ) + tf = nn.tf + + image_paths = pathex.get_image_paths(input_path) + image_paths_len = len(image_paths) + + batch_size = 512 + batch_size_remain = image_paths_len % batch_size + + i_t = tf.placeholder (tf.float32, (None,None,None,None) ) + j_t = tf.placeholder (tf.float32, (None,None,None,None) ) + + outputs_full = [] + outputs_remain = [] + + for i in range(batch_size): + diff_t = tf.reduce_sum( tf.abs(i_t-j_t[i]), axis=[1,2,3] ) + outputs_full.append(diff_t) + if i < batch_size_remain: + outputs_remain.append(diff_t) + + def func_bs_full(i,j): + return nn.tf_sess.run (outputs_full, feed_dict={i_t:i,j_t:j}) + + def func_bs_remain(i,j): + return nn.tf_sess.run (outputs_remain, feed_dict={i_t:i,j_t:j}) + + import h5py + db_file_path = Path(tempfile.gettempdir()) / 'sort_cache.hdf5' + db_file = h5py.File( str(db_file_path), "w") + db = db_file.create_dataset("results", (image_paths_len,image_paths_len), compression="gzip") + + pg_len = image_paths_len // batch_size + if batch_size_remain != 0: + pg_len += 1 + + pg_len = int( ( pg_len*pg_len - pg_len ) / 2 + pg_len ) + + io.progress_bar ("Computing", pg_len) + j=0 + while j < image_paths_len: + j_images = [ cv2_imread(x) for x in image_paths[j:j+batch_size] ] + j_images_len = len(j_images) + + func = func_bs_remain if image_paths_len-j < batch_size else func_bs_full + + i=0 + while i < image_paths_len: + if i >= j: + i_images = [ cv2_imread(x) for x in image_paths[i:i+batch_size] ] + i_images_len = len(i_images) + result = func (i_images,j_images) + db[j:j+j_images_len,i:i+i_images_len] = np.array(result) + io.progress_bar_inc(1) + + i += batch_size + db_file.flush() + j += batch_size + + io.progress_bar_close() + + next_id = 0 + sorted = [next_id] + for i in io.progress_bar_generator ( range(image_paths_len-1), "Sorting" ): + id_ar = np.concatenate ( [ db[:next_id,next_id], db[next_id,next_id:] ] ) + id_ar = np.argsort(id_ar) + + + next_id = np.setdiff1d(id_ar, sorted, True)[ 0 if is_sim else -1] + sorted += [next_id] + db_file.close() + db_file_path.unlink() + + img_list = [ (image_paths[x],) for x in sorted] + return img_list, [] + +def final_process(input_path, img_list, trash_img_list): + if len(trash_img_list) != 0: + parent_input_path = input_path.parent + trash_path = parent_input_path / (input_path.stem + '_trash') + trash_path.mkdir (exist_ok=True) + + io.log_info ("Trashing %d items to %s" % ( len(trash_img_list), str(trash_path) ) ) + + for filename in pathex.get_image_paths(trash_path): + Path(filename).unlink() + + for i in io.progress_bar_generator( range(len(trash_img_list)), "Moving trash", leave=False): + src = Path (trash_img_list[i][0]) + dst = trash_path / src.name + try: + src.rename (dst) + except: + io.log_info ('fail to trashing %s' % (src.name) ) + + io.log_info ("") + + if len(img_list) != 0: + for i in io.progress_bar_generator( [*range(len(img_list))], "Renaming", leave=False): + src = Path (img_list[i][0]) + dst = input_path / ('%.5d_%s' % (i, src.name )) + try: + src.rename (dst) + except: + io.log_info ('fail to rename %s' % (src.name) ) + + for i in io.progress_bar_generator( [*range(len(img_list))], "Renaming"): + src = Path (img_list[i][0]) + src = input_path / ('%.5d_%s' % (i, src.name)) + dst = input_path / ('%.5d%s' % (i, src.suffix)) + try: + src.rename (dst) + except: + io.log_info ('fail to rename %s' % (src.name) ) + +sort_func_methods = { + 'blur': ("blur", sort_by_blur), + 'motion-blur': ("motion_blur", sort_by_motion_blur), + 'face-yaw': ("face yaw direction", sort_by_face_yaw), + 'face-pitch': ("face pitch direction", sort_by_face_pitch), + 'face-source-rect-size' : ("face rect size in source image", sort_by_face_source_rect_size), + 'hist': ("histogram similarity", sort_by_hist), + 'hist-dissim': ("histogram dissimilarity", sort_by_hist_dissim), + 'brightness': ("brightness", sort_by_brightness), + 'hue': ("hue", sort_by_hue), + 'black': ("amount of black pixels", sort_by_black), + 'origname': ("original filename", sort_by_origname), + 'oneface': ("one face in image", sort_by_oneface_in_image), + 'absdiff': ("absolute pixel difference", sort_by_absdiff), + 'final': ("best faces", sort_best), + 'final-fast': ("best faces faster", sort_best_faster), +} + +def main (input_path, sort_by_method=None): + io.log_info ("Running sort tool.\r\n") + + if sort_by_method is None: + io.log_info(f"Choose sorting method:") + + key_list = list(sort_func_methods.keys()) + for i, key in enumerate(key_list): + desc, func = sort_func_methods[key] + io.log_info(f"[{i}] {desc}") + + io.log_info("") + id = io.input_int("", 5, valid_list=[*range(len(key_list))] ) + + sort_by_method = key_list[id] + else: + sort_by_method = sort_by_method.lower() + + desc, func = sort_func_methods[sort_by_method] + img_list, trash_img_list = func(input_path) + + final_process (input_path, img_list, trash_img_list) diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..df74ca37805cfa3e6b9843c8d01e4d7dd83b8293 --- /dev/null +++ b/mainscripts/Trainer.py @@ -0,0 +1,360 @@ +import os +import sys +import traceback +import queue +import threading +import time +import numpy as np +import itertools +from pathlib import Path +from core import pathex +from core import imagelib +import cv2 +import models +from core.interact import interact as io + +def trainerThread (s2c, c2s, e, + model_class_name = None, + saved_models_path = None, + training_data_src_path = None, + training_data_dst_path = None, + pretraining_data_path = None, + pretrained_model_path = None, + no_preview=False, + force_model_name=None, + force_gpu_idxs=None, + cpu_only=None, + silent_start=False, + execute_programs = None, + debug=False, + **kwargs): + while True: + try: + start_time = time.time() + + save_interval_min = 25 + + if not training_data_src_path.exists(): + training_data_src_path.mkdir(exist_ok=True, parents=True) + + if not training_data_dst_path.exists(): + training_data_dst_path.mkdir(exist_ok=True, parents=True) + + if not saved_models_path.exists(): + saved_models_path.mkdir(exist_ok=True, parents=True) + + model = models.import_model(model_class_name)( + is_training=True, + saved_models_path=saved_models_path, + training_data_src_path=training_data_src_path, + training_data_dst_path=training_data_dst_path, + pretraining_data_path=pretraining_data_path, + pretrained_model_path=pretrained_model_path, + no_preview=no_preview, + force_model_name=force_model_name, + force_gpu_idxs=force_gpu_idxs, + cpu_only=cpu_only, + silent_start=silent_start, + debug=debug) + + is_reached_goal = model.is_reached_iter_goal() + + shared_state = { 'after_save' : False } + loss_string = "" + save_iter = model.get_iter() + def model_save(): + if not debug and not is_reached_goal: + io.log_info ("Saving....", end='\r') + model.save() + shared_state['after_save'] = True + + def model_backup(): + if not debug and not is_reached_goal: + model.create_backup() + + def send_preview(): + if not debug: + previews = model.get_previews() + c2s.put ( {'op':'show', 'previews': previews, 'iter':model.get_iter(), 'loss_history': model.get_loss_history().copy() } ) + else: + previews = [( 'debug, press update for new', model.debug_one_iter())] + c2s.put ( {'op':'show', 'previews': previews} ) + e.set() #Set the GUI Thread as Ready + + if model.get_target_iter() != 0: + if is_reached_goal: + io.log_info('Model already trained to target iteration. You can use preview.') + else: + io.log_info('Starting. Target iteration: %d. Press "Enter" to stop training and save model.' % ( model.get_target_iter() ) ) + else: + io.log_info('Starting. Press "Enter" to stop training and save model.') + + last_save_time = time.time() + + execute_programs = [ [x[0], x[1], time.time() ] for x in execute_programs ] + + for i in itertools.count(0,1): + if not debug: + cur_time = time.time() + + for x in execute_programs: + prog_time, prog, last_time = x + exec_prog = False + if prog_time > 0 and (cur_time - start_time) >= prog_time: + x[0] = 0 + exec_prog = True + elif prog_time < 0 and (cur_time - last_time) >= -prog_time: + x[2] = cur_time + exec_prog = True + + if exec_prog: + try: + exec(prog) + except Exception as e: + print("Unable to execute program: %s" % (prog) ) + + if not is_reached_goal: + + if model.get_iter() == 0: + io.log_info("") + io.log_info("Trying to do the first iteration. If an error occurs, reduce the model parameters.") + io.log_info("") + + if sys.platform[0:3] == 'win': + io.log_info("!!!") + io.log_info("Windows 10 users IMPORTANT notice. You should set this setting in order to work correctly.") + io.log_info("https://i.imgur.com/B7cmDCB.jpg") + io.log_info("!!!") + + iter, iter_time = model.train_one_iter() + + loss_history = model.get_loss_history() + time_str = time.strftime("[%H:%M:%S]") + if iter_time >= 10: + loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, iter, '{:0.4f}'.format(iter_time) ) + else: + loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, iter, int(iter_time*1000) ) + + if shared_state['after_save']: + shared_state['after_save'] = False + + mean_loss = np.mean ( loss_history[save_iter:iter], axis=0) + + for loss_value in mean_loss: + loss_string += "[%.4f]" % (loss_value) + + io.log_info (loss_string) + + save_iter = iter + else: + for loss_value in loss_history[-1]: + loss_string += "[%.4f]" % (loss_value) + + if io.is_colab(): + io.log_info ('\r' + loss_string, end='') + else: + io.log_info (loss_string, end='\r') + + if model.get_iter() == 1: + model_save() + + if model.get_target_iter() != 0 and model.is_reached_iter_goal(): + io.log_info ('Reached target iteration.') + model_save() + is_reached_goal = True + io.log_info ('You can use preview now.') + + need_save = False + while time.time() - last_save_time >= save_interval_min*60: + last_save_time += save_interval_min*60 + need_save = True + + if not is_reached_goal and need_save: + model_save() + send_preview() + + if i==0: + if is_reached_goal: + model.pass_one_iter() + send_preview() + + if debug: + time.sleep(0.005) + + while not s2c.empty(): + input = s2c.get() + op = input['op'] + if op == 'save': + model_save() + elif op == 'backup': + model_backup() + elif op == 'preview': + if is_reached_goal: + model.pass_one_iter() + send_preview() + elif op == 'close': + model_save() + i = -1 + break + + if i == -1: + break + + + + model.finalize() + + except Exception as e: + print ('Error: %s' % (str(e))) + traceback.print_exc() + break + c2s.put ( {'op':'close'} ) + + + +def main(**kwargs): + io.log_info ("Running trainer.\r\n") + + no_preview = kwargs.get('no_preview', False) + + s2c = queue.Queue() + c2s = queue.Queue() + + e = threading.Event() + thread = threading.Thread(target=trainerThread, args=(s2c, c2s, e), kwargs=kwargs ) + thread.start() + + e.wait() #Wait for inital load to occur. + + if no_preview: + while True: + if not c2s.empty(): + input = c2s.get() + op = input.get('op','') + if op == 'close': + break + try: + io.process_messages(0.1) + except KeyboardInterrupt: + s2c.put ( {'op': 'close'} ) + else: + wnd_name = "Training preview" + io.named_window(wnd_name) + io.capture_keys(wnd_name) + + previews = None + loss_history = None + selected_preview = 0 + update_preview = False + is_showing = False + is_waiting_preview = False + show_last_history_iters_count = 0 + iter = 0 + while True: + if not c2s.empty(): + input = c2s.get() + op = input['op'] + if op == 'show': + is_waiting_preview = False + loss_history = input['loss_history'] if 'loss_history' in input.keys() else None + previews = input['previews'] if 'previews' in input.keys() else None + iter = input['iter'] if 'iter' in input.keys() else 0 + if previews is not None: + max_w = 0 + max_h = 0 + for (preview_name, preview_rgb) in previews: + (h, w, c) = preview_rgb.shape + max_h = max (max_h, h) + max_w = max (max_w, w) + + max_size = 800 + if max_h > max_size: + max_w = int( max_w / (max_h / max_size) ) + max_h = max_size + + #make all previews size equal + for preview in previews[:]: + (preview_name, preview_rgb) = preview + (h, w, c) = preview_rgb.shape + if h != max_h or w != max_w: + previews.remove(preview) + previews.append ( (preview_name, cv2.resize(preview_rgb, (max_w, max_h))) ) + selected_preview = selected_preview % len(previews) + update_preview = True + elif op == 'close': + break + + if update_preview: + update_preview = False + + selected_preview_name = previews[selected_preview][0] + selected_preview_rgb = previews[selected_preview][1] + (h,w,c) = selected_preview_rgb.shape + + # HEAD + head_lines = [ + '[s]:save [b]:backup [enter]:exit', + '[p]:update [space]:next preview [l]:change history range', + 'Preview: "%s" [%d/%d]' % (selected_preview_name,selected_preview+1, len(previews) ) + ] + head_line_height = 15 + head_height = len(head_lines) * head_line_height + head = np.ones ( (head_height,w,c) ) * 0.1 + + for i in range(0, len(head_lines)): + t = i*head_line_height + b = (i+1)*head_line_height + head[t:b, 0:w] += imagelib.get_text_image ( (head_line_height,w,c) , head_lines[i], color=[0.8]*c ) + + final = head + + if loss_history is not None: + if show_last_history_iters_count == 0: + loss_history_to_show = loss_history + else: + loss_history_to_show = loss_history[-show_last_history_iters_count:] + + lh_img = models.ModelBase.get_loss_history_preview(loss_history_to_show, iter, w, c) + final = np.concatenate ( [final, lh_img], axis=0 ) + + final = np.concatenate ( [final, selected_preview_rgb], axis=0 ) + final = np.clip(final, 0, 1) + + io.show_image( wnd_name, (final*255).astype(np.uint8) ) + is_showing = True + + key_events = io.get_key_events(wnd_name) + key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) + + if key == ord('\n') or key == ord('\r'): + s2c.put ( {'op': 'close'} ) + elif key == ord('s'): + s2c.put ( {'op': 'save'} ) + elif key == ord('b'): + s2c.put ( {'op': 'backup'} ) + elif key == ord('p'): + if not is_waiting_preview: + is_waiting_preview = True + s2c.put ( {'op': 'preview'} ) + elif key == ord('l'): + if show_last_history_iters_count == 0: + show_last_history_iters_count = 5000 + elif show_last_history_iters_count == 5000: + show_last_history_iters_count = 10000 + elif show_last_history_iters_count == 10000: + show_last_history_iters_count = 50000 + elif show_last_history_iters_count == 50000: + show_last_history_iters_count = 100000 + elif show_last_history_iters_count == 100000: + show_last_history_iters_count = 0 + update_preview = True + elif key == ord(' '): + selected_preview = (selected_preview + 1) % len(previews) + update_preview = True + + try: + io.process_messages(0.1) + except KeyboardInterrupt: + s2c.put ( {'op': 'close'} ) + + io.destroy_all_windows() \ No newline at end of file diff --git a/mainscripts/Util.py b/mainscripts/Util.py new file mode 100644 index 0000000000000000000000000000000000000000..4a51e537075f32dfd923849ea9c10cf2544d1b31 --- /dev/null +++ b/mainscripts/Util.py @@ -0,0 +1,161 @@ +import pickle +from pathlib import Path + +import cv2 + +from DFLIMG import * +from facelib import LandmarksProcessor, FaceType +from core.interact import interact as io +from core import pathex +from core.cv2ex import * + + +def save_faceset_metadata_folder(input_path): + input_path = Path(input_path) + + metadata_filepath = input_path / 'meta.dat' + + io.log_info (f"Saving metadata to {str(metadata_filepath)}\r\n") + + d = {} + for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Processing"): + filepath = Path(filepath) + dflimg = DFLIMG.load (filepath) + if dflimg is None or not dflimg.has_data(): + io.log_info(f"{filepath} is not a dfl image file") + continue + + dfl_dict = dflimg.get_dict() + d[filepath.name] = ( dflimg.get_shape(), dfl_dict ) + + try: + with open(metadata_filepath, "wb") as f: + f.write ( pickle.dumps(d) ) + except: + raise Exception( 'cannot save %s' % (filename) ) + + io.log_info("Now you can edit images.") + io.log_info("!!! Keep same filenames in the folder.") + io.log_info("You can change size of images, restoring process will downscale back to original size.") + io.log_info("After that, use restore metadata.") + +def restore_faceset_metadata_folder(input_path): + input_path = Path(input_path) + + metadata_filepath = input_path / 'meta.dat' + io.log_info (f"Restoring metadata from {str(metadata_filepath)}.\r\n") + + if not metadata_filepath.exists(): + io.log_err(f"Unable to find {str(metadata_filepath)}.") + + try: + with open(metadata_filepath, "rb") as f: + d = pickle.loads(f.read()) + except: + raise FileNotFoundError(filename) + + for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path, image_extensions=['.jpg'], return_Path_class=True), "Processing"): + saved_data = d.get(filepath.name, None) + if saved_data is None: + io.log_info(f"No saved metadata for {filepath}") + continue + + shape, dfl_dict = saved_data + + img = cv2_imread (filepath) + if img.shape != shape: + img = cv2.resize (img, (shape[1], shape[0]), interpolation=cv2.INTER_LANCZOS4 ) + + cv2_imwrite (str(filepath), img, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) + + if filepath.suffix == '.jpg': + dflimg = DFLJPG.load(filepath) + dflimg.set_dict(dfl_dict) + dflimg.save() + else: + continue + + metadata_filepath.unlink() + +def add_landmarks_debug_images(input_path): + io.log_info ("Adding landmarks debug images...") + + for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Processing"): + filepath = Path(filepath) + + img = cv2_imread(str(filepath)) + + dflimg = DFLIMG.load (filepath) + + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") + continue + + if img is not None: + face_landmarks = dflimg.get_landmarks() + face_type = FaceType.fromString ( dflimg.get_face_type() ) + + if face_type == FaceType.MARK_ONLY: + rect = dflimg.get_source_rect() + LandmarksProcessor.draw_rect_landmarks(img, rect, face_landmarks, FaceType.FULL ) + else: + LandmarksProcessor.draw_landmarks(img, face_landmarks, transparent_mask=True ) + + + + output_file = '{}{}'.format( str(Path(str(input_path)) / filepath.stem), '_debug.jpg') + cv2_imwrite(output_file, img, [int(cv2.IMWRITE_JPEG_QUALITY), 50] ) + +def recover_original_aligned_filename(input_path): + io.log_info ("Recovering original aligned filename...") + + files = [] + for filepath in io.progress_bar_generator( pathex.get_image_paths(input_path), "Processing"): + filepath = Path(filepath) + + dflimg = DFLIMG.load (filepath) + + if dflimg is None or not dflimg.has_data(): + io.log_err (f"{filepath.name} is not a dfl image file") + continue + + files += [ [filepath, None, dflimg.get_source_filename(), False] ] + + files_len = len(files) + for i in io.progress_bar_generator( range(files_len), "Sorting" ): + fp, _, sf, converted = files[i] + + if converted: + continue + + sf_stem = Path(sf).stem + + files[i][1] = fp.parent / ( sf_stem + '_0' + fp.suffix ) + files[i][3] = True + c = 1 + + for j in range(i+1, files_len): + fp_j, _, sf_j, converted_j = files[j] + if converted_j: + continue + + if sf_j == sf: + files[j][1] = fp_j.parent / ( sf_stem + ('_%d' % (c)) + fp_j.suffix ) + files[j][3] = True + c += 1 + + for file in io.progress_bar_generator( files, "Renaming", leave=False ): + fs, _, _, _ = file + dst = fs.parent / ( fs.stem + '_tmp' + fs.suffix ) + try: + fs.rename (dst) + except: + io.log_err ('fail to rename %s' % (fs.name) ) + + for file in io.progress_bar_generator( files, "Renaming" ): + fs, fd, _, _ = file + fs = fs.parent / ( fs.stem + '_tmp' + fs.suffix ) + try: + fs.rename (fd) + except: + io.log_err ('fail to rename %s' % (fs.name) ) diff --git a/mainscripts/VideoEd.py b/mainscripts/VideoEd.py new file mode 100644 index 0000000000000000000000000000000000000000..f9fcedd13c03b88ad166b48b7a857a04f5a78ca1 --- /dev/null +++ b/mainscripts/VideoEd.py @@ -0,0 +1,270 @@ +import subprocess +import numpy as np +import ffmpeg +from pathlib import Path +from core import pathex +from core.interact import interact as io + +def extract_video(input_file, output_dir, output_ext=None, fps=None): + input_file_path = Path(input_file) + output_path = Path(output_dir) + + if not output_path.exists(): + output_path.mkdir(exist_ok=True) + + + if input_file_path.suffix == '.*': + input_file_path = pathex.get_first_file_by_stem (input_file_path.parent, input_file_path.stem) + else: + if not input_file_path.exists(): + input_file_path = None + + if input_file_path is None: + io.log_err("input_file not found.") + return + + if fps is None: + fps = io.input_int ("Enter FPS", 0, help_message="How many frames of every second of the video will be extracted. 0 - full fps") + + if output_ext is None: + output_ext = io.input_str ("Output image format", "png", ["png","jpg"], help_message="png is lossless, but extraction is x10 slower for HDD, requires x10 more disk space than jpg.") + + for filename in pathex.get_image_paths (output_path, ['.'+output_ext]): + Path(filename).unlink() + + job = ffmpeg.input(str(input_file_path)) + + kwargs = {'pix_fmt': 'rgb24'} + if fps != 0: + kwargs.update ({'r':str(fps)}) + + if output_ext == 'jpg': + kwargs.update ({'q:v':'2'}) #highest quality for jpg + + job = job.output( str (output_path / ('%5d.'+output_ext)), **kwargs ) + + try: + job = job.run() + except: + io.log_err ("ffmpeg fail, job commandline:" + str(job.compile()) ) + +def cut_video ( input_file, from_time=None, to_time=None, audio_track_id=None, bitrate=None): + input_file_path = Path(input_file) + if input_file_path is None: + io.log_err("input_file not found.") + return + + output_file_path = input_file_path.parent / (input_file_path.stem + "_cut" + input_file_path.suffix) + + if from_time is None: + from_time = io.input_str ("From time", "00:00:00.000") + + if to_time is None: + to_time = io.input_str ("To time", "00:00:00.000") + + if audio_track_id is None: + audio_track_id = io.input_int ("Specify audio track id.", 0) + + if bitrate is None: + bitrate = max (1, io.input_int ("Bitrate of output file in MB/s", 25) ) + + kwargs = {"c:v": "libx264", + "b:v": "%dM" %(bitrate), + "pix_fmt": "yuv420p", + } + + job = ffmpeg.input(str(input_file_path), ss=from_time, to=to_time) + + job_v = job['v:0'] + job_a = job['a:' + str(audio_track_id) + '?' ] + + job = ffmpeg.output(job_v, job_a, str(output_file_path), **kwargs).overwrite_output() + + try: + job = job.run() + except: + io.log_err ("ffmpeg fail, job commandline:" + str(job.compile()) ) + +def denoise_image_sequence( input_dir, ext=None, factor=None ): + input_path = Path(input_dir) + + if not input_path.exists(): + io.log_err("input_dir not found.") + return + + image_paths = [ Path(filepath) for filepath in pathex.get_image_paths(input_path) ] + + # Check extension of all images + image_paths_suffix = None + for filepath in image_paths: + if image_paths_suffix is None: + image_paths_suffix = filepath.suffix + else: + if filepath.suffix != image_paths_suffix: + io.log_err(f"All images in {input_path.name} should be with the same extension.") + return + + if factor is None: + factor = np.clip ( io.input_int ("Denoise factor?", 7, add_info="1-20"), 1, 20 ) + + # Rename to temporary filenames + for i,filepath in io.progress_bar_generator( enumerate(image_paths), "Renaming", leave=False): + src = filepath + dst = filepath.parent / ( f'{i+1:06}_{filepath.name}' ) + try: + src.rename (dst) + except: + io.log_error ('fail to rename %s' % (src.name) ) + return + + # Rename to sequental filenames + for i,filepath in io.progress_bar_generator( enumerate(image_paths), "Renaming", leave=False): + + src = filepath.parent / ( f'{i+1:06}_{filepath.name}' ) + dst = filepath.parent / ( f'{i+1:06}{filepath.suffix}' ) + try: + src.rename (dst) + except: + io.log_error ('fail to rename %s' % (src.name) ) + return + + # Process image sequence in ffmpeg + kwargs = {} + if image_paths_suffix == '.jpg': + kwargs.update ({'q:v':'2'}) + + job = ( ffmpeg + .input(str ( input_path / ('%6d'+image_paths_suffix) ) ) + .filter("hqdn3d", factor, factor, 5,5) + .output(str ( input_path / ('%6d'+image_paths_suffix) ), **kwargs ) + ) + + try: + job = job.run() + except: + io.log_err ("ffmpeg fail, job commandline:" + str(job.compile()) ) + + # Rename to temporary filenames + for i,filepath in io.progress_bar_generator( enumerate(image_paths), "Renaming", leave=False): + src = filepath.parent / ( f'{i+1:06}{filepath.suffix}' ) + dst = filepath.parent / ( f'{i+1:06}_{filepath.name}' ) + try: + src.rename (dst) + except: + io.log_error ('fail to rename %s' % (src.name) ) + return + + # Rename to initial filenames + for i,filepath in io.progress_bar_generator( enumerate(image_paths), "Renaming", leave=False): + src = filepath.parent / ( f'{i+1:06}_{filepath.name}' ) + dst = filepath + + try: + src.rename (dst) + except: + io.log_error ('fail to rename %s' % (src.name) ) + return + +def video_from_sequence( input_dir, output_file, reference_file=None, ext=None, fps=None, bitrate=None, include_audio=False, lossless=None ): + input_path = Path(input_dir) + output_file_path = Path(output_file) + reference_file_path = Path(reference_file) if reference_file is not None else None + + if not input_path.exists(): + io.log_err("input_dir not found.") + return + + if not output_file_path.parent.exists(): + output_file_path.parent.mkdir(parents=True, exist_ok=True) + return + + out_ext = output_file_path.suffix + + if ext is None: + ext = io.input_str ("Input image format (extension)", "png") + + if lossless is None: + lossless = io.input_bool ("Use lossless codec", False) + + video_id = None + audio_id = None + ref_in_a = None + if reference_file_path is not None: + if reference_file_path.suffix == '.*': + reference_file_path = pathex.get_first_file_by_stem (reference_file_path.parent, reference_file_path.stem) + else: + if not reference_file_path.exists(): + reference_file_path = None + + if reference_file_path is None: + io.log_err("reference_file not found.") + return + + #probing reference file + probe = ffmpeg.probe (str(reference_file_path)) + + #getting first video and audio streams id with fps + for stream in probe['streams']: + if video_id is None and stream['codec_type'] == 'video': + video_id = stream['index'] + fps = stream['r_frame_rate'] + + if audio_id is None and stream['codec_type'] == 'audio': + audio_id = stream['index'] + + if audio_id is not None: + #has audio track + ref_in_a = ffmpeg.input (str(reference_file_path))[str(audio_id)] + + if fps is None: + #if fps not specified and not overwritten by reference-file + fps = max (1, io.input_int ("Enter FPS", 25) ) + + if not lossless and bitrate is None: + bitrate = max (1, io.input_int ("Bitrate of output file in MB/s", 16) ) + + input_image_paths = pathex.get_image_paths(input_path) + + i_in = ffmpeg.input('pipe:', format='image2pipe', r=fps) + + output_args = [i_in] + + if include_audio and ref_in_a is not None: + output_args += [ref_in_a] + + output_args += [str (output_file_path)] + + output_kwargs = {} + + if lossless: + output_kwargs.update ({"c:v": "libx264", + "crf": "0", + "pix_fmt": "yuv420p", + }) + else: + output_kwargs.update ({"c:v": "libx264", + "b:v": "%dM" %(bitrate), + "pix_fmt": "yuv420p", + }) + + if include_audio and ref_in_a is not None: + output_kwargs.update ({"c:a": "aac", + "b:a": "192k", + "ar" : "48000", + "strict": "experimental" + }) + + job = ( ffmpeg.output(*output_args, **output_kwargs).overwrite_output() ) + + try: + job_run = job.run_async(pipe_stdin=True) + + for image_path in input_image_paths: + with open (image_path, "rb") as f: + image_bytes = f.read() + job_run.stdin.write (image_bytes) + + job_run.stdin.close() + job_run.wait() + except: + io.log_err ("ffmpeg fail, job commandline:" + str(job.compile()) ) diff --git a/mainscripts/XSegUtil.py b/mainscripts/XSegUtil.py new file mode 100644 index 0000000000000000000000000000000000000000..c75a14a32a2b56e1875b20e78098f85903849de5 --- /dev/null +++ b/mainscripts/XSegUtil.py @@ -0,0 +1,187 @@ +import json +import shutil +import traceback +from pathlib import Path + +import numpy as np + +from core import pathex +from core.cv2ex import * +from core.interact import interact as io +from core.leras import nn +from DFLIMG import * +from facelib import XSegNet, LandmarksProcessor, FaceType +import pickle + +def apply_xseg(input_path, model_path): + if not input_path.exists(): + raise ValueError(f'{input_path} not found. Please ensure it exists.') + + if not model_path.exists(): + raise ValueError(f'{model_path} not found. Please ensure it exists.') + + face_type = None + + model_dat = model_path / 'XSeg_data.dat' + if model_dat.exists(): + dat = pickle.loads( model_dat.read_bytes() ) + dat_options = dat.get('options', None) + if dat_options is not None: + face_type = dat_options.get('face_type', None) + + + + if face_type is None: + face_type = io.input_str ("XSeg model face type", 'same', ['h','mf','f','wf','head','same'], help_message="Specify face type of trained XSeg model. For example if XSeg model trained as WF, but faceset is HEAD, specify WF to apply xseg only on WF part of HEAD. Default is 'same'").lower() + if face_type == 'same': + face_type = None + + if face_type is not None: + face_type = {'h' : FaceType.HALF, + 'mf' : FaceType.MID_FULL, + 'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[face_type] + + io.log_info(f'Applying trained XSeg model to {input_path.name}/ folder.') + + device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True) + nn.initialize(device_config) + + + + xseg = XSegNet(name='XSeg', + load_weights=True, + weights_file_root=model_path, + data_format=nn.data_format, + raise_on_no_model_files=True) + xseg_res = xseg.get_resolution() + + images_paths = pathex.get_image_paths(input_path, return_Path_class=True) + + for filepath in io.progress_bar_generator(images_paths, "Processing"): + dflimg = DFLIMG.load(filepath) + if dflimg is None or not dflimg.has_data(): + io.log_info(f'{filepath} is not a DFLIMG') + continue + + img = cv2_imread(filepath).astype(np.float32) / 255.0 + h,w,c = img.shape + + img_face_type = FaceType.fromString( dflimg.get_face_type() ) + if face_type is not None and img_face_type != face_type: + lmrks = dflimg.get_source_landmarks() + + fmat = LandmarksProcessor.get_transform_mat(lmrks, w, face_type) + imat = LandmarksProcessor.get_transform_mat(lmrks, w, img_face_type) + + g_p = LandmarksProcessor.transform_points (np.float32([(0,0),(w,0),(0,w) ]), fmat, True) + g_p2 = LandmarksProcessor.transform_points (g_p, imat) + + mat = cv2.getAffineTransform( g_p2, np.float32([(0,0),(w,0),(0,w) ]) ) + + img = cv2.warpAffine(img, mat, (w, w), cv2.INTER_LANCZOS4) + img = cv2.resize(img, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4) + else: + if w != xseg_res: + img = cv2.resize( img, (xseg_res,xseg_res), interpolation=cv2.INTER_LANCZOS4 ) + + if len(img.shape) == 2: + img = img[...,None] + + mask = xseg.extract(img) + + if face_type is not None and img_face_type != face_type: + mask = cv2.resize(mask, (w, w), interpolation=cv2.INTER_LANCZOS4) + mask = cv2.warpAffine( mask, mat, (w,w), np.zeros( (h,w,c), dtype=np.float), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4) + mask = cv2.resize(mask, (xseg_res, xseg_res), interpolation=cv2.INTER_LANCZOS4) + mask[mask < 0.5]=0 + mask[mask >= 0.5]=1 + dflimg.set_xseg_mask(mask) + dflimg.save() + + + +def fetch_xseg(input_path): + if not input_path.exists(): + raise ValueError(f'{input_path} not found. Please ensure it exists.') + + output_path = input_path.parent / (input_path.name + '_xseg') + output_path.mkdir(exist_ok=True, parents=True) + + io.log_info(f'Copying faces containing XSeg polygons to {output_path.name}/ folder.') + + images_paths = pathex.get_image_paths(input_path, return_Path_class=True) + + + files_copied = [] + for filepath in io.progress_bar_generator(images_paths, "Processing"): + dflimg = DFLIMG.load(filepath) + if dflimg is None or not dflimg.has_data(): + io.log_info(f'{filepath} is not a DFLIMG') + continue + + ie_polys = dflimg.get_seg_ie_polys() + + if ie_polys.has_polys(): + files_copied.append(filepath) + shutil.copy ( str(filepath), str(output_path / filepath.name) ) + + io.log_info(f'Files copied: {len(files_copied)}') + + is_delete = io.input_bool (f"\r\nDelete original files?", True) + if is_delete: + for filepath in files_copied: + Path(filepath).unlink() + + +def remove_xseg(input_path): + if not input_path.exists(): + raise ValueError(f'{input_path} not found. Please ensure it exists.') + + io.log_info(f'Processing folder {input_path}') + io.log_info('!!! WARNING : APPLIED XSEG MASKS WILL BE REMOVED FROM THE FRAMES !!!') + io.log_info('!!! WARNING : APPLIED XSEG MASKS WILL BE REMOVED FROM THE FRAMES !!!') + io.log_info('!!! WARNING : APPLIED XSEG MASKS WILL BE REMOVED FROM THE FRAMES !!!') + io.input_str('Press enter to continue.') + + images_paths = pathex.get_image_paths(input_path, return_Path_class=True) + + files_processed = 0 + for filepath in io.progress_bar_generator(images_paths, "Processing"): + dflimg = DFLIMG.load(filepath) + if dflimg is None or not dflimg.has_data(): + io.log_info(f'{filepath} is not a DFLIMG') + continue + + if dflimg.has_xseg_mask(): + dflimg.set_xseg_mask(None) + dflimg.save() + files_processed += 1 + io.log_info(f'Files processed: {files_processed}') + +def remove_xseg_labels(input_path): + if not input_path.exists(): + raise ValueError(f'{input_path} not found. Please ensure it exists.') + + io.log_info(f'Processing folder {input_path}') + io.log_info('!!! WARNING : LABELED XSEG POLYGONS WILL BE REMOVED FROM THE FRAMES !!!') + io.log_info('!!! WARNING : LABELED XSEG POLYGONS WILL BE REMOVED FROM THE FRAMES !!!') + io.log_info('!!! WARNING : LABELED XSEG POLYGONS WILL BE REMOVED FROM THE FRAMES !!!') + io.input_str('Press enter to continue.') + + images_paths = pathex.get_image_paths(input_path, return_Path_class=True) + + files_processed = 0 + for filepath in io.progress_bar_generator(images_paths, "Processing"): + dflimg = DFLIMG.load(filepath) + if dflimg is None or not dflimg.has_data(): + io.log_info(f'{filepath} is not a DFLIMG') + continue + + if dflimg.has_seg_ie_polys(): + dflimg.set_seg_ie_polys(None) + dflimg.save() + files_processed += 1 + + io.log_info(f'Files processed: {files_processed}') \ No newline at end of file diff --git a/mainscripts/dev_misc.py b/mainscripts/dev_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..93a4359afe00b8cd34ce76d19d094e8a5a06b1eb --- /dev/null +++ b/mainscripts/dev_misc.py @@ -0,0 +1,594 @@ +import traceback +import json +import multiprocessing +import shutil +from pathlib import Path +import cv2 +import numpy as np + +from core import imagelib, pathex +from core.cv2ex import * +from core.interact import interact as io +from core.joblib import Subprocessor +from core.leras import nn +from DFLIMG import * +from facelib import FaceType, LandmarksProcessor +from . import Extractor, Sorter +from .Extractor import ExtractSubprocessor + + +def extract_vggface2_dataset(input_dir, device_args={} ): + multi_gpu = device_args.get('multi_gpu', False) + cpu_only = device_args.get('cpu_only', False) + + input_path = Path(input_dir) + if not input_path.exists(): + raise ValueError('Input directory not found. Please ensure it exists.') + + bb_csv = input_path / 'loose_bb_train.csv' + if not bb_csv.exists(): + raise ValueError('loose_bb_train.csv found. Please ensure it exists.') + + bb_lines = bb_csv.read_text().split('\n') + bb_lines.pop(0) + + bb_dict = {} + for line in bb_lines: + name, l, t, w, h = line.split(',') + name = name[1:-1] + l, t, w, h = [ int(x) for x in (l, t, w, h) ] + bb_dict[name] = (l,t,w, h) + + + output_path = input_path.parent / (input_path.name + '_out') + + dir_names = pathex.get_all_dir_names(input_path) + + if not output_path.exists(): + output_path.mkdir(parents=True, exist_ok=True) + + data = [] + for dir_name in io.progress_bar_generator(dir_names, "Collecting"): + cur_input_path = input_path / dir_name + cur_output_path = output_path / dir_name + + if not cur_output_path.exists(): + cur_output_path.mkdir(parents=True, exist_ok=True) + + input_path_image_paths = pathex.get_image_paths(cur_input_path) + + for filename in input_path_image_paths: + filename_path = Path(filename) + + name = filename_path.parent.name + '/' + filename_path.stem + if name not in bb_dict: + continue + + l,t,w,h = bb_dict[name] + if min(w,h) < 128: + continue + + data += [ ExtractSubprocessor.Data(filename=filename,rects=[ (l,t,l+w,t+h) ], landmarks_accurate=False, force_output_path=cur_output_path ) ] + + face_type = FaceType.fromString('full_face') + + io.log_info ('Performing 2nd pass...') + data = ExtractSubprocessor (data, 'landmarks', 256, face_type, debug_dir=None, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False).run() + + io.log_info ('Performing 3rd pass...') + ExtractSubprocessor (data, 'final', 256, face_type, debug_dir=None, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, final_output_path=None).run() + + +""" + import code + code.interact(local=dict(globals(), **locals())) + + data_len = len(data) + i = 0 + while i < data_len-1: + i_name = Path(data[i].filename).parent.name + + sub_data = [] + + for j in range (i, data_len): + j_name = Path(data[j].filename).parent.name + if i_name == j_name: + sub_data += [ data[j] ] + else: + break + i = j + + cur_output_path = output_path / i_name + + io.log_info (f"Processing: {str(cur_output_path)}, {i}/{data_len} ") + + if not cur_output_path.exists(): + cur_output_path.mkdir(parents=True, exist_ok=True) + + + + + + + + + for dir_name in dir_names: + + cur_input_path = input_path / dir_name + cur_output_path = output_path / dir_name + + input_path_image_paths = pathex.get_image_paths(cur_input_path) + l = len(input_path_image_paths) + #if l < 250 or l > 350: + # continue + + io.log_info (f"Processing: {str(cur_input_path)} ") + + if not cur_output_path.exists(): + cur_output_path.mkdir(parents=True, exist_ok=True) + + + data = [] + for filename in input_path_image_paths: + filename_path = Path(filename) + + name = filename_path.parent.name + '/' + filename_path.stem + if name not in bb_dict: + continue + + bb = bb_dict[name] + l,t,w,h = bb + if min(w,h) < 128: + continue + + data += [ ExtractSubprocessor.Data(filename=filename,rects=[ (l,t,l+w,t+h) ], landmarks_accurate=False ) ] + + + + io.log_info ('Performing 2nd pass...') + data = ExtractSubprocessor (data, 'landmarks', 256, face_type, debug_dir=None, multi_gpu=False, cpu_only=False, manual=False).run() + + io.log_info ('Performing 3rd pass...') + data = ExtractSubprocessor (data, 'final', 256, face_type, debug_dir=None, multi_gpu=False, cpu_only=False, manual=False, final_output_path=cur_output_path).run() + + + io.log_info (f"Sorting: {str(cur_output_path)} ") + Sorter.main (input_path=str(cur_output_path), sort_by_method='hist') + + import code + code.interact(local=dict(globals(), **locals())) + + #try: + # io.log_info (f"Removing: {str(cur_input_path)} ") + # shutil.rmtree(cur_input_path) + #except: + # io.log_info (f"unable to remove: {str(cur_input_path)} ") + + + + +def extract_vggface2_dataset(input_dir, device_args={} ): + multi_gpu = device_args.get('multi_gpu', False) + cpu_only = device_args.get('cpu_only', False) + + input_path = Path(input_dir) + if not input_path.exists(): + raise ValueError('Input directory not found. Please ensure it exists.') + + output_path = input_path.parent / (input_path.name + '_out') + + dir_names = pathex.get_all_dir_names(input_path) + + if not output_path.exists(): + output_path.mkdir(parents=True, exist_ok=True) + + + + for dir_name in dir_names: + + cur_input_path = input_path / dir_name + cur_output_path = output_path / dir_name + + l = len(pathex.get_image_paths(cur_input_path)) + if l < 250 or l > 350: + continue + + io.log_info (f"Processing: {str(cur_input_path)} ") + + if not cur_output_path.exists(): + cur_output_path.mkdir(parents=True, exist_ok=True) + + Extractor.main( str(cur_input_path), + str(cur_output_path), + detector='s3fd', + image_size=256, + face_type='full_face', + max_faces_from_image=1, + device_args=device_args ) + + io.log_info (f"Sorting: {str(cur_input_path)} ") + Sorter.main (input_path=str(cur_output_path), sort_by_method='hist') + + try: + io.log_info (f"Removing: {str(cur_input_path)} ") + shutil.rmtree(cur_input_path) + except: + io.log_info (f"unable to remove: {str(cur_input_path)} ") + +""" + +#unused in end user workflow +def dev_test_68(input_dir ): + # process 68 landmarks dataset with .pts files + input_path = Path(input_dir) + if not input_path.exists(): + raise ValueError('input_dir not found. Please ensure it exists.') + + output_path = input_path.parent / (input_path.name+'_aligned') + + io.log_info(f'Output dir is % {output_path}') + + if output_path.exists(): + output_images_paths = pathex.get_image_paths(output_path) + if len(output_images_paths) > 0: + io.input_bool("WARNING !!! \n %s contains files! \n They will be deleted. \n Press enter to continue." % (str(output_path)), False ) + for filename in output_images_paths: + Path(filename).unlink() + else: + output_path.mkdir(parents=True, exist_ok=True) + + images_paths = pathex.get_image_paths(input_path) + + for filepath in io.progress_bar_generator(images_paths, "Processing"): + filepath = Path(filepath) + + + pts_filepath = filepath.parent / (filepath.stem+'.pts') + if pts_filepath.exists(): + pts = pts_filepath.read_text() + pts_lines = pts.split('\n') + + lmrk_lines = None + for pts_line in pts_lines: + if pts_line == '{': + lmrk_lines = [] + elif pts_line == '}': + break + else: + if lmrk_lines is not None: + lmrk_lines.append (pts_line) + + if lmrk_lines is not None and len(lmrk_lines) == 68: + try: + lmrks = [ np.array ( lmrk_line.strip().split(' ') ).astype(np.float32).tolist() for lmrk_line in lmrk_lines] + except Exception as e: + print(e) + print(filepath) + continue + + rect = LandmarksProcessor.get_rect_from_landmarks(lmrks) + + output_filepath = output_path / (filepath.stem+'.jpg') + + img = cv2_imread(filepath) + img = imagelib.normalize_channels(img, 3) + cv2_imwrite(output_filepath, img, [int(cv2.IMWRITE_JPEG_QUALITY), 95] ) + + raise Exception("unimplemented") + #DFLJPG.x(output_filepath, face_type=FaceType.toString(FaceType.MARK_ONLY), + # landmarks=lmrks, + # source_filename=filepath.name, + # source_rect=rect, + # source_landmarks=lmrks + # ) + + io.log_info("Done.") + +#unused in end user workflow +def extract_umd_csv(input_file_csv, + face_type='full_face', + device_args={} ): + + #extract faces from umdfaces.io dataset csv file with pitch,yaw,roll info. + multi_gpu = device_args.get('multi_gpu', False) + cpu_only = device_args.get('cpu_only', False) + face_type = FaceType.fromString(face_type) + + input_file_csv_path = Path(input_file_csv) + if not input_file_csv_path.exists(): + raise ValueError('input_file_csv not found. Please ensure it exists.') + + input_file_csv_root_path = input_file_csv_path.parent + output_path = input_file_csv_path.parent / ('aligned_' + input_file_csv_path.name) + + io.log_info("Output dir is %s." % (str(output_path)) ) + + if output_path.exists(): + output_images_paths = pathex.get_image_paths(output_path) + if len(output_images_paths) > 0: + io.input_bool("WARNING !!! \n %s contains files! \n They will be deleted. \n Press enter to continue." % (str(output_path)), False ) + for filename in output_images_paths: + Path(filename).unlink() + else: + output_path.mkdir(parents=True, exist_ok=True) + + try: + with open( str(input_file_csv_path), 'r') as f: + csv_file = f.read() + except Exception as e: + io.log_err("Unable to open or read file " + str(input_file_csv_path) + ": " + str(e) ) + return + + strings = csv_file.split('\n') + keys = strings[0].split(',') + keys_len = len(keys) + csv_data = [] + for i in range(1, len(strings)): + values = strings[i].split(',') + if keys_len != len(values): + io.log_err("Wrong string in csv file, skipping.") + continue + + csv_data += [ { keys[n] : values[n] for n in range(keys_len) } ] + + data = [] + for d in csv_data: + filename = input_file_csv_root_path / d['FILE'] + + + x,y,w,h = float(d['FACE_X']), float(d['FACE_Y']), float(d['FACE_WIDTH']), float(d['FACE_HEIGHT']) + + data += [ ExtractSubprocessor.Data(filename=filename, rects=[ [x,y,x+w,y+h] ]) ] + + images_found = len(data) + faces_detected = 0 + if len(data) > 0: + io.log_info ("Performing 2nd pass from csv file...") + data = ExtractSubprocessor (data, 'landmarks', multi_gpu=multi_gpu, cpu_only=cpu_only).run() + + io.log_info ('Performing 3rd pass...') + data = ExtractSubprocessor (data, 'final', face_type, None, multi_gpu=multi_gpu, cpu_only=cpu_only, manual=False, final_output_path=output_path).run() + faces_detected += sum([d.faces_detected for d in data]) + + + io.log_info ('-------------------------') + io.log_info ('Images found: %d' % (images_found) ) + io.log_info ('Faces detected: %d' % (faces_detected) ) + io.log_info ('-------------------------') + + + +def dev_test1(input_dir): + # LaPa dataset + + image_size = 1024 + face_type = FaceType.HEAD + + input_path = Path(input_dir) + images_path = input_path / 'images' + if not images_path.exists: + raise ValueError('LaPa dataset: images folder not found.') + labels_path = input_path / 'labels' + if not labels_path.exists: + raise ValueError('LaPa dataset: labels folder not found.') + landmarks_path = input_path / 'landmarks' + if not landmarks_path.exists: + raise ValueError('LaPa dataset: landmarks folder not found.') + + output_path = input_path / 'out' + if output_path.exists(): + output_images_paths = pathex.get_image_paths(output_path) + if len(output_images_paths) != 0: + io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n") + for filename in output_images_paths: + Path(filename).unlink() + output_path.mkdir(parents=True, exist_ok=True) + + data = [] + + img_paths = pathex.get_image_paths (images_path) + for filename in img_paths: + filepath = Path(filename) + + landmark_filepath = landmarks_path / (filepath.stem + '.txt') + if not landmark_filepath.exists(): + raise ValueError(f'no landmarks for {filepath}') + + #img = cv2_imread(filepath) + + lm = landmark_filepath.read_text() + lm = lm.split('\n') + if int(lm[0]) != 106: + raise ValueError(f'wrong landmarks format in {landmark_filepath}') + + lmrks = [] + for i in range(106): + x,y = lm[i+1].split(' ') + x,y = float(x), float(y) + lmrks.append ( (x,y) ) + + lmrks = np.array(lmrks) + + l,t = np.min(lmrks, 0) + r,b = np.max(lmrks, 0) + + l,t,r,b = ( int(x) for x in (l,t,r,b) ) + + #for x, y in lmrks: + # x,y = int(x), int(y) + # cv2.circle(img, (x, y), 1, (0,255,0) , 1, lineType=cv2.LINE_AA) + + #imagelib.draw_rect(img, (l,t,r,b), (0,255,0) ) + + + data += [ ExtractSubprocessor.Data(filepath=filepath, rects=[ (l,t,r,b) ]) ] + + #cv2.imshow("", img) + #cv2.waitKey(0) + + if len(data) > 0: + device_config = nn.DeviceConfig.BestGPU() + + io.log_info ("Performing 2nd pass...") + data = ExtractSubprocessor (data, 'landmarks', image_size, 95, face_type, device_config=device_config).run() + io.log_info ("Performing 3rd pass...") + data = ExtractSubprocessor (data, 'final', image_size, 95, face_type, final_output_path=output_path, device_config=device_config).run() + + + for filename in pathex.get_image_paths (output_path): + filepath = Path(filename) + + + dflimg = DFLJPG.load(filepath) + + src_filename = dflimg.get_source_filename() + image_to_face_mat = dflimg.get_image_to_face_mat() + + label_filepath = labels_path / ( Path(src_filename).stem + '.png') + if not label_filepath.exists(): + raise ValueError(f'{label_filepath} does not exist') + + mask = cv2_imread(label_filepath) + #mask[mask == 10] = 0 # remove hair + mask[mask > 0] = 1 + mask = cv2.warpAffine(mask, image_to_face_mat, (image_size, image_size), cv2.INTER_LINEAR) + mask = cv2.blur(mask, (3,3) ) + + #cv2.imshow("", (mask*255).astype(np.uint8) ) + #cv2.waitKey(0) + + dflimg.set_xseg_mask(mask) + dflimg.save() + + + import code + code.interact(local=dict(globals(), **locals())) + + +def dev_resave_pngs(input_dir): + input_path = Path(input_dir) + if not input_path.exists(): + raise ValueError('input_dir not found. Please ensure it exists.') + + images_paths = pathex.get_image_paths(input_path, image_extensions=['.png'], subdirs=True, return_Path_class=True) + + for filepath in io.progress_bar_generator(images_paths,"Processing"): + cv2_imwrite(filepath, cv2_imread(filepath)) + + +def dev_segmented_trash(input_dir): + input_path = Path(input_dir) + if not input_path.exists(): + raise ValueError('input_dir not found. Please ensure it exists.') + + output_path = input_path.parent / (input_path.name+'_trash') + output_path.mkdir(parents=True, exist_ok=True) + + images_paths = pathex.get_image_paths(input_path, return_Path_class=True) + + trash_paths = [] + for filepath in images_paths: + json_file = filepath.parent / (filepath.stem +'.json') + if not json_file.exists(): + trash_paths.append(filepath) + + for filepath in trash_paths: + + try: + filepath.rename ( output_path / filepath.name ) + except: + io.log_info ('fail to trashing %s' % (src.name) ) + + + +def dev_test(input_dir): + """ + extract FaceSynthetics dataset https://github.com/microsoft/FaceSynthetics + + BACKGROUND = 0 + SKIN = 1 + NOSE = 2 + RIGHT_EYE = 3 + LEFT_EYE = 4 + RIGHT_BROW = 5 + LEFT_BROW = 6 + RIGHT_EAR = 7 + LEFT_EAR = 8 + MOUTH_INTERIOR = 9 + TOP_LIP = 10 + BOTTOM_LIP = 11 + NECK = 12 + HAIR = 13 + BEARD = 14 + CLOTHING = 15 + GLASSES = 16 + HEADWEAR = 17 + FACEWEAR = 18 + IGNORE = 255 + """ + + + image_size = 1024 + face_type = FaceType.WHOLE_FACE + + input_path = Path(input_dir) + + + + output_path = input_path.parent / f'{input_path.name}_out' + if output_path.exists(): + output_images_paths = pathex.get_image_paths(output_path) + if len(output_images_paths) != 0: + io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n") + for filename in output_images_paths: + Path(filename).unlink() + output_path.mkdir(parents=True, exist_ok=True) + + data = [] + + for filepath in io.progress_bar_generator(pathex.get_paths(input_path), "Processing"): + if filepath.suffix == '.txt': + + image_filepath = filepath.parent / f'{filepath.name.split("_")[0]}.png' + if not image_filepath.exists(): + print(f'{image_filepath} does not exist, skipping') + + lmrks = [] + for lmrk_line in filepath.read_text().split('\n'): + if len(lmrk_line) == 0: + continue + + x, y = lmrk_line.split(' ') + x, y = float(x), float(y) + + lmrks.append( (x,y) ) + + lmrks = np.array(lmrks[:68], np.float32) + rect = LandmarksProcessor.get_rect_from_landmarks(lmrks) + data += [ ExtractSubprocessor.Data(filepath=image_filepath, rects=[rect], landmarks=[ lmrks ] ) ] + + if len(data) > 0: + io.log_info ("Performing 3rd pass...") + data = ExtractSubprocessor (data, 'final', image_size, 95, face_type, final_output_path=output_path, device_config=nn.DeviceConfig.CPU()).run() + + for filename in io.progress_bar_generator(pathex.get_image_paths (output_path), "Processing"): + filepath = Path(filename) + + dflimg = DFLJPG.load(filepath) + + src_filename = dflimg.get_source_filename() + image_to_face_mat = dflimg.get_image_to_face_mat() + + seg_filepath = input_path / ( Path(src_filename).stem + '_seg.png') + if not seg_filepath.exists(): + raise ValueError(f'{seg_filepath} does not exist') + + seg = cv2_imread(seg_filepath) + seg_inds = np.isin(seg, [1,2,3,4,5,6,9,10,11]) + seg[~seg_inds] = 0 + seg[seg_inds] = 1 + seg = seg.astype(np.float32) + seg = cv2.warpAffine(seg, image_to_face_mat, (image_size, image_size), cv2.INTER_LANCZOS4) + dflimg.set_xseg_mask(seg) + dflimg.save() + \ No newline at end of file diff --git a/merger/FrameInfo.py b/merger/FrameInfo.py new file mode 100644 index 0000000000000000000000000000000000000000..1b8ebb0e62ed4c0db5b139c913dab2bd4f8852c9 --- /dev/null +++ b/merger/FrameInfo.py @@ -0,0 +1,8 @@ +from pathlib import Path + +class FrameInfo(object): + def __init__(self, filepath=None, landmarks_list=None): + self.filepath = filepath + self.landmarks_list = landmarks_list or [] + self.motion_deg = 0 + self.motion_power = 0 \ No newline at end of file diff --git a/merger/InteractiveMergerSubprocessor.py b/merger/InteractiveMergerSubprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..58db0c1fc1edbde16d5e371d385f1ad60db68b43 --- /dev/null +++ b/merger/InteractiveMergerSubprocessor.py @@ -0,0 +1,574 @@ +import multiprocessing +import os +import pickle +import sys +import traceback +from pathlib import Path + +import numpy as np + +from core import imagelib, pathex +from core.cv2ex import * +from core.interact import interact as io +from core.joblib import Subprocessor +from merger import MergeFaceAvatar, MergeMasked, MergerConfig + +from .MergerScreen import Screen, ScreenManager + +MERGER_DEBUG = False +class InteractiveMergerSubprocessor(Subprocessor): + + class Frame(object): + def __init__(self, prev_temporal_frame_infos=None, + frame_info=None, + next_temporal_frame_infos=None): + self.prev_temporal_frame_infos = prev_temporal_frame_infos + self.frame_info = frame_info + self.next_temporal_frame_infos = next_temporal_frame_infos + self.output_filepath = None + self.output_mask_filepath = None + + self.idx = None + self.cfg = None + self.is_done = False + self.is_processing = False + self.is_shown = False + self.image = None + + class ProcessingFrame(object): + def __init__(self, idx=None, + cfg=None, + prev_temporal_frame_infos=None, + frame_info=None, + next_temporal_frame_infos=None, + output_filepath=None, + output_mask_filepath=None, + need_return_image = False): + self.idx = idx + self.cfg = cfg + self.prev_temporal_frame_infos = prev_temporal_frame_infos + self.frame_info = frame_info + self.next_temporal_frame_infos = next_temporal_frame_infos + self.output_filepath = output_filepath + self.output_mask_filepath = output_mask_filepath + + self.need_return_image = need_return_image + if self.need_return_image: + self.image = None + + class Cli(Subprocessor.Cli): + + #override + def on_initialize(self, client_dict): + self.log_info ('Running on %s.' % (client_dict['device_name']) ) + self.device_idx = client_dict['device_idx'] + self.device_name = client_dict['device_name'] + self.predictor_func = client_dict['predictor_func'] + self.predictor_input_shape = client_dict['predictor_input_shape'] + self.face_enhancer_func = client_dict['face_enhancer_func'] + self.xseg_256_extract_func = client_dict['xseg_256_extract_func'] + + + #transfer and set stdin in order to work code.interact in debug subprocess + stdin_fd = client_dict['stdin_fd'] + if stdin_fd is not None: + sys.stdin = os.fdopen(stdin_fd) + + return None + + #override + def process_data(self, pf): #pf=ProcessingFrame + cfg = pf.cfg.copy() + + frame_info = pf.frame_info + filepath = frame_info.filepath + + if len(frame_info.landmarks_list) == 0: + + if cfg.mode == 'raw-predict': + h,w,c = self.predictor_input_shape + img_bgr = np.zeros( (h,w,3), dtype=np.uint8) + img_mask = np.zeros( (h,w,1), dtype=np.uint8) + else: + self.log_info (f'no faces found for {filepath.name}, copying without faces') + img_bgr = cv2_imread(filepath) + imagelib.normalize_channels(img_bgr, 3) + h,w,c = img_bgr.shape + img_mask = np.zeros( (h,w,1), dtype=img_bgr.dtype) + + cv2_imwrite (pf.output_filepath, img_bgr) + cv2_imwrite (pf.output_mask_filepath, img_mask) + + if pf.need_return_image: + pf.image = np.concatenate ([img_bgr, img_mask], axis=-1) + + else: + if cfg.type == MergerConfig.TYPE_MASKED: + try: + final_img = MergeMasked (self.predictor_func, self.predictor_input_shape, + face_enhancer_func=self.face_enhancer_func, + xseg_256_extract_func=self.xseg_256_extract_func, + cfg=cfg, + frame_info=frame_info) + except Exception as e: + e_str = traceback.format_exc() + if 'MemoryError' in e_str: + raise Subprocessor.SilenceException + else: + raise Exception( f'Error while merging file [{filepath}]: {e_str}' ) + + elif cfg.type == MergerConfig.TYPE_FACE_AVATAR: + final_img = MergeFaceAvatar (self.predictor_func, self.predictor_input_shape, + cfg, pf.prev_temporal_frame_infos, + pf.frame_info, + pf.next_temporal_frame_infos ) + + cv2_imwrite (pf.output_filepath, final_img[...,0:3] ) + cv2_imwrite (pf.output_mask_filepath, final_img[...,3:4] ) + + if pf.need_return_image: + pf.image = final_img + + return pf + + #overridable + def get_data_name (self, pf): + #return string identificator of your data + return pf.frame_info.filepath + + + + + #override + def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter, subprocess_count=4): + if len (frames) == 0: + raise ValueError ("len (frames) == 0") + + super().__init__('Merger', InteractiveMergerSubprocessor.Cli, io_loop_sleep_time=0.001) + + self.is_interactive = is_interactive + self.merger_session_filepath = Path(merger_session_filepath) + self.merger_config = merger_config + + self.predictor_func = predictor_func + self.predictor_input_shape = predictor_input_shape + + self.face_enhancer_func = face_enhancer_func + self.xseg_256_extract_func = xseg_256_extract_func + + self.frames_root_path = frames_root_path + self.output_path = output_path + self.output_mask_path = output_mask_path + self.model_iter = model_iter + + self.prefetch_frame_count = self.process_count = subprocess_count + + session_data = None + if self.is_interactive and self.merger_session_filepath.exists(): + io.input_skip_pending() + if io.input_bool ("Use saved session?", True): + try: + with open( str(self.merger_session_filepath), "rb") as f: + session_data = pickle.loads(f.read()) + + except Exception as e: + pass + + rewind_to_frame_idx = None + self.frames = frames + self.frames_idxs = [ *range(len(self.frames)) ] + self.frames_done_idxs = [] + + if self.is_interactive and session_data is not None: + # Loaded session data, check it + s_frames = session_data.get('frames', None) + s_frames_idxs = session_data.get('frames_idxs', None) + s_frames_done_idxs = session_data.get('frames_done_idxs', None) + s_model_iter = session_data.get('model_iter', None) + + frames_equal = (s_frames is not None) and \ + (s_frames_idxs is not None) and \ + (s_frames_done_idxs is not None) and \ + (s_model_iter is not None) and \ + (len(frames) == len(s_frames)) # frames count must match + + if frames_equal: + for i in range(len(frames)): + frame = frames[i] + s_frame = s_frames[i] + # frames filenames must match + if frame.frame_info.filepath.name != s_frame.frame_info.filepath.name: + frames_equal = False + if not frames_equal: + break + + if frames_equal: + io.log_info ('Using saved session from ' + '/'.join (self.merger_session_filepath.parts[-2:]) ) + + for frame in s_frames: + if frame.cfg is not None: + # recreate MergerConfig class using constructor with get_config() as dict params + # so if any new param will be added, old merger session will work properly + frame.cfg = frame.cfg.__class__( **frame.cfg.get_config() ) + + self.frames = s_frames + self.frames_idxs = s_frames_idxs + self.frames_done_idxs = s_frames_done_idxs + + if self.model_iter != s_model_iter: + # model was more trained, recompute all frames + rewind_to_frame_idx = -1 + for frame in self.frames: + frame.is_done = False + elif len(self.frames_idxs) == 0: + # all frames are done? + rewind_to_frame_idx = -1 + + if len(self.frames_idxs) != 0: + cur_frame = self.frames[self.frames_idxs[0]] + cur_frame.is_shown = False + + if not frames_equal: + session_data = None + + if session_data is None: + for filename in pathex.get_image_paths(self.output_path): #remove all images in output_path + Path(filename).unlink() + + for filename in pathex.get_image_paths(self.output_mask_path): #remove all images in output_mask_path + Path(filename).unlink() + + + frames[0].cfg = self.merger_config.copy() + + for i in range( len(self.frames) ): + frame = self.frames[i] + frame.idx = i + frame.output_filepath = self.output_path / ( frame.frame_info.filepath.stem + '.png' ) + frame.output_mask_filepath = self.output_mask_path / ( frame.frame_info.filepath.stem + '.png' ) + + if not frame.output_filepath.exists() or \ + not frame.output_mask_filepath.exists(): + # if some frame does not exist, recompute and rewind + frame.is_done = False + frame.is_shown = False + + if rewind_to_frame_idx is None: + rewind_to_frame_idx = i-1 + else: + rewind_to_frame_idx = min(rewind_to_frame_idx, i-1) + + if rewind_to_frame_idx is not None: + while len(self.frames_done_idxs) > 0: + if self.frames_done_idxs[-1] > rewind_to_frame_idx: + prev_frame = self.frames[self.frames_done_idxs.pop()] + self.frames_idxs.insert(0, prev_frame.idx) + else: + break + #override + def process_info_generator(self): + r = [0] if MERGER_DEBUG else range(self.process_count) + + for i in r: + yield 'CPU%d' % (i), {}, {'device_idx': i, + 'device_name': 'CPU%d' % (i), + 'predictor_func': self.predictor_func, + 'predictor_input_shape' : self.predictor_input_shape, + 'face_enhancer_func': self.face_enhancer_func, + 'xseg_256_extract_func' : self.xseg_256_extract_func, + 'stdin_fd': sys.stdin.fileno() if MERGER_DEBUG else None + } + + #overridable optional + def on_clients_initialized(self): + io.progress_bar ("Merging", len(self.frames_idxs)+len(self.frames_done_idxs), initial=len(self.frames_done_idxs) ) + + self.process_remain_frames = not self.is_interactive + self.is_interactive_quitting = not self.is_interactive + + if self.is_interactive: + help_images = { + MergerConfig.TYPE_MASKED : cv2_imread ( str(Path(__file__).parent / 'gfx' / 'help_merger_masked.jpg') ), + MergerConfig.TYPE_FACE_AVATAR : cv2_imread ( str(Path(__file__).parent / 'gfx' / 'help_merger_face_avatar.jpg') ), + } + + self.main_screen = Screen(initial_scale_to_width=1368, image=None, waiting_icon=True) + self.help_screen = Screen(initial_scale_to_height=768, image=help_images[self.merger_config.type], waiting_icon=False) + self.screen_manager = ScreenManager( "Merger", [self.main_screen, self.help_screen], capture_keys=True ) + self.screen_manager.set_current (self.help_screen) + self.screen_manager.show_current() + + self.masked_keys_funcs = { + '`' : lambda cfg,shift_pressed: cfg.set_mode(0), + '1' : lambda cfg,shift_pressed: cfg.set_mode(1), + '2' : lambda cfg,shift_pressed: cfg.set_mode(2), + '3' : lambda cfg,shift_pressed: cfg.set_mode(3), + '4' : lambda cfg,shift_pressed: cfg.set_mode(4), + '5' : lambda cfg,shift_pressed: cfg.set_mode(5), + '6' : lambda cfg,shift_pressed: cfg.set_mode(6), + 'q' : lambda cfg,shift_pressed: cfg.add_hist_match_threshold(1 if not shift_pressed else 5), + 'a' : lambda cfg,shift_pressed: cfg.add_hist_match_threshold(-1 if not shift_pressed else -5), + 'w' : lambda cfg,shift_pressed: cfg.add_erode_mask_modifier(1 if not shift_pressed else 5), + 's' : lambda cfg,shift_pressed: cfg.add_erode_mask_modifier(-1 if not shift_pressed else -5), + 'e' : lambda cfg,shift_pressed: cfg.add_blur_mask_modifier(1 if not shift_pressed else 5), + 'd' : lambda cfg,shift_pressed: cfg.add_blur_mask_modifier(-1 if not shift_pressed else -5), + 'r' : lambda cfg,shift_pressed: cfg.add_motion_blur_power(1 if not shift_pressed else 5), + 'f' : lambda cfg,shift_pressed: cfg.add_motion_blur_power(-1 if not shift_pressed else -5), + 't' : lambda cfg,shift_pressed: cfg.add_super_resolution_power(1 if not shift_pressed else 5), + 'g' : lambda cfg,shift_pressed: cfg.add_super_resolution_power(-1 if not shift_pressed else -5), + 'y' : lambda cfg,shift_pressed: cfg.add_blursharpen_amount(1 if not shift_pressed else 5), + 'h' : lambda cfg,shift_pressed: cfg.add_blursharpen_amount(-1 if not shift_pressed else -5), + 'u' : lambda cfg,shift_pressed: cfg.add_output_face_scale(1 if not shift_pressed else 5), + 'j' : lambda cfg,shift_pressed: cfg.add_output_face_scale(-1 if not shift_pressed else -5), + 'i' : lambda cfg,shift_pressed: cfg.add_image_denoise_power(1 if not shift_pressed else 5), + 'k' : lambda cfg,shift_pressed: cfg.add_image_denoise_power(-1 if not shift_pressed else -5), + 'o' : lambda cfg,shift_pressed: cfg.add_bicubic_degrade_power(1 if not shift_pressed else 5), + 'l' : lambda cfg,shift_pressed: cfg.add_bicubic_degrade_power(-1 if not shift_pressed else -5), + 'p' : lambda cfg,shift_pressed: cfg.add_color_degrade_power(1 if not shift_pressed else 5), + ';' : lambda cfg,shift_pressed: cfg.add_color_degrade_power(-1), + ':' : lambda cfg,shift_pressed: cfg.add_color_degrade_power(-5), + 'z' : lambda cfg,shift_pressed: cfg.toggle_masked_hist_match(), + 'x' : lambda cfg,shift_pressed: cfg.toggle_mask_mode(), + 'c' : lambda cfg,shift_pressed: cfg.toggle_color_transfer_mode(), + 'n' : lambda cfg,shift_pressed: cfg.toggle_sharpen_mode(), + } + self.masked_keys = list(self.masked_keys_funcs.keys()) + + #overridable optional + def on_clients_finalized(self): + io.progress_bar_close() + + if self.is_interactive: + self.screen_manager.finalize() + + for frame in self.frames: + frame.output_filepath = None + frame.output_mask_filepath = None + frame.image = None + + session_data = { + 'frames': self.frames, + 'frames_idxs': self.frames_idxs, + 'frames_done_idxs': self.frames_done_idxs, + 'model_iter' : self.model_iter, + } + self.merger_session_filepath.write_bytes( pickle.dumps(session_data) ) + + io.log_info ("Session is saved to " + '/'.join (self.merger_session_filepath.parts[-2:]) ) + + #override + def on_tick(self): + io.process_messages() + + go_prev_frame = False + go_first_frame = False + go_prev_frame_overriding_cfg = False + go_first_frame_overriding_cfg = False + + go_next_frame = self.process_remain_frames + go_next_frame_overriding_cfg = False + go_last_frame_overriding_cfg = False + + cur_frame = None + if len(self.frames_idxs) != 0: + cur_frame = self.frames[self.frames_idxs[0]] + + if self.is_interactive: + + screen_image = None if self.process_remain_frames else \ + self.main_screen.get_image() + + self.main_screen.set_waiting_icon( self.process_remain_frames or \ + self.is_interactive_quitting ) + + if cur_frame is not None and not self.is_interactive_quitting: + + if not self.process_remain_frames: + if cur_frame.is_done: + if not cur_frame.is_shown: + if cur_frame.image is None: + image = cv2_imread (cur_frame.output_filepath, verbose=False) + image_mask = cv2_imread (cur_frame.output_mask_filepath, verbose=False) + if image is None or image_mask is None: + # unable to read? recompute then + cur_frame.is_done = False + else: + image = imagelib.normalize_channels(image, 3) + image_mask = imagelib.normalize_channels(image_mask, 1) + cur_frame.image = np.concatenate([image, image_mask], -1) + + if cur_frame.is_done: + io.log_info (cur_frame.cfg.to_string( cur_frame.frame_info.filepath.name) ) + cur_frame.is_shown = True + screen_image = cur_frame.image + else: + self.main_screen.set_waiting_icon(True) + + self.main_screen.set_image(screen_image) + self.screen_manager.show_current() + + key_events = self.screen_manager.get_key_events() + key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) + + if key == 9: #tab + self.screen_manager.switch_screens() + else: + if key == 27: #esc + self.is_interactive_quitting = True + elif self.screen_manager.get_current() is self.main_screen: + + if self.merger_config.type == MergerConfig.TYPE_MASKED and chr_key in self.masked_keys: + self.process_remain_frames = False + + if cur_frame is not None: + cfg = cur_frame.cfg + prev_cfg = cfg.copy() + + if cfg.type == MergerConfig.TYPE_MASKED: + self.masked_keys_funcs[chr_key](cfg, shift_pressed) + + if prev_cfg != cfg: + io.log_info ( cfg.to_string(cur_frame.frame_info.filepath.name) ) + cur_frame.is_done = False + cur_frame.is_shown = False + else: + + if chr_key == ',' or chr_key == 'm': + self.process_remain_frames = False + go_prev_frame = True + + if chr_key == ',': + if shift_pressed: + go_first_frame = True + + elif chr_key == 'm': + if not shift_pressed: + go_prev_frame_overriding_cfg = True + else: + go_first_frame_overriding_cfg = True + + elif chr_key == '.' or chr_key == '/': + self.process_remain_frames = False + go_next_frame = True + + if chr_key == '.': + if shift_pressed: + self.process_remain_frames = not self.process_remain_frames + + elif chr_key == '/': + if not shift_pressed: + go_next_frame_overriding_cfg = True + else: + go_last_frame_overriding_cfg = True + + elif chr_key == '-': + self.screen_manager.get_current().diff_scale(-0.1) + elif chr_key == '=': + self.screen_manager.get_current().diff_scale(0.1) + elif chr_key == 'v': + self.screen_manager.get_current().toggle_show_checker_board() + + if go_prev_frame: + if cur_frame is None or cur_frame.is_done: + if cur_frame is not None: + cur_frame.image = None + + while True: + if len(self.frames_done_idxs) > 0: + prev_frame = self.frames[self.frames_done_idxs.pop()] + self.frames_idxs.insert(0, prev_frame.idx) + prev_frame.is_shown = False + io.progress_bar_inc(-1) + + if cur_frame is not None and (go_prev_frame_overriding_cfg or go_first_frame_overriding_cfg): + if prev_frame.cfg != cur_frame.cfg: + prev_frame.cfg = cur_frame.cfg.copy() + prev_frame.is_done = False + + cur_frame = prev_frame + + if go_first_frame_overriding_cfg or go_first_frame: + if len(self.frames_done_idxs) > 0: + continue + break + + elif go_next_frame: + if cur_frame is not None and cur_frame.is_done: + cur_frame.image = None + cur_frame.is_shown = True + self.frames_done_idxs.append(cur_frame.idx) + self.frames_idxs.pop(0) + io.progress_bar_inc(1) + + f = self.frames + + if len(self.frames_idxs) != 0: + next_frame = f[ self.frames_idxs[0] ] + next_frame.is_shown = False + + if go_next_frame_overriding_cfg or go_last_frame_overriding_cfg: + + if go_next_frame_overriding_cfg: + to_frames = next_frame.idx+1 + else: + to_frames = len(f) + + for i in range( next_frame.idx, to_frames ): + f[i].cfg = None + + for i in range( min(len(self.frames_idxs), self.prefetch_frame_count) ): + frame = f[ self.frames_idxs[i] ] + if frame.cfg is None: + if i == 0: + frame.cfg = cur_frame.cfg.copy() + else: + frame.cfg = f[ self.frames_idxs[i-1] ].cfg.copy() + + frame.is_done = False #initiate solve again + frame.is_shown = False + + if len(self.frames_idxs) == 0: + self.process_remain_frames = False + + return (self.is_interactive and self.is_interactive_quitting) or \ + (not self.is_interactive and self.process_remain_frames == False) + + + #override + def on_data_return (self, host_dict, pf): + frame = self.frames[pf.idx] + frame.is_done = False + frame.is_processing = False + + #override + def on_result (self, host_dict, pf_sent, pf_result): + frame = self.frames[pf_result.idx] + frame.is_processing = False + if frame.cfg == pf_result.cfg: + frame.is_done = True + frame.image = pf_result.image + + #override + def get_data(self, host_dict): + if self.is_interactive and self.is_interactive_quitting: + return None + + for i in range ( min(len(self.frames_idxs), self.prefetch_frame_count) ): + frame = self.frames[ self.frames_idxs[i] ] + + if not frame.is_done and not frame.is_processing and frame.cfg is not None: + frame.is_processing = True + return InteractiveMergerSubprocessor.ProcessingFrame(idx=frame.idx, + cfg=frame.cfg.copy(), + prev_temporal_frame_infos=frame.prev_temporal_frame_infos, + frame_info=frame.frame_info, + next_temporal_frame_infos=frame.next_temporal_frame_infos, + output_filepath=frame.output_filepath, + output_mask_filepath=frame.output_mask_filepath, + need_return_image=True ) + + return None + + #override + def get_result(self): + return 0 \ No newline at end of file diff --git a/merger/MergeAvatar.py b/merger/MergeAvatar.py new file mode 100644 index 0000000000000000000000000000000000000000..cc59d2394e55d5e488a35d5c77a5ee9c53047a71 --- /dev/null +++ b/merger/MergeAvatar.py @@ -0,0 +1,41 @@ +import cv2 +import numpy as np + +from core import imagelib +from facelib import FaceType, LandmarksProcessor +from core.cv2ex import * + +def process_frame_info(frame_info, inp_sh): + img_uint8 = cv2_imread (frame_info.filename) + img_uint8 = imagelib.normalize_channels (img_uint8, 3) + img = img_uint8.astype(np.float32) / 255.0 + + img_mat = LandmarksProcessor.get_transform_mat (frame_info.landmarks_list[0], inp_sh[0], face_type=FaceType.FULL_NO_ALIGN) + img = cv2.warpAffine( img, img_mat, inp_sh[0:2], borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) + return img + +def MergeFaceAvatar (predictor_func, predictor_input_shape, cfg, prev_temporal_frame_infos, frame_info, next_temporal_frame_infos): + inp_sh = predictor_input_shape + + prev_imgs=[] + next_imgs=[] + for i in range(cfg.temporal_face_count): + prev_imgs.append( process_frame_info(prev_temporal_frame_infos[i], inp_sh) ) + next_imgs.append( process_frame_info(next_temporal_frame_infos[i], inp_sh) ) + img = process_frame_info(frame_info, inp_sh) + + prd_f = predictor_func ( prev_imgs, img, next_imgs ) + + #if cfg.super_resolution_mode != 0: + # prd_f = cfg.superres_func(cfg.super_resolution_mode, prd_f) + + if cfg.sharpen_mode != 0 and cfg.sharpen_amount != 0: + prd_f = cfg.sharpen_func ( prd_f, cfg.sharpen_mode, 3, cfg.sharpen_amount) + + out_img = np.clip(prd_f, 0.0, 1.0) + + if cfg.add_source_image: + out_img = np.concatenate ( [cv2.resize ( img, (prd_f.shape[1], prd_f.shape[0]) ), + out_img], axis=1 ) + + return (out_img*255).astype(np.uint8) diff --git a/merger/MergeMasked.py b/merger/MergeMasked.py new file mode 100644 index 0000000000000000000000000000000000000000..0a5c633aa631a4ab4b5fff567be9ad5d30339a18 --- /dev/null +++ b/merger/MergeMasked.py @@ -0,0 +1,348 @@ +import sys +import traceback + +import cv2 +import numpy as np + +from core import imagelib +from core.cv2ex import * +from core.interact import interact as io +from facelib import FaceType, LandmarksProcessor + +is_windows = sys.platform[0:3] == 'win' +xseg_input_size = 256 + +def MergeMaskedFace (predictor_func, predictor_input_shape, + face_enhancer_func, + xseg_256_extract_func, + cfg, frame_info, img_bgr_uint8, img_bgr, img_face_landmarks): + + img_size = img_bgr.shape[1], img_bgr.shape[0] + img_face_mask_a = LandmarksProcessor.get_image_hull_mask (img_bgr.shape, img_face_landmarks) + + input_size = predictor_input_shape[0] + mask_subres_size = input_size*4 + output_size = input_size + if cfg.super_resolution_power != 0: + output_size *= 4 + + face_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, output_size, face_type=cfg.face_type) + face_output_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, output_size, face_type=cfg.face_type, scale= 1.0 + 0.01*cfg.output_face_scale) + + if mask_subres_size == output_size: + face_mask_output_mat = face_output_mat + else: + face_mask_output_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, mask_subres_size, face_type=cfg.face_type, scale= 1.0 + 0.01*cfg.output_face_scale) + + dst_face_bgr = cv2.warpAffine( img_bgr , face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC ) + dst_face_bgr = np.clip(dst_face_bgr, 0, 1) + + dst_face_mask_a_0 = cv2.warpAffine( img_face_mask_a, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC ) + dst_face_mask_a_0 = np.clip(dst_face_mask_a_0, 0, 1) + + predictor_input_bgr = cv2.resize (dst_face_bgr, (input_size,input_size) ) + + predicted = predictor_func (predictor_input_bgr) + prd_face_bgr = np.clip (predicted[0], 0, 1.0) + prd_face_mask_a_0 = np.clip (predicted[1], 0, 1.0) + prd_face_dst_mask_a_0 = np.clip (predicted[2], 0, 1.0) + + if cfg.super_resolution_power != 0: + prd_face_bgr_enhanced = face_enhancer_func(prd_face_bgr, is_tanh=True, preserve_size=False) + mod = cfg.super_resolution_power / 100.0 + prd_face_bgr = cv2.resize(prd_face_bgr, (output_size,output_size))*(1.0-mod) + prd_face_bgr_enhanced*mod + prd_face_bgr = np.clip(prd_face_bgr, 0, 1) + + if cfg.super_resolution_power != 0: + prd_face_mask_a_0 = cv2.resize (prd_face_mask_a_0, (output_size, output_size), interpolation=cv2.INTER_CUBIC) + prd_face_dst_mask_a_0 = cv2.resize (prd_face_dst_mask_a_0, (output_size, output_size), interpolation=cv2.INTER_CUBIC) + + if cfg.mask_mode == 0: #full + wrk_face_mask_a_0 = np.ones_like(dst_face_mask_a_0) + elif cfg.mask_mode == 1: #dst + wrk_face_mask_a_0 = cv2.resize (dst_face_mask_a_0, (output_size,output_size), interpolation=cv2.INTER_CUBIC) + elif cfg.mask_mode == 2: #learned-prd + wrk_face_mask_a_0 = prd_face_mask_a_0 + elif cfg.mask_mode == 3: #learned-dst + wrk_face_mask_a_0 = prd_face_dst_mask_a_0 + elif cfg.mask_mode == 4: #learned-prd*learned-dst + wrk_face_mask_a_0 = prd_face_mask_a_0*prd_face_dst_mask_a_0 + elif cfg.mask_mode == 5: #learned-prd+learned-dst + wrk_face_mask_a_0 = np.clip( prd_face_mask_a_0+prd_face_dst_mask_a_0, 0, 1) + elif cfg.mask_mode >= 6 and cfg.mask_mode <= 9: #XSeg modes + if cfg.mask_mode == 6 or cfg.mask_mode == 8 or cfg.mask_mode == 9: + # obtain XSeg-prd + prd_face_xseg_bgr = cv2.resize (prd_face_bgr, (xseg_input_size,)*2, interpolation=cv2.INTER_CUBIC) + prd_face_xseg_mask = xseg_256_extract_func(prd_face_xseg_bgr) + X_prd_face_mask_a_0 = cv2.resize ( prd_face_xseg_mask, (output_size, output_size), interpolation=cv2.INTER_CUBIC) + + if cfg.mask_mode >= 7 and cfg.mask_mode <= 9: + # obtain XSeg-dst + xseg_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, xseg_input_size, face_type=cfg.face_type) + dst_face_xseg_bgr = cv2.warpAffine(img_bgr, xseg_mat, (xseg_input_size,)*2, flags=cv2.INTER_CUBIC ) + dst_face_xseg_mask = xseg_256_extract_func(dst_face_xseg_bgr) + X_dst_face_mask_a_0 = cv2.resize (dst_face_xseg_mask, (output_size,output_size), interpolation=cv2.INTER_CUBIC) + + if cfg.mask_mode == 6: #'XSeg-prd' + wrk_face_mask_a_0 = X_prd_face_mask_a_0 + elif cfg.mask_mode == 7: #'XSeg-dst' + wrk_face_mask_a_0 = X_dst_face_mask_a_0 + elif cfg.mask_mode == 8: #'XSeg-prd*XSeg-dst' + wrk_face_mask_a_0 = X_prd_face_mask_a_0 * X_dst_face_mask_a_0 + elif cfg.mask_mode == 9: #learned-prd*learned-dst*XSeg-prd*XSeg-dst + wrk_face_mask_a_0 = prd_face_mask_a_0 * prd_face_dst_mask_a_0 * X_prd_face_mask_a_0 * X_dst_face_mask_a_0 + + wrk_face_mask_a_0[ wrk_face_mask_a_0 < (1.0/255.0) ] = 0.0 # get rid of noise + + # resize to mask_subres_size + if wrk_face_mask_a_0.shape[0] != mask_subres_size: + wrk_face_mask_a_0 = cv2.resize (wrk_face_mask_a_0, (mask_subres_size, mask_subres_size), interpolation=cv2.INTER_CUBIC) + + # process mask in local predicted space + if 'raw' not in cfg.mode: + # add zero pad + wrk_face_mask_a_0 = np.pad (wrk_face_mask_a_0, input_size) + + ero = cfg.erode_mask_modifier + blur = cfg.blur_mask_modifier + + if ero > 0: + wrk_face_mask_a_0 = cv2.erode(wrk_face_mask_a_0, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(ero,ero)), iterations = 1 ) + elif ero < 0: + wrk_face_mask_a_0 = cv2.dilate(wrk_face_mask_a_0, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(-ero,-ero)), iterations = 1 ) + + # clip eroded/dilated mask in actual predict area + # pad with half blur size in order to accuratelly fade to zero at the boundary + clip_size = input_size + blur // 2 + + wrk_face_mask_a_0[:clip_size,:] = 0 + wrk_face_mask_a_0[-clip_size:,:] = 0 + wrk_face_mask_a_0[:,:clip_size] = 0 + wrk_face_mask_a_0[:,-clip_size:] = 0 + + if blur > 0: + blur = blur + (1-blur % 2) + wrk_face_mask_a_0 = cv2.GaussianBlur(wrk_face_mask_a_0, (blur, blur) , 0) + + wrk_face_mask_a_0 = wrk_face_mask_a_0[input_size:-input_size,input_size:-input_size] + + wrk_face_mask_a_0 = np.clip(wrk_face_mask_a_0, 0, 1) + + img_face_mask_a = cv2.warpAffine( wrk_face_mask_a_0, face_mask_output_mat, img_size, np.zeros(img_bgr.shape[0:2], dtype=np.float32), flags=cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC )[...,None] + img_face_mask_a = np.clip (img_face_mask_a, 0.0, 1.0) + img_face_mask_a [ img_face_mask_a < (1.0/255.0) ] = 0.0 # get rid of noise + + if wrk_face_mask_a_0.shape[0] != output_size: + wrk_face_mask_a_0 = cv2.resize (wrk_face_mask_a_0, (output_size,output_size), interpolation=cv2.INTER_CUBIC) + + wrk_face_mask_a = wrk_face_mask_a_0[...,None] + + out_img = None + out_merging_mask_a = None + if cfg.mode == 'original': + return img_bgr, img_face_mask_a + + elif 'raw' in cfg.mode: + if cfg.mode == 'raw-rgb': + out_img_face = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC) + out_img_face_mask = cv2.warpAffine( np.ones_like(prd_face_bgr), face_output_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC) + out_img = img_bgr*(1-out_img_face_mask) + out_img_face*out_img_face_mask + out_merging_mask_a = img_face_mask_a + elif cfg.mode == 'raw-predict': + out_img = prd_face_bgr + out_merging_mask_a = wrk_face_mask_a + else: + raise ValueError(f"undefined raw type {cfg.mode}") + + out_img = np.clip (out_img, 0.0, 1.0 ) + else: + + # Process if the mask meets minimum size + maxregion = np.argwhere( img_face_mask_a >= 0.1 ) + if maxregion.size != 0: + miny,minx = maxregion.min(axis=0)[:2] + maxy,maxx = maxregion.max(axis=0)[:2] + lenx = maxx - minx + leny = maxy - miny + if min(lenx,leny) >= 4: + wrk_face_mask_area_a = wrk_face_mask_a.copy() + wrk_face_mask_area_a[wrk_face_mask_area_a>0] = 1.0 + + if 'seamless' not in cfg.mode and cfg.color_transfer_mode != 0: + if cfg.color_transfer_mode == 1: #rct + prd_face_bgr = imagelib.reinhard_color_transfer (prd_face_bgr, dst_face_bgr, target_mask=wrk_face_mask_area_a, source_mask=wrk_face_mask_area_a) + elif cfg.color_transfer_mode == 2: #lct + prd_face_bgr = imagelib.linear_color_transfer (prd_face_bgr, dst_face_bgr) + elif cfg.color_transfer_mode == 3: #mkl + prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr, dst_face_bgr) + elif cfg.color_transfer_mode == 4: #mkl-m + prd_face_bgr = imagelib.color_transfer_mkl (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a) + elif cfg.color_transfer_mode == 5: #idt + prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr, dst_face_bgr) + elif cfg.color_transfer_mode == 6: #idt-m + prd_face_bgr = imagelib.color_transfer_idt (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a) + elif cfg.color_transfer_mode == 7: #sot-m + prd_face_bgr = imagelib.color_transfer_sot (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a, steps=10, batch_size=30) + prd_face_bgr = np.clip (prd_face_bgr, 0.0, 1.0) + elif cfg.color_transfer_mode == 8: #mix-m + prd_face_bgr = imagelib.color_transfer_mix (prd_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a) + + if cfg.mode == 'hist-match': + hist_mask_a = np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32) + + if cfg.masked_hist_match: + hist_mask_a *= wrk_face_mask_area_a + + white = (1.0-hist_mask_a)* np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32) + + hist_match_1 = prd_face_bgr*hist_mask_a + white + hist_match_1[ hist_match_1 > 1.0 ] = 1.0 + + hist_match_2 = dst_face_bgr*hist_mask_a + white + hist_match_2[ hist_match_1 > 1.0 ] = 1.0 + + prd_face_bgr = imagelib.color_hist_match(hist_match_1, hist_match_2, cfg.hist_match_threshold ).astype(dtype=np.float32) + + if 'seamless' in cfg.mode: + #mask used for cv2.seamlessClone + img_face_seamless_mask_a = None + for i in range(1,10): + a = img_face_mask_a > i / 10.0 + if len(np.argwhere(a)) == 0: + continue + img_face_seamless_mask_a = img_face_mask_a.copy() + img_face_seamless_mask_a[a] = 1.0 + img_face_seamless_mask_a[img_face_seamless_mask_a <= i / 10.0] = 0.0 + break + + out_img = cv2.warpAffine( prd_face_bgr, face_output_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC ) + out_img = np.clip(out_img, 0.0, 1.0) + + if 'seamless' in cfg.mode: + try: + #calc same bounding rect and center point as in cv2.seamlessClone to prevent jittering (not flickering) + l,t,w,h = cv2.boundingRect( (img_face_seamless_mask_a*255).astype(np.uint8) ) + s_maskx, s_masky = int(l+w/2), int(t+h/2) + out_img = cv2.seamlessClone( (out_img*255).astype(np.uint8), img_bgr_uint8, (img_face_seamless_mask_a*255).astype(np.uint8), (s_maskx,s_masky) , cv2.NORMAL_CLONE ) + out_img = out_img.astype(dtype=np.float32) / 255.0 + except Exception as e: + #seamlessClone may fail in some cases + e_str = traceback.format_exc() + + if 'MemoryError' in e_str: + raise Exception("Seamless fail: " + e_str) #reraise MemoryError in order to reprocess this data by other processes + else: + print ("Seamless fail: " + e_str) + + cfg_mp = cfg.motion_blur_power / 100.0 + + out_img = img_bgr*(1-img_face_mask_a) + (out_img*img_face_mask_a) + + if ('seamless' in cfg.mode and cfg.color_transfer_mode != 0) or \ + cfg.mode == 'seamless-hist-match' or \ + cfg_mp != 0 or \ + cfg.blursharpen_amount != 0 or \ + cfg.image_denoise_power != 0 or \ + cfg.bicubic_degrade_power != 0: + + out_face_bgr = cv2.warpAffine( out_img, face_mat, (output_size, output_size), flags=cv2.INTER_CUBIC ) + + if 'seamless' in cfg.mode and cfg.color_transfer_mode != 0: + if cfg.color_transfer_mode == 1: + out_face_bgr = imagelib.reinhard_color_transfer (out_face_bgr, dst_face_bgr, target_mask=wrk_face_mask_area_a, source_mask=wrk_face_mask_area_a) + elif cfg.color_transfer_mode == 2: #lct + out_face_bgr = imagelib.linear_color_transfer (out_face_bgr, dst_face_bgr) + elif cfg.color_transfer_mode == 3: #mkl + out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr, dst_face_bgr) + elif cfg.color_transfer_mode == 4: #mkl-m + out_face_bgr = imagelib.color_transfer_mkl (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a) + elif cfg.color_transfer_mode == 5: #idt + out_face_bgr = imagelib.color_transfer_idt (out_face_bgr, dst_face_bgr) + elif cfg.color_transfer_mode == 6: #idt-m + out_face_bgr = imagelib.color_transfer_idt (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a) + elif cfg.color_transfer_mode == 7: #sot-m + out_face_bgr = imagelib.color_transfer_sot (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a, steps=10, batch_size=30) + out_face_bgr = np.clip (out_face_bgr, 0.0, 1.0) + elif cfg.color_transfer_mode == 8: #mix-m + out_face_bgr = imagelib.color_transfer_mix (out_face_bgr*wrk_face_mask_area_a, dst_face_bgr*wrk_face_mask_area_a) + + if cfg.mode == 'seamless-hist-match': + out_face_bgr = imagelib.color_hist_match(out_face_bgr, dst_face_bgr, cfg.hist_match_threshold) + + if cfg_mp != 0: + k_size = int(frame_info.motion_power*cfg_mp) + if k_size >= 1: + k_size = np.clip (k_size+1, 2, 50) + if cfg.super_resolution_power != 0: + k_size *= 2 + out_face_bgr = imagelib.LinearMotionBlur (out_face_bgr, k_size , frame_info.motion_deg) + + if cfg.blursharpen_amount != 0: + out_face_bgr = imagelib.blursharpen ( out_face_bgr, cfg.sharpen_mode, 3, cfg.blursharpen_amount) + + if cfg.image_denoise_power != 0: + n = cfg.image_denoise_power + while n > 0: + img_bgr_denoised = cv2.medianBlur(img_bgr, 5) + if int(n / 100) != 0: + img_bgr = img_bgr_denoised + else: + pass_power = (n % 100) / 100.0 + img_bgr = img_bgr*(1.0-pass_power)+img_bgr_denoised*pass_power + n = max(n-10,0) + + if cfg.bicubic_degrade_power != 0: + p = 1.0 - cfg.bicubic_degrade_power / 101.0 + img_bgr_downscaled = cv2.resize (img_bgr, ( int(img_size[0]*p), int(img_size[1]*p ) ), interpolation=cv2.INTER_CUBIC) + img_bgr = cv2.resize (img_bgr_downscaled, img_size, interpolation=cv2.INTER_CUBIC) + + new_out = cv2.warpAffine( out_face_bgr, face_mat, img_size, np.empty_like(img_bgr), cv2.WARP_INVERSE_MAP | cv2.INTER_CUBIC ) + + out_img = np.clip( img_bgr*(1-img_face_mask_a) + (new_out*img_face_mask_a) , 0, 1.0 ) + + if cfg.color_degrade_power != 0: + out_img_reduced = imagelib.reduce_colors(out_img, 256) + if cfg.color_degrade_power == 100: + out_img = out_img_reduced + else: + alpha = cfg.color_degrade_power / 100.0 + out_img = (out_img*(1.0-alpha) + out_img_reduced*alpha) + out_merging_mask_a = img_face_mask_a + + if out_img is None: + out_img = img_bgr.copy() + + return out_img, out_merging_mask_a + + +def MergeMasked (predictor_func, + predictor_input_shape, + face_enhancer_func, + xseg_256_extract_func, + cfg, + frame_info): + img_bgr_uint8 = cv2_imread(frame_info.filepath) + img_bgr_uint8 = imagelib.normalize_channels (img_bgr_uint8, 3) + img_bgr = img_bgr_uint8.astype(np.float32) / 255.0 + + outs = [] + for face_num, img_landmarks in enumerate( frame_info.landmarks_list ): + out_img, out_img_merging_mask = MergeMaskedFace (predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, cfg, frame_info, img_bgr_uint8, img_bgr, img_landmarks) + outs += [ (out_img, out_img_merging_mask) ] + + #Combining multiple face outputs + final_img = None + final_mask = None + for img, merging_mask in outs: + h,w,c = img.shape + + if final_img is None: + final_img = img + final_mask = merging_mask + else: + final_img = final_img*(1-merging_mask) + img*merging_mask + final_mask = np.clip (final_mask + merging_mask, 0, 1 ) + + final_img = np.concatenate ( [final_img, final_mask], -1) + + return (final_img*255).astype(np.uint8) diff --git a/merger/MergerConfig.py b/merger/MergerConfig.py new file mode 100644 index 0000000000000000000000000000000000000000..eba1493d5f54d6876802c1857bafe8c28a21ee60 --- /dev/null +++ b/merger/MergerConfig.py @@ -0,0 +1,329 @@ +import numpy as np +import copy + +from facelib import FaceType +from core.interact import interact as io + + +class MergerConfig(object): + TYPE_NONE = 0 + TYPE_MASKED = 1 + TYPE_FACE_AVATAR = 2 + #### + + TYPE_IMAGE = 3 + TYPE_IMAGE_WITH_LANDMARKS = 4 + + def __init__(self, type=0, + sharpen_mode=0, + blursharpen_amount=0, + **kwargs + ): + self.type = type + + self.sharpen_dict = {0:"None", 1:'box', 2:'gaussian'} + + #default changeable params + self.sharpen_mode = sharpen_mode + self.blursharpen_amount = blursharpen_amount + + def copy(self): + return copy.copy(self) + + #overridable + def ask_settings(self): + s = """Choose sharpen mode: \n""" + for key in self.sharpen_dict.keys(): + s += f"""({key}) {self.sharpen_dict[key]}\n""" + io.log_info(s) + self.sharpen_mode = io.input_int ("", 0, valid_list=self.sharpen_dict.keys(), help_message="Enhance details by applying sharpen filter.") + + if self.sharpen_mode != 0: + self.blursharpen_amount = np.clip ( io.input_int ("Choose blur/sharpen amount", 0, add_info="-100..100"), -100, 100 ) + + def toggle_sharpen_mode(self): + a = list( self.sharpen_dict.keys() ) + self.sharpen_mode = a[ (a.index(self.sharpen_mode)+1) % len(a) ] + + def add_blursharpen_amount(self, diff): + self.blursharpen_amount = np.clip ( self.blursharpen_amount+diff, -100, 100) + + #overridable + def get_config(self): + d = self.__dict__.copy() + d.pop('type') + return d + + #overridable + def __eq__(self, other): + #check equality of changeable params + + if isinstance(other, MergerConfig): + return self.sharpen_mode == other.sharpen_mode and \ + self.blursharpen_amount == other.blursharpen_amount + + return False + + #overridable + def to_string(self, filename): + r = "" + r += f"sharpen_mode : {self.sharpen_dict[self.sharpen_mode]}\n" + r += f"blursharpen_amount : {self.blursharpen_amount}\n" + return r + +mode_dict = {0:'original', + 1:'overlay', + 2:'hist-match', + 3:'seamless', + 4:'seamless-hist-match', + 5:'raw-rgb', + 6:'raw-predict'} + +mode_str_dict = { mode_dict[key] : key for key in mode_dict.keys() } + +mask_mode_dict = {0:'full', + 1:'dst', + 2:'learned-prd', + 3:'learned-dst', + 4:'learned-prd*learned-dst', + 5:'learned-prd+learned-dst', + 6:'XSeg-prd', + 7:'XSeg-dst', + 8:'XSeg-prd*XSeg-dst', + 9:'learned-prd*learned-dst*XSeg-prd*XSeg-dst' + } + + +ctm_dict = { 0: "None", 1:"rct", 2:"lct", 3:"mkl", 4:"mkl-m", 5:"idt", 6:"idt-m", 7:"sot-m", 8:"mix-m" } +ctm_str_dict = {None:0, "rct":1, "lct":2, "mkl":3, "mkl-m":4, "idt":5, "idt-m":6, "sot-m":7, "mix-m":8 } + +class MergerConfigMasked(MergerConfig): + + def __init__(self, face_type=FaceType.FULL, + default_mode = 'overlay', + mode='overlay', + masked_hist_match=True, + hist_match_threshold = 238, + mask_mode = 4, + erode_mask_modifier = 0, + blur_mask_modifier = 0, + motion_blur_power = 0, + output_face_scale = 0, + super_resolution_power = 0, + color_transfer_mode = ctm_str_dict['rct'], + image_denoise_power = 0, + bicubic_degrade_power = 0, + color_degrade_power = 0, + **kwargs + ): + + super().__init__(type=MergerConfig.TYPE_MASKED, **kwargs) + + self.face_type = face_type + if self.face_type not in [FaceType.HALF, FaceType.MID_FULL, FaceType.FULL, FaceType.WHOLE_FACE, FaceType.HEAD ]: + raise ValueError("MergerConfigMasked does not support this type of face.") + + self.default_mode = default_mode + + #default changeable params + if mode not in mode_str_dict: + mode = mode_dict[1] + + self.mode = mode + self.masked_hist_match = masked_hist_match + self.hist_match_threshold = hist_match_threshold + self.mask_mode = mask_mode + self.erode_mask_modifier = erode_mask_modifier + self.blur_mask_modifier = blur_mask_modifier + self.motion_blur_power = motion_blur_power + self.output_face_scale = output_face_scale + self.super_resolution_power = super_resolution_power + self.color_transfer_mode = color_transfer_mode + self.image_denoise_power = image_denoise_power + self.bicubic_degrade_power = bicubic_degrade_power + self.color_degrade_power = color_degrade_power + + def copy(self): + return copy.copy(self) + + def set_mode (self, mode): + self.mode = mode_dict.get (mode, self.default_mode) + + def toggle_masked_hist_match(self): + if self.mode == 'hist-match': + self.masked_hist_match = not self.masked_hist_match + + def add_hist_match_threshold(self, diff): + if self.mode == 'hist-match' or self.mode == 'seamless-hist-match': + self.hist_match_threshold = np.clip ( self.hist_match_threshold+diff , 0, 255) + + def toggle_mask_mode(self): + a = list( mask_mode_dict.keys() ) + self.mask_mode = a[ (a.index(self.mask_mode)+1) % len(a) ] + + def add_erode_mask_modifier(self, diff): + self.erode_mask_modifier = np.clip ( self.erode_mask_modifier+diff , -400, 400) + + def add_blur_mask_modifier(self, diff): + self.blur_mask_modifier = np.clip ( self.blur_mask_modifier+diff , 0, 400) + + def add_motion_blur_power(self, diff): + self.motion_blur_power = np.clip ( self.motion_blur_power+diff, 0, 100) + + def add_output_face_scale(self, diff): + self.output_face_scale = np.clip ( self.output_face_scale+diff , -50, 50) + + def toggle_color_transfer_mode(self): + self.color_transfer_mode = (self.color_transfer_mode+1) % ( max(ctm_dict.keys())+1 ) + + def add_super_resolution_power(self, diff): + self.super_resolution_power = np.clip ( self.super_resolution_power+diff , 0, 100) + + def add_color_degrade_power(self, diff): + self.color_degrade_power = np.clip ( self.color_degrade_power+diff , 0, 100) + + def add_image_denoise_power(self, diff): + self.image_denoise_power = np.clip ( self.image_denoise_power+diff, 0, 500) + + def add_bicubic_degrade_power(self, diff): + self.bicubic_degrade_power = np.clip ( self.bicubic_degrade_power+diff, 0, 100) + + def ask_settings(self): + s = """Choose mode: \n""" + for key in mode_dict.keys(): + s += f"""({key}) {mode_dict[key]}\n""" + io.log_info(s) + mode = io.input_int ("", mode_str_dict.get(self.default_mode, 1) ) + + self.mode = mode_dict.get (mode, self.default_mode ) + + if 'raw' not in self.mode: + if self.mode == 'hist-match': + self.masked_hist_match = io.input_bool("Masked hist match?", True) + + if self.mode == 'hist-match' or self.mode == 'seamless-hist-match': + self.hist_match_threshold = np.clip ( io.input_int("Hist match threshold", 255, add_info="0..255"), 0, 255) + + s = """Choose mask mode: \n""" + for key in mask_mode_dict.keys(): + s += f"""({key}) {mask_mode_dict[key]}\n""" + io.log_info(s) + self.mask_mode = io.input_int ("", 1, valid_list=mask_mode_dict.keys() ) + + if 'raw' not in self.mode: + self.erode_mask_modifier = np.clip ( io.input_int ("Choose erode mask modifier", 0, add_info="-400..400"), -400, 400) + self.blur_mask_modifier = np.clip ( io.input_int ("Choose blur mask modifier", 0, add_info="0..400"), 0, 400) + self.motion_blur_power = np.clip ( io.input_int ("Choose motion blur power", 0, add_info="0..100"), 0, 100) + + self.output_face_scale = np.clip (io.input_int ("Choose output face scale modifier", 0, add_info="-50..50" ), -50, 50) + + if 'raw' not in self.mode: + self.color_transfer_mode = io.input_str ( "Color transfer to predicted face", None, valid_list=list(ctm_str_dict.keys())[1:] ) + self.color_transfer_mode = ctm_str_dict[self.color_transfer_mode] + + super().ask_settings() + + self.super_resolution_power = np.clip ( io.input_int ("Choose super resolution power", 0, add_info="0..100", help_message="Enhance details by applying superresolution network."), 0, 100) + + if 'raw' not in self.mode: + self.image_denoise_power = np.clip ( io.input_int ("Choose image degrade by denoise power", 0, add_info="0..500"), 0, 500) + self.bicubic_degrade_power = np.clip ( io.input_int ("Choose image degrade by bicubic rescale power", 0, add_info="0..100"), 0, 100) + self.color_degrade_power = np.clip ( io.input_int ("Degrade color power of final image", 0, add_info="0..100"), 0, 100) + + io.log_info ("") + + def __eq__(self, other): + #check equality of changeable params + + if isinstance(other, MergerConfigMasked): + return super().__eq__(other) and \ + self.mode == other.mode and \ + self.masked_hist_match == other.masked_hist_match and \ + self.hist_match_threshold == other.hist_match_threshold and \ + self.mask_mode == other.mask_mode and \ + self.erode_mask_modifier == other.erode_mask_modifier and \ + self.blur_mask_modifier == other.blur_mask_modifier and \ + self.motion_blur_power == other.motion_blur_power and \ + self.output_face_scale == other.output_face_scale and \ + self.color_transfer_mode == other.color_transfer_mode and \ + self.super_resolution_power == other.super_resolution_power and \ + self.image_denoise_power == other.image_denoise_power and \ + self.bicubic_degrade_power == other.bicubic_degrade_power and \ + self.color_degrade_power == other.color_degrade_power + + return False + + def to_string(self, filename): + r = ( + f"""MergerConfig {filename}:\n""" + f"""Mode: {self.mode}\n""" + ) + + if self.mode == 'hist-match': + r += f"""masked_hist_match: {self.masked_hist_match}\n""" + + if self.mode == 'hist-match' or self.mode == 'seamless-hist-match': + r += f"""hist_match_threshold: {self.hist_match_threshold}\n""" + + r += f"""mask_mode: { mask_mode_dict[self.mask_mode] }\n""" + + if 'raw' not in self.mode: + r += (f"""erode_mask_modifier: {self.erode_mask_modifier}\n""" + f"""blur_mask_modifier: {self.blur_mask_modifier}\n""" + f"""motion_blur_power: {self.motion_blur_power}\n""") + + r += f"""output_face_scale: {self.output_face_scale}\n""" + + if 'raw' not in self.mode: + r += f"""color_transfer_mode: {ctm_dict[self.color_transfer_mode]}\n""" + r += super().to_string(filename) + + r += f"""super_resolution_power: {self.super_resolution_power}\n""" + + if 'raw' not in self.mode: + r += (f"""image_denoise_power: {self.image_denoise_power}\n""" + f"""bicubic_degrade_power: {self.bicubic_degrade_power}\n""" + f"""color_degrade_power: {self.color_degrade_power}\n""") + + r += "================" + + return r + + +class MergerConfigFaceAvatar(MergerConfig): + + def __init__(self, temporal_face_count=0, + add_source_image=False): + super().__init__(type=MergerConfig.TYPE_FACE_AVATAR) + self.temporal_face_count = temporal_face_count + + #changeable params + self.add_source_image = add_source_image + + def copy(self): + return copy.copy(self) + + #override + def ask_settings(self): + self.add_source_image = io.input_bool("Add source image?", False, help_message="Add source image for comparison.") + super().ask_settings() + + def toggle_add_source_image(self): + self.add_source_image = not self.add_source_image + + #override + def __eq__(self, other): + #check equality of changeable params + + if isinstance(other, MergerConfigFaceAvatar): + return super().__eq__(other) and \ + self.add_source_image == other.add_source_image + + return False + + #override + def to_string(self, filename): + return (f"MergerConfig {filename}:\n" + f"add_source_image : {self.add_source_image}\n") + \ + super().to_string(filename) + "================" + diff --git a/merger/MergerScreen/MergerScreen.py b/merger/MergerScreen/MergerScreen.py new file mode 100644 index 0000000000000000000000000000000000000000..fad2a8d04fc4823ebe55e38f3478e368cc99824a --- /dev/null +++ b/merger/MergerScreen/MergerScreen.py @@ -0,0 +1,149 @@ +import math +from pathlib import Path + +import numpy as np + +from core import imagelib +from core.interact import interact as io +from core.cv2ex import * +from core import osex + + +class ScreenAssets(object): + waiting_icon_image = cv2_imread ( str(Path(__file__).parent / 'gfx' / 'sand_clock_64.png') ) + + @staticmethod + def build_checkerboard_a( sh, size=5): + h,w = sh[0], sh[1] + tile = np.array([[0,1],[1,0]]).repeat(size, axis=0).repeat(size, axis=1) + grid = np.tile(tile,(int(math.ceil((h+0.0)/(2*size))),int(math.ceil((w+0.0)/(2*size))))) + return grid[:h,:w,None] + +class Screen(object): + def __init__(self, initial_scale_to_width=0, initial_scale_to_height=0, image=None, waiting_icon=False, **kwargs): + self.initial_scale_to_width = initial_scale_to_width + self.initial_scale_to_height = initial_scale_to_height + self.image = image + self.waiting_icon = waiting_icon + + self.state = -1 + self.scale = 1 + self.force_update = True + self.is_first_appear = True + self.show_checker_board = False + + self.last_screen_shape = (480,640,3) + self.checkerboard_image = None + self.set_image (image) + self.scrn_manager = None + + def set_waiting_icon(self, b): + self.waiting_icon = b + + def toggle_show_checker_board(self): + self.show_checker_board = not self.show_checker_board + self.force_update = True + + def get_image(self): + return self.image + + def set_image(self, img): + if not img is self.image: + self.force_update = True + + self.image = img + + if self.image is not None: + self.last_screen_shape = self.image.shape + + if self.initial_scale_to_width != 0: + if self.last_screen_shape[1] > self.initial_scale_to_width: + self.scale = self.initial_scale_to_width / self.last_screen_shape[1] + self.force_update = True + self.initial_scale_to_width = 0 + + elif self.initial_scale_to_height != 0: + if self.last_screen_shape[0] > self.initial_scale_to_height: + self.scale = self.initial_scale_to_height / self.last_screen_shape[0] + self.force_update = True + self.initial_scale_to_height = 0 + + + def diff_scale(self, diff): + self.scale = np.clip (self.scale + diff, 0.1, 4.0) + self.force_update = True + + def show(self, force=False): + new_state = 0 | self.waiting_icon + + if self.state != new_state or self.force_update or force: + self.state = new_state + self.force_update = False + + if self.image is None: + screen = np.zeros ( self.last_screen_shape, dtype=np.uint8 ) + else: + screen = self.image.copy() + + if self.waiting_icon: + imagelib.overlay_alpha_image (screen, ScreenAssets.waiting_icon_image, (0,0) ) + + h,w,c = screen.shape + if self.scale != 1.0: + screen = cv2.resize ( screen, ( int(w*self.scale), int(h*self.scale) ) ) + + if c == 4: + if not self.show_checker_board: + screen = screen[...,0:3] + else: + if self.checkerboard_image is None or self.checkerboard_image.shape[0:2] != screen.shape[0:2]: + self.checkerboard_image = ScreenAssets.build_checkerboard_a(screen.shape) + + screen = screen[...,0:3]*0.75 + 64*self.checkerboard_image*(1- (screen[...,3:4].astype(np.float32)/255.0) ) + screen = screen.astype(np.uint8) + + io.show_image(self.scrn_manager.wnd_name, screen) + + if self.is_first_appear: + self.is_first_appear = False + #center window + desktop_w, desktop_h = osex.get_screen_size() + h,w,c = screen.shape + cv2.moveWindow(self.scrn_manager.wnd_name, max(0,(desktop_w-w) // 2), max(0, (desktop_h-h) // 2) ) + + io.process_messages(0.0001) + +class ScreenManager(object): + def __init__(self, window_name="ScreenManager", screens=None, capture_keys=False ): + self.screens = screens or [] + self.current_screen_id = 0 + + if self.screens is not None: + for screen in self.screens: + screen.scrn_manager = self + + self.wnd_name = window_name + io.named_window(self.wnd_name) + + + if capture_keys: + io.capture_keys(self.wnd_name) + + def finalize(self): + io.destroy_all_windows() + + def get_key_events(self): + return io.get_key_events(self.wnd_name) + + def switch_screens(self): + self.current_screen_id = (self.current_screen_id + 1) % len(self.screens) + self.screens[self.current_screen_id].show(force=True) + + def show_current(self): + self.screens[self.current_screen_id].show() + + def get_current(self): + return self.screens[self.current_screen_id] + + def set_current(self, screen): + self.current_screen_id = self.screens.index(screen) diff --git a/merger/MergerScreen/__init__.py b/merger/MergerScreen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea3e32085c5f5330c050698a71046f22538e8fc4 --- /dev/null +++ b/merger/MergerScreen/__init__.py @@ -0,0 +1 @@ +from .MergerScreen import Screen, ScreenManager \ No newline at end of file diff --git a/merger/MergerScreen/gfx/sand_clock_64.png b/merger/MergerScreen/gfx/sand_clock_64.png new file mode 100644 index 0000000000000000000000000000000000000000..c8d34869aba357a4a99acae8ec4a475db20d0fc0 Binary files /dev/null and b/merger/MergerScreen/gfx/sand_clock_64.png differ diff --git a/merger/__init__.py b/merger/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..82a44144b1ef989d6b53bd186fa456d45dc91ff4 --- /dev/null +++ b/merger/__init__.py @@ -0,0 +1,5 @@ +from .FrameInfo import FrameInfo +from .MergerConfig import MergerConfig, MergerConfigMasked, MergerConfigFaceAvatar +from .MergeMasked import MergeMasked +from .MergeAvatar import MergeFaceAvatar +from .InteractiveMergerSubprocessor import InteractiveMergerSubprocessor \ No newline at end of file diff --git a/merger/gfx/help_merger_face_avatar.jpg b/merger/gfx/help_merger_face_avatar.jpg new file mode 100644 index 0000000000000000000000000000000000000000..29b7e728075e218c2d3f6e9d6bdd105b8b7a1a56 Binary files /dev/null and b/merger/gfx/help_merger_face_avatar.jpg differ diff --git a/merger/gfx/help_merger_face_avatar_source.psd b/merger/gfx/help_merger_face_avatar_source.psd new file mode 100644 index 0000000000000000000000000000000000000000..04adff001400255fac286f4bc5e2bc568152c574 --- /dev/null +++ b/merger/gfx/help_merger_face_avatar_source.psd @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:457d44d105aa60addd06e7d5cec9f9d5ffa3762960c0b085e9e559b1cf7e4a3e +size 8371144 diff --git a/merger/gfx/help_merger_masked.jpg b/merger/gfx/help_merger_masked.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f1822f34fb827b107d7de80885c5fed4207963fa Binary files /dev/null and b/merger/gfx/help_merger_masked.jpg differ diff --git a/merger/gfx/help_merger_masked_source.psd b/merger/gfx/help_merger_masked_source.psd new file mode 100644 index 0000000000000000000000000000000000000000..fad88ae9c28a40dd68d539a0ce9963e6659595c2 --- /dev/null +++ b/merger/gfx/help_merger_masked_source.psd @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02c7cdd5d3be810268dfdd5512c2c5f4a35c80e6c25a53e02d7a801d46d70033 +size 8677652 diff --git a/models/ModelBase.py b/models/ModelBase.py new file mode 100644 index 0000000000000000000000000000000000000000..f446efa3f8b7f6ae1a3dcbeedf12f8aa82b3a27b --- /dev/null +++ b/models/ModelBase.py @@ -0,0 +1,685 @@ +import colorsys +import inspect +import json +import multiprocessing +import operator +import os +import pickle +import shutil +import tempfile +import time +from pathlib import Path + +import cv2 +import numpy as np + +from core import imagelib, pathex +from core.cv2ex import * +from core.interact import interact as io +from core.leras import nn +from samplelib import SampleGeneratorBase + + +class ModelBase(object): + def __init__(self, is_training=False, + is_exporting=False, + saved_models_path=None, + training_data_src_path=None, + training_data_dst_path=None, + pretraining_data_path=None, + pretrained_model_path=None, + no_preview=False, + force_model_name=None, + force_gpu_idxs=None, + cpu_only=False, + debug=False, + force_model_class_name=None, + silent_start=False, + **kwargs): + self.is_training = is_training + self.is_exporting = is_exporting + self.saved_models_path = saved_models_path + self.training_data_src_path = training_data_src_path + self.training_data_dst_path = training_data_dst_path + self.pretraining_data_path = pretraining_data_path + self.pretrained_model_path = pretrained_model_path + self.no_preview = no_preview + self.debug = debug + + self.model_class_name = model_class_name = Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1] + + if force_model_class_name is None: + if force_model_name is not None: + self.model_name = force_model_name + else: + while True: + # gather all model dat files + saved_models_names = [] + for filepath in pathex.get_file_paths(saved_models_path): + filepath_name = filepath.name + if filepath_name.endswith(f'{model_class_name}_data.dat'): + saved_models_names += [ (filepath_name.split('_')[0], os.path.getmtime(filepath)) ] + + # sort by modified datetime + saved_models_names = sorted(saved_models_names, key=operator.itemgetter(1), reverse=True ) + saved_models_names = [ x[0] for x in saved_models_names ] + + + if len(saved_models_names) != 0: + if silent_start: + self.model_name = saved_models_names[0] + io.log_info(f'Silent start: choosed model "{self.model_name}"') + else: + io.log_info ("Choose one of saved models, or enter a name to create a new model.") + io.log_info ("[r] : rename") + io.log_info ("[d] : delete") + io.log_info ("") + for i, model_name in enumerate(saved_models_names): + s = f"[{i}] : {model_name} " + if i == 0: + s += "- latest" + io.log_info (s) + + inp = io.input_str(f"", "0", show_default_value=False ) + model_idx = -1 + try: + model_idx = np.clip ( int(inp), 0, len(saved_models_names)-1 ) + except: + pass + + if model_idx == -1: + if len(inp) == 1: + is_rename = inp[0] == 'r' + is_delete = inp[0] == 'd' + + if is_rename or is_delete: + if len(saved_models_names) != 0: + + if is_rename: + name = io.input_str(f"Enter the name of the model you want to rename") + elif is_delete: + name = io.input_str(f"Enter the name of the model you want to delete") + + if name in saved_models_names: + + if is_rename: + new_model_name = io.input_str(f"Enter new name of the model") + + for filepath in pathex.get_paths(saved_models_path): + filepath_name = filepath.name + + model_filename, remain_filename = filepath_name.split('_', 1) + if model_filename == name: + + if is_rename: + new_filepath = filepath.parent / ( new_model_name + '_' + remain_filename ) + filepath.rename (new_filepath) + elif is_delete: + filepath.unlink() + continue + + self.model_name = inp + else: + self.model_name = saved_models_names[model_idx] + + else: + self.model_name = io.input_str(f"No saved models found. Enter a name of a new model", "new") + self.model_name = self.model_name.replace('_', ' ') + break + + + self.model_name = self.model_name + '_' + self.model_class_name + else: + self.model_name = force_model_class_name + + self.iter = 0 + self.options = {} + self.options_show_override = {} + self.loss_history = [] + self.sample_for_preview = None + self.choosed_gpu_indexes = None + + model_data = {} + self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') ) + if self.model_data_path.exists(): + io.log_info (f"Loading {self.model_name} model...") + model_data = pickle.loads ( self.model_data_path.read_bytes() ) + self.iter = model_data.get('iter',0) + if self.iter != 0: + self.options = model_data['options'] + self.loss_history = model_data.get('loss_history', []) + self.sample_for_preview = model_data.get('sample_for_preview', None) + self.choosed_gpu_indexes = model_data.get('choosed_gpu_indexes', None) + + if self.is_first_run(): + io.log_info ("\nModel first run.") + + if silent_start: + self.device_config = nn.DeviceConfig.BestGPU() + io.log_info (f"Silent start: choosed device {'CPU' if self.device_config.cpu_only else self.device_config.devices[0].name}") + else: + self.device_config = nn.DeviceConfig.GPUIndexes( force_gpu_idxs or nn.ask_choose_device_idxs(suggest_best_multi_gpu=True)) \ + if not cpu_only else nn.DeviceConfig.CPU() + + nn.initialize(self.device_config) + + #### + self.default_options_path = saved_models_path / f'{self.model_class_name}_default_options.dat' + self.default_options = {} + if self.default_options_path.exists(): + try: + self.default_options = pickle.loads ( self.default_options_path.read_bytes() ) + except: + pass + + self.choose_preview_history = False + self.batch_size = self.load_or_def_option('batch_size', 1) + ##### + + io.input_skip_pending() + self.on_initialize_options() + + if self.is_first_run(): + # save as default options only for first run model initialize + self.default_options_path.write_bytes( pickle.dumps (self.options) ) + + self.autobackup_hour = self.options.get('autobackup_hour', 0) + self.write_preview_history = self.options.get('write_preview_history', False) + self.target_iter = self.options.get('target_iter',0) + self.random_flip = self.options.get('random_flip',True) + self.random_src_flip = self.options.get('random_src_flip', False) + self.random_dst_flip = self.options.get('random_dst_flip', True) + + self.on_initialize() + self.options['batch_size'] = self.batch_size + + self.preview_history_writer = None + if self.is_training: + self.preview_history_path = self.saved_models_path / ( f'{self.get_model_name()}_history' ) + self.autobackups_path = self.saved_models_path / ( f'{self.get_model_name()}_autobackups' ) + + if self.write_preview_history or io.is_colab(): + if not self.preview_history_path.exists(): + self.preview_history_path.mkdir(exist_ok=True) + else: + if self.iter == 0: + for filename in pathex.get_image_paths(self.preview_history_path): + Path(filename).unlink() + + if self.generator_list is None: + raise ValueError( 'You didnt set_training_data_generators()') + else: + for i, generator in enumerate(self.generator_list): + if not isinstance(generator, SampleGeneratorBase): + raise ValueError('training data generator is not subclass of SampleGeneratorBase') + + self.update_sample_for_preview(choose_preview_history=self.choose_preview_history) + + if self.autobackup_hour != 0: + self.autobackup_start_time = time.time() + + if not self.autobackups_path.exists(): + self.autobackups_path.mkdir(exist_ok=True) + + io.log_info( self.get_summary_text() ) + + def update_sample_for_preview(self, choose_preview_history=False, force_new=False): + if self.sample_for_preview is None or choose_preview_history or force_new: + if choose_preview_history and io.is_support_windows(): + wnd_name = "[p] - next. [space] - switch preview type. [enter] - confirm." + io.log_info (f"Choose image for the preview history. {wnd_name}") + io.named_window(wnd_name) + io.capture_keys(wnd_name) + choosed = False + preview_id_counter = 0 + while not choosed: + self.sample_for_preview = self.generate_next_samples() + previews = self.get_history_previews() + + io.show_image( wnd_name, ( previews[preview_id_counter % len(previews) ][1] *255).astype(np.uint8) ) + + while True: + key_events = io.get_key_events(wnd_name) + key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0,0,False,False,False) + if key == ord('\n') or key == ord('\r'): + choosed = True + break + elif key == ord(' '): + preview_id_counter += 1 + break + elif key == ord('p'): + break + + try: + io.process_messages(0.1) + except KeyboardInterrupt: + choosed = True + + io.destroy_window(wnd_name) + else: + self.sample_for_preview = self.generate_next_samples() + + try: + self.get_history_previews() + except: + self.sample_for_preview = self.generate_next_samples() + + self.last_sample = self.sample_for_preview + + def load_or_def_option(self, name, def_value): + options_val = self.options.get(name, None) + if options_val is not None: + return options_val + + def_opt_val = self.default_options.get(name, None) + if def_opt_val is not None: + return def_opt_val + + return def_value + + def ask_override(self): + return self.is_training and self.iter != 0 and io.input_in_time ("Press enter in 2 seconds to override model settings.", 5 if io.is_colab() else 2 ) + + def ask_autobackup_hour(self, default_value=0): + default_autobackup_hour = self.options['autobackup_hour'] = self.load_or_def_option('autobackup_hour', default_value) + self.options['autobackup_hour'] = io.input_int(f"Autobackup every N hour", default_autobackup_hour, add_info="0..24", help_message="Autobackup model files with preview every N hour. Latest backup located in model/<>_autobackups/01") + + def ask_write_preview_history(self, default_value=False): + default_write_preview_history = self.load_or_def_option('write_preview_history', default_value) + self.options['write_preview_history'] = io.input_bool(f"Write preview history", default_write_preview_history, help_message="Preview history will be writed to _history folder.") + + if self.options['write_preview_history']: + if io.is_support_windows(): + self.choose_preview_history = io.input_bool("Choose image for the preview history", False) + elif io.is_colab(): + self.choose_preview_history = io.input_bool("Randomly choose new image for preview history", False, help_message="Preview image history will stay stuck with old faces if you reuse the same model on different celebs. Choose no unless you are changing src/dst to a new person") + + def ask_target_iter(self, default_value=0): + default_target_iter = self.load_or_def_option('target_iter', default_value) + self.options['target_iter'] = max(0, io.input_int("Target iteration", default_target_iter)) + + def ask_random_flip(self): + default_random_flip = self.load_or_def_option('random_flip', True) + self.options['random_flip'] = io.input_bool("Flip faces randomly", default_random_flip, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.") + + def ask_random_src_flip(self): + default_random_src_flip = self.load_or_def_option('random_src_flip', False) + self.options['random_src_flip'] = io.input_bool("Flip SRC faces randomly", default_random_src_flip, help_message="Random horizontal flip SRC faceset. Covers more angles, but the face may look less naturally.") + + def ask_random_dst_flip(self): + default_random_dst_flip = self.load_or_def_option('random_dst_flip', True) + self.options['random_dst_flip'] = io.input_bool("Flip DST faces randomly", default_random_dst_flip, help_message="Random horizontal flip DST faceset. Makes generalization of src->dst better, if src random flip is not enabled.") + + def ask_batch_size(self, suggest_batch_size=None, range=None): + default_batch_size = self.load_or_def_option('batch_size', suggest_batch_size or self.batch_size) + + batch_size = max(0, io.input_int("Batch_size", default_batch_size, valid_range=range, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually.")) + + if range is not None: + batch_size = np.clip(batch_size, range[0], range[1]) + + self.options['batch_size'] = self.batch_size = batch_size + + + #overridable + def on_initialize_options(self): + pass + + #overridable + def on_initialize(self): + ''' + initialize your models + + store and retrieve your model options in self.options[''] + + check example + ''' + pass + + #overridable + def onSave(self): + #save your models here + pass + + #overridable + def onTrainOneIter(self, sample, generator_list): + #train your models here + + #return array of losses + return ( ('loss_src', 0), ('loss_dst', 0) ) + + #overridable + def onGetPreview(self, sample, for_history=False): + #you can return multiple previews + #return [ ('preview_name',preview_rgb), ... ] + return [] + + #overridable if you want model name differs from folder name + def get_model_name(self): + return self.model_name + + #overridable , return [ [model, filename],... ] list + def get_model_filename_list(self): + return [] + + #overridable + def get_MergerConfig(self): + #return predictor_func, predictor_input_shape, MergerConfig() for the model + raise NotImplementedError + + def get_pretraining_data_path(self): + return self.pretraining_data_path + + def get_target_iter(self): + return self.target_iter + + def is_reached_iter_goal(self): + return self.target_iter != 0 and self.iter >= self.target_iter + + def get_previews(self): + return self.onGetPreview ( self.last_sample ) + + def get_history_previews(self): + return self.onGetPreview (self.sample_for_preview, for_history=True) + + def get_preview_history_writer(self): + if self.preview_history_writer is None: + self.preview_history_writer = PreviewHistoryWriter() + return self.preview_history_writer + + def save(self): + Path( self.get_summary_path() ).write_text( self.get_summary_text() ) + + self.onSave() + + model_data = { + 'iter': self.iter, + 'options': self.options, + 'loss_history': self.loss_history, + 'sample_for_preview' : self.sample_for_preview, + 'choosed_gpu_indexes' : self.choosed_gpu_indexes, + } + pathex.write_bytes_safe (self.model_data_path, pickle.dumps(model_data) ) + + if self.autobackup_hour != 0: + diff_hour = int ( (time.time() - self.autobackup_start_time) // 3600 ) + + if diff_hour > 0 and diff_hour % self.autobackup_hour == 0: + self.autobackup_start_time += self.autobackup_hour*3600 + self.create_backup() + + def create_backup(self): + io.log_info ("Creating backup...", end='\r') + + if not self.autobackups_path.exists(): + self.autobackups_path.mkdir(exist_ok=True) + + bckp_filename_list = [ self.get_strpath_storage_for_file(filename) for _, filename in self.get_model_filename_list() ] + bckp_filename_list += [ str(self.get_summary_path()), str(self.model_data_path) ] + + for i in range(24,0,-1): + idx_str = '%.2d' % i + next_idx_str = '%.2d' % (i+1) + + idx_backup_path = self.autobackups_path / idx_str + next_idx_packup_path = self.autobackups_path / next_idx_str + + if idx_backup_path.exists(): + if i == 24: + pathex.delete_all_files(idx_backup_path) + else: + next_idx_packup_path.mkdir(exist_ok=True) + pathex.move_all_files (idx_backup_path, next_idx_packup_path) + + if i == 1: + idx_backup_path.mkdir(exist_ok=True) + for filename in bckp_filename_list: + shutil.copy ( str(filename), str(idx_backup_path / Path(filename).name) ) + + previews = self.get_previews() + plist = [] + for i in range(len(previews)): + name, bgr = previews[i] + plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ] + + if len(plist) != 0: + self.get_preview_history_writer().post(plist, self.loss_history, self.iter) + + def debug_one_iter(self): + images = [] + for generator in self.generator_list: + for i,batch in enumerate(next(generator)): + if len(batch.shape) == 4: + images.append( batch[0] ) + + return imagelib.equalize_and_stack_square (images) + + def generate_next_samples(self): + sample = [] + for generator in self.generator_list: + if generator.is_initialized(): + sample.append ( generator.generate_next() ) + else: + sample.append ( [] ) + self.last_sample = sample + return sample + + #overridable + def should_save_preview_history(self): + return (not io.is_colab() and self.iter % 10 == 0) or (io.is_colab() and self.iter % 100 == 0) + + def train_one_iter(self): + + iter_time = time.time() + losses = self.onTrainOneIter() + iter_time = time.time() - iter_time + + self.loss_history.append ( [float(loss[1]) for loss in losses] ) + + if self.should_save_preview_history(): + plist = [] + + if io.is_colab(): + previews = self.get_previews() + for i in range(len(previews)): + name, bgr = previews[i] + plist += [ (bgr, self.get_strpath_storage_for_file('preview_%s.jpg' % (name) ) ) ] + + if self.write_preview_history: + previews = self.get_history_previews() + for i in range(len(previews)): + name, bgr = previews[i] + path = self.preview_history_path / name + plist += [ ( bgr, str ( path / ( f'{self.iter:07d}.jpg') ) ) ] + if not io.is_colab(): + plist += [ ( bgr, str ( path / ( '_last.jpg' ) )) ] + + if len(plist) != 0: + self.get_preview_history_writer().post(plist, self.loss_history, self.iter) + + self.iter += 1 + + return self.iter, iter_time + + def pass_one_iter(self): + self.generate_next_samples() + + def finalize(self): + nn.close_session() + + def is_first_run(self): + return self.iter == 0 + + def is_debug(self): + return self.debug + + def set_batch_size(self, batch_size): + self.batch_size = batch_size + + def get_batch_size(self): + return self.batch_size + + def get_iter(self): + return self.iter + + def set_iter(self, iter): + self.iter = iter + self.loss_history = self.loss_history[:iter] + + def get_loss_history(self): + return self.loss_history + + def set_training_data_generators (self, generator_list): + self.generator_list = generator_list + + def get_training_data_generators (self): + return self.generator_list + + def get_model_root_path(self): + return self.saved_models_path + + def get_strpath_storage_for_file(self, filename): + return str( self.saved_models_path / ( self.get_model_name() + '_' + filename) ) + + def get_summary_path(self): + return self.get_strpath_storage_for_file('summary.txt') + + def get_summary_text(self): + visible_options = self.options.copy() + visible_options.update(self.options_show_override) + + ###Generate text summary of model hyperparameters + #Find the longest key name and value string. Used as column widths. + width_name = max([len(k) for k in visible_options.keys()] + [17]) + 1 # Single space buffer to left edge. Minimum of 17, the length of the longest static string used "Current iteration" + width_value = max([len(str(x)) for x in visible_options.values()] + [len(str(self.get_iter())), len(self.get_model_name())]) + 1 # Single space buffer to right edge + if len(self.device_config.devices) != 0: #Check length of GPU names + width_value = max([len(device.name)+1 for device in self.device_config.devices] + [width_value]) + width_total = width_name + width_value + 2 #Plus 2 for ": " + + summary_text = [] + summary_text += [f'=={" Model Summary ":=^{width_total}}=='] # Model/status summary + summary_text += [f'=={" "*width_total}=='] + summary_text += [f'=={"Model name": >{width_name}}: {self.get_model_name(): <{width_value}}=='] # Name + summary_text += [f'=={" "*width_total}=='] + summary_text += [f'=={"Current iteration": >{width_name}}: {str(self.get_iter()): <{width_value}}=='] # Iter + summary_text += [f'=={" "*width_total}=='] + + summary_text += [f'=={" Model Options ":-^{width_total}}=='] # Model options + summary_text += [f'=={" "*width_total}=='] + for key in visible_options.keys(): + summary_text += [f'=={key: >{width_name}}: {str(visible_options[key]): <{width_value}}=='] # visible_options key/value pairs + summary_text += [f'=={" "*width_total}=='] + + summary_text += [f'=={" Running On ":-^{width_total}}=='] # Training hardware info + summary_text += [f'=={" "*width_total}=='] + if len(self.device_config.devices) == 0: + summary_text += [f'=={"Using device": >{width_name}}: {"CPU": <{width_value}}=='] # cpu_only + else: + for device in self.device_config.devices: + summary_text += [f'=={"Device index": >{width_name}}: {device.index: <{width_value}}=='] # GPU hardware device index + summary_text += [f'=={"Name": >{width_name}}: {device.name: <{width_value}}=='] # GPU name + vram_str = f'{device.total_mem_gb:.2f}GB' # GPU VRAM - Formated as #.## (or ##.##) + summary_text += [f'=={"VRAM": >{width_name}}: {vram_str: <{width_value}}=='] + summary_text += [f'=={" "*width_total}=='] + summary_text += [f'=={"="*width_total}=='] + summary_text = "\n".join (summary_text) + return summary_text + + @staticmethod + def get_loss_history_preview(loss_history, iter, w, c): + loss_history = np.array (loss_history.copy()) + + lh_height = 100 + lh_img = np.ones ( (lh_height,w,c) ) * 0.1 + + if len(loss_history) != 0: + loss_count = len(loss_history[0]) + lh_len = len(loss_history) + + l_per_col = lh_len / w + plist_max = [ [ max (0.0, loss_history[int(col*l_per_col)][p], + *[ loss_history[i_ab][p] + for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) ) + ] + ) + for p in range(loss_count) + ] + for col in range(w) + ] + + plist_min = [ [ min (plist_max[col][p], loss_history[int(col*l_per_col)][p], + *[ loss_history[i_ab][p] + for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) ) + ] + ) + for p in range(loss_count) + ] + for col in range(w) + ] + + plist_abs_max = np.mean(loss_history[ len(loss_history) // 5 : ]) * 2 + + for col in range(0, w): + for p in range(0,loss_count): + point_color = [1.0]*c + point_color[0:3] = colorsys.hsv_to_rgb ( p * (1.0/loss_count), 1.0, 1.0 ) + + ph_max = int ( (plist_max[col][p] / plist_abs_max) * (lh_height-1) ) + ph_max = np.clip( ph_max, 0, lh_height-1 ) + + ph_min = int ( (plist_min[col][p] / plist_abs_max) * (lh_height-1) ) + ph_min = np.clip( ph_min, 0, lh_height-1 ) + + for ph in range(ph_min, ph_max+1): + lh_img[ (lh_height-ph-1), col ] = point_color + + lh_lines = 5 + lh_line_height = (lh_height-1)/lh_lines + for i in range(0,lh_lines+1): + lh_img[ int(i*lh_line_height), : ] = (0.8,)*c + + last_line_t = int((lh_lines-1)*lh_line_height) + last_line_b = int(lh_lines*lh_line_height) + + lh_text = 'Iter: %d' % (iter) if iter != 0 else '' + + lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image ( (last_line_b-last_line_t,w,c), lh_text, color=[0.8]*c ) + return lh_img + +class PreviewHistoryWriter(): + def __init__(self): + self.sq = multiprocessing.Queue() + self.p = multiprocessing.Process(target=self.process, args=( self.sq, )) + self.p.daemon = True + self.p.start() + + def process(self, sq): + while True: + while not sq.empty(): + plist, loss_history, iter = sq.get() + + preview_lh_cache = {} + for preview, filepath in plist: + filepath = Path(filepath) + i = (preview.shape[1], preview.shape[2]) + + preview_lh = preview_lh_cache.get(i, None) + if preview_lh is None: + preview_lh = ModelBase.get_loss_history_preview(loss_history, iter, preview.shape[1], preview.shape[2]) + preview_lh_cache[i] = preview_lh + + img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8) + + filepath.parent.mkdir(parents=True, exist_ok=True) + cv2_imwrite (filepath, img ) + + time.sleep(0.01) + + def post(self, plist, loss_history, iter): + self.sq.put ( (plist, loss_history, iter) ) + + # disable pickling + def __getstate__(self): + return dict() + def __setstate__(self, d): + self.__dict__.update(d) diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py new file mode 100644 index 0000000000000000000000000000000000000000..82b0dc55dbdc93008e8888f61509e79bc44e0a57 --- /dev/null +++ b/models/Model_AMP/Model.py @@ -0,0 +1,725 @@ +import multiprocessing +import operator +from functools import partial + +import numpy as np + +from core import mathlib +from core.interact import interact as io +from core.leras import nn +from facelib import FaceType +from models import ModelBase +from samplelib import * +from core.cv2ex import * + +class AMPModel(ModelBase): + + #override + def on_initialize_options(self): + default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224) + default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') + default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) + + default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) + default_inter_dims = self.options['inter_dims'] = self.load_or_def_option('inter_dims', 1024) + + default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) + default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) + default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) + default_morph_factor = self.options['morph_factor'] = self.options.get('morph_factor', 0.5) + default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False) + default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False) + default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', 'n') + default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) + default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none') + default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False) + + ask_override = self.ask_override() + if self.is_first_run() or ask_override: + self.ask_autobackup_hour() + self.ask_write_preview_history() + self.ask_target_iter() + self.ask_random_src_flip() + self.ask_random_dst_flip() + self.ask_batch_size(8) + + if self.is_first_run(): + resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .") + resolution = np.clip ( (resolution // 32) * 32, 64, 640) + self.options['resolution'] = resolution + self.options['face_type'] = io.input_str ("Face type", default_face_type, ['f','wf','head'], help_message="whole face / head").lower() + + + default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64) + + default_d_mask_dims = default_d_dims // 3 + default_d_mask_dims += default_d_mask_dims % 2 + default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims) + + if self.is_first_run(): + self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) + self.options['inter_dims'] = np.clip ( io.input_int("Inter dimensions", default_inter_dims, add_info="32-2048", help_message="Should be equal or more than AutoEncoder dimensions. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 2048 ) + + e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['e_dims'] = e_dims + e_dims % 2 + + d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['d_dims'] = d_dims + d_dims % 2 + + d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 ) + self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2 + + morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="Typical fine value is 0.5"), 0.1, 0.5 ) + self.options['morph_factor'] = morph_factor + + if self.is_first_run() or ask_override: + self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') + self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.') + self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.") + + default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) + default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8) + default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16) + + if self.is_first_run() or ask_override: + self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.") + + self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") + + self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 ) + + if self.options['gan_power'] != 0.0: + gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 ) + self.options['gan_patch_size'] = gan_patch_size + + gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 ) + self.options['gan_dims'] = gan_dims + + self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. If src faceset is deverse enough, then lct mode is fine in most cases.") + self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") + + self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims']) + + #override + def on_initialize(self): + device_config = nn.getCurrentDeviceConfig() + devices = device_config.devices + self.model_data_format = "NCHW" + nn.initialize(data_format=self.model_data_format) + tf = nn.tf + + input_ch=3 + resolution = self.resolution = self.options['resolution'] + e_dims = self.options['e_dims'] + ae_dims = self.options['ae_dims'] + inter_dims = self.inter_dims = self.options['inter_dims'] + inter_res = self.inter_res = resolution // 32 + d_dims = self.options['d_dims'] + d_mask_dims = self.options['d_mask_dims'] + face_type = self.face_type = {'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[ self.options['face_type'] ] + morph_factor = self.options['morph_factor'] + gan_power = self.gan_power = self.options['gan_power'] + random_warp = self.options['random_warp'] + + blur_out_mask = self.options['blur_out_mask'] + + ct_mode = self.options['ct_mode'] + if ct_mode == 'none': + ct_mode = None + + use_fp16 = False + if self.is_exporting: + use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') + + conv_dtype = tf.float16 if use_fp16 else tf.float32 + + class Downscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=5 ): + self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + return tf.nn.leaky_relu(self.conv1(x), 0.1) + + class Upscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=3 ): + self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2) + return x + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, inp): + x = self.conv1(inp) + x = tf.nn.leaky_relu(x, 0.2) + x = self.conv2(x) + x = tf.nn.leaky_relu(inp+x, 0.2) + return x + + class Encoder(nn.ModelBase): + def on_build(self): + self.down1 = Downscale(input_ch, e_dims, kernel_size=5) + self.res1 = ResidualBlock(e_dims) + self.down2 = Downscale(e_dims, e_dims*2, kernel_size=5) + self.down3 = Downscale(e_dims*2, e_dims*4, kernel_size=5) + self.down4 = Downscale(e_dims*4, e_dims*8, kernel_size=5) + self.down5 = Downscale(e_dims*8, e_dims*8, kernel_size=5) + self.res5 = ResidualBlock(e_dims*8) + self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims ) + + def forward(self, x): + if use_fp16: + x = tf.cast(x, tf.float16) + x = self.down1(x) + x = self.res1(x) + x = self.down2(x) + x = self.down3(x) + x = self.down4(x) + x = self.down5(x) + x = self.res5(x) + if use_fp16: + x = tf.cast(x, tf.float32) + x = nn.pixel_norm(nn.flatten(x), axes=-1) + x = self.dense1(x) + return x + + + class Inter(nn.ModelBase): + def on_build(self): + self.dense2 = nn.Dense(ae_dims, inter_res * inter_res * inter_dims) + + def forward(self, inp): + x = inp + x = self.dense2(x) + x = nn.reshape_4D (x, inter_res, inter_res, inter_dims) + return x + + + class Decoder(nn.ModelBase): + def on_build(self ): + self.upscale0 = Upscale(inter_dims, d_dims*8, kernel_size=3) + self.upscale1 = Upscale(d_dims*8, d_dims*8, kernel_size=3) + self.upscale2 = Upscale(d_dims*8, d_dims*4, kernel_size=3) + self.upscale3 = Upscale(d_dims*4, d_dims*2, kernel_size=3) + + self.res0 = ResidualBlock(d_dims*8, kernel_size=3) + self.res1 = ResidualBlock(d_dims*8, kernel_size=3) + self.res2 = ResidualBlock(d_dims*4, kernel_size=3) + self.res3 = ResidualBlock(d_dims*2, kernel_size=3) + + self.upscalem0 = Upscale(inter_dims, d_mask_dims*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3) + self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3) + self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3) + self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + + self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) + self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + + def forward(self, z): + if use_fp16: + z = tf.cast(z, tf.float16) + + x = self.upscale0(z) + x = self.res0(x) + x = self.upscale1(x) + x = self.res1(x) + x = self.upscale2(x) + x = self.res2(x) + x = self.upscale3(x) + x = self.res3(x) + + x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x), + self.out_conv1(x), + self.out_conv2(x), + self.out_conv3(x)), nn.conv2d_ch_axis), 2) ) + m = self.upscalem0(z) + m = self.upscalem1(m) + m = self.upscalem2(m) + m = self.upscalem3(m) + m = self.upscalem4(m) + m = tf.nn.sigmoid(self.out_convm(m)) + + if use_fp16: + x = tf.cast(x, tf.float32) + m = tf.cast(m, tf.float32) + return x, m + + models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] + models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' + optimizer_vars_on_cpu = models_opt_device=='/CPU:0' + + bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) + mask_shape = nn.get4Dshape(resolution,resolution,1) + self.model_filename_list = [] + + with tf.device ('/CPU:0'): + #Place holders on CPU + self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') + self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') + + self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') + self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') + + self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') + self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') + self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') + self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') + + self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t') + + # Initializing model classes + with tf.device (models_opt_device): + self.encoder = Encoder(name='encoder') + self.inter_src = Inter(name='inter_src') + self.inter_dst = Inter(name='inter_dst') + self.decoder = Decoder(name='decoder') + + self.model_filename_list += [ [self.encoder, 'encoder.npy'], + [self.inter_src, 'inter_src.npy'], + [self.inter_dst , 'inter_dst.npy'], + [self.decoder , 'decoder.npy'] ] + + if self.is_training: + # Initialize optimizers + clipnorm = 1.0 if self.options['clipgrad'] else 0.0 + if self.options['lr_dropout'] in ['y','cpu']: + lr_cos = 500 + lr_dropout = 0.3 + else: + lr_cos = 0 + lr_dropout = 1.0 + self.G_weights = self.encoder.get_weights() + self.decoder.get_weights() + + self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='src_dst_opt') + self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu) + self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] + + if gan_power != 0: + self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN") + self.GAN_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='GAN_opt') + self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) + self.model_filename_list += [ [self.GAN, 'GAN.npy'], + [self.GAN_opt, 'GAN_opt.npy'] ] + + if self.is_training: + # Adjust batch size for multiple GPU + gpu_count = max(1, len(devices) ) + bs_per_gpu = max(1, self.get_batch_size() // gpu_count) + self.set_batch_size( gpu_count*bs_per_gpu) + + # Compute losses per GPU + gpu_pred_src_src_list = [] + gpu_pred_dst_dst_list = [] + gpu_pred_src_dst_list = [] + gpu_pred_src_srcm_list = [] + gpu_pred_dst_dstm_list = [] + gpu_pred_src_dstm_list = [] + + gpu_src_losses = [] + gpu_dst_losses = [] + gpu_G_loss_gradients = [] + gpu_GAN_loss_gradients = [] + + def DLossOnes(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3]) + + def DLossZeros(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3]) + + for gpu_id in range(gpu_count): + with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + with tf.device(f'/CPU:0'): + # slice on CPU, otherwise all batch data will be transfered to GPU first + batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) + gpu_warped_src = self.warped_src [batch_slice,:,:,:] + gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] + gpu_target_src = self.target_src [batch_slice,:,:,:] + gpu_target_dst = self.target_dst [batch_slice,:,:,:] + gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] + gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] + gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] + gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] + + # process model tensors + gpu_src_code = self.encoder (gpu_warped_src) + gpu_dst_code = self.encoder (gpu_warped_dst) + + gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code) + gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code) + + inter_dims_bin = int(inter_dims*morph_factor) + with tf.device(f'/CPU:0'): + inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )), + tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0) + + inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None]) + + gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial) + gpu_dst_code = gpu_dst_inter_dst_code + + inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) + gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) + + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) + gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + + gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) + gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) + gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) + + gpu_target_srcm_anti = 1-gpu_target_srcm + gpu_target_dstm_anti = 1-gpu_target_dstm + + gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32) + gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32) + + gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2 + gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2 + gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur + gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur + + if blur_out_mask: + sigma = resolution / 128 + + x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_srcm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti + + x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_dstm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti + + gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur + gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur + gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur + gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur + + gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur + gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur + gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur + gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur + + # Structural loss + gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + gpu_dst_loss = tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) + + # Pixel loss + gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src_masked-gpu_pred_src_src_masked), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst_masked-gpu_pred_dst_dst_masked), axis=[1,2,3]) + + # Eyes+mouth prio loss + gpu_src_loss += tf.reduce_mean (300*tf.abs (gpu_target_src*gpu_target_srcm_em-gpu_pred_src_src*gpu_target_srcm_em), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean (300*tf.abs (gpu_target_dst*gpu_target_dstm_em-gpu_pred_dst_dst*gpu_target_dstm_em), axis=[1,2,3]) + + # Mask loss + gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) + + gpu_src_losses += [gpu_src_loss] + gpu_dst_losses += [gpu_dst_loss] + gpu_G_loss = gpu_src_loss + gpu_dst_loss + # dst-dst background weak loss + gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] ) + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked) + + + if gan_power != 0: + gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked) + gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked) + gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked) + gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked) + + gpu_GAN_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \ + DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \ + DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \ + DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2) + ) * (1.0 / 8) + + gpu_GAN_loss_gradients += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ] + + gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \ + DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2) + ) * gan_power + + # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) + gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) + + gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.G_weights ) ] + + # Average losses and gradients, and create optimizer update ops + with tf.device(f'/CPU:0'): + pred_src_src = nn.concat(gpu_pred_src_src_list, 0) + pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) + pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) + pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) + pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) + pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) + + with tf.device (models_opt_device): + src_loss = tf.concat(gpu_src_losses, 0) + dst_loss = tf.concat(gpu_dst_losses, 0) + train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients)) + + if gan_power != 0: + GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) ) + + # Initializing training and view functions + def train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + s, d, _ = nn.tf_sess.run ([src_loss, dst_loss, train_op], + feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em, + }) + return s, d + self.train = train + + if gan_power != 0: + def GAN_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + nn.tf_sess.run ([GAN_train_op], feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em}) + self.GAN_train = GAN_train + + def AE_view(warped_src, warped_dst, morph_value): + return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], + feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) + + self.AE_view = AE_view + else: + #Initializing merge function + with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): + gpu_dst_code = self.encoder (self.warped_dst) + gpu_dst_inter_src_code = self.inter_src (gpu_dst_code) + gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) + + inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) + gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code) + + def AE_merge(warped_dst, morph_value): + return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) + + self.AE_merge = AE_merge + + # Loading/initializing all models/optimizers weights + for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): + do_init = self.is_first_run() + if self.is_training and gan_power != 0 and model == self.GAN: + if self.gan_model_changed: + do_init = True + if not do_init: + do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) + if do_init: + model.init_weights() + ############### + + # initializing sample generators + if self.is_training: + training_data_src_path = self.training_data_src_path #if not self.pretrain else self.get_pretraining_data_path() + training_data_dst_path = self.training_data_dst_path #if not self.pretrain else self.get_pretraining_data_path() + + random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain + + cpu_count = multiprocessing.cpu_count() + src_generators_count = cpu_count // 2 + dst_generators_count = cpu_count // 2 + if ct_mode is not None: + src_generators_count = int(src_generators_count * 1.5) + + + + self.set_training_data_generators ([ + SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=self.random_src_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, + generators_count=src_generators_count ), + + SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=self.random_dst_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, + generators_count=dst_generators_count ) + ]) + + def export_dfm (self): + output_path=self.get_strpath_storage_for_file('model.dfm') + + io.log_info(f'Dumping .dfm to {output_path}') + + tf = nn.tf + with tf.device (nn.tf_default_device_name): + warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') + warped_dst = tf.transpose(warped_dst, (0,3,1,2)) + morph_value = tf.placeholder (nn.floatx, (1,), name='morph_value') + + gpu_dst_code = self.encoder (warped_dst) + gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code) + gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code) + + inter_dims_slice = tf.cast(self.inter_dims*morph_value[0], tf.int32) + gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , self.inter_res, self.inter_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,self.inter_dims-inter_dims_slice, self.inter_res,self.inter_res]) ), 1 ) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code) + + gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1)) + gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1)) + gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1)) + + tf.identity(gpu_pred_dst_dstm, name='out_face_mask') + tf.identity(gpu_pred_src_dst, name='out_celeb_face') + tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') + + output_graph_def = tf.graph_util.convert_variables_to_constants( + nn.tf_sess, + tf.get_default_graph().as_graph_def(), + ['out_face_mask','out_celeb_face','out_celeb_face_mask'] + ) + + import tf2onnx + with tf.device("/CPU:0"): + model_proto, _ = tf2onnx.convert._convert_common( + output_graph_def, + name='AMP', + input_names=['in_face:0','morph_value:0'], + output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'], + opset=12, + output_path=output_path) + + #override + def get_model_filename_list(self): + return self.model_filename_list + + #override + def onSave(self): + for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): + model.save_weights ( self.get_strpath_storage_for_file(filename) ) + + #override + def should_save_preview_history(self): + return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \ + (io.is_colab() and self.iter % 100 == 0) + + #override + def onTrainOneIter(self): + bs = self.get_batch_size() + + ( (warped_src, target_src, target_srcm, target_srcm_em), \ + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() + + src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + if self.gan_power != 0: + self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) + + #override + def onGetPreview(self, samples, for_history=False): + ( (warped_src, target_src, target_srcm, target_srcm_em), + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples + + S, D, SS, DD, DDM_000, _, _ = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst, 0.0) ) ] + + _, _, DDM_025, SD_025, SDM_025 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.25) ] + _, _, DDM_050, SD_050, SDM_050 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.50) ] + _, _, DDM_065, SD_065, SDM_065 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.65) ] + _, _, DDM_075, SD_075, SDM_075 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.75) ] + _, _, DDM_100, SD_100, SDM_100 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 1.00) ] + + (DDM_000, + DDM_025, SDM_025, + DDM_050, SDM_050, + DDM_065, SDM_065, + DDM_075, SDM_075, + DDM_100, SDM_100) = [ np.repeat (x, (3,), -1) for x in (DDM_000, + DDM_025, SDM_025, + DDM_050, SDM_050, + DDM_065, SDM_065, + DDM_075, SDM_075, + DDM_100, SDM_100) ] + + target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] + + n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) + + result = [] + + i = np.random.randint(n_samples) if not for_history else 0 + + st = [ np.concatenate ((S[i], D[i], DD[i]*DDM_000[i]), axis=1) ] + st += [ np.concatenate ((SS[i], DD[i], SD_100[i] ), axis=1) ] + + result += [ ('AMP morph 1.0', np.concatenate (st, axis=0 )), ] + + st = [ np.concatenate ((DD[i], SD_025[i], SD_050[i]), axis=1) ] + st += [ np.concatenate ((SD_065[i], SD_075[i], SD_100[i]), axis=1) ] + result += [ ('AMP morph list', np.concatenate (st, axis=0 )), ] + + st = [ np.concatenate ((DD[i], SD_025[i]*DDM_025[i]*SDM_025[i], SD_050[i]*DDM_050[i]*SDM_050[i]), axis=1) ] + st += [ np.concatenate ((SD_065[i]*DDM_065[i]*SDM_065[i], SD_075[i]*DDM_075[i]*SDM_075[i], SD_100[i]*DDM_100[i]*SDM_100[i]), axis=1) ] + result += [ ('AMP morph list masked', np.concatenate (st, axis=0 )), ] + + return result + + def predictor_func (self, face, morph_value): + face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") + + bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face, morph_value) ] + + return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] + + #override + def get_MergerConfig(self): + morph_factor = np.clip ( io.input_number ("Morph factor", 1.0, add_info="0.0 .. 1.0"), 0.0, 1.0 ) + + def predictor_morph(face): + return self.predictor_func(face, morph_factor) + + + import merger + return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') + +Model = AMPModel diff --git a/models/Model_AMP/__init__.py b/models/Model_AMP/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0188f11aec7882710edf9d40586a00823f0d8c20 --- /dev/null +++ b/models/Model_AMP/__init__.py @@ -0,0 +1 @@ +from .Model import Model diff --git a/models/Model_AMP/defModel.py b/models/Model_AMP/defModel.py new file mode 100644 index 0000000000000000000000000000000000000000..04a1c2653194222d56f6239d534cfb4b8e60a9c1 --- /dev/null +++ b/models/Model_AMP/defModel.py @@ -0,0 +1,724 @@ +import multiprocessing +import operator +from functools import partial + +import numpy as np + +from core import mathlib +from core.interact import interact as io +from core.leras import nn +from facelib import FaceType +from models import ModelBase +from samplelib import * +from core.cv2ex import * + +class AMPModel(ModelBase): + + #override + def on_initialize_options(self): + default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224) + default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') + default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) + + default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) + default_inter_dims = self.options['inter_dims'] = self.load_or_def_option('inter_dims', 1024) + + default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) + default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) + default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) + default_morph_factor = self.options['morph_factor'] = self.options.get('morph_factor', 0.5) + default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False) + default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False) + default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', 'n') + default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) + default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none') + default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False) + + ask_override = self.ask_override() + if self.is_first_run() or ask_override: + self.ask_autobackup_hour() + self.ask_write_preview_history() + self.ask_target_iter() + self.ask_random_src_flip() + self.ask_random_dst_flip() + self.ask_batch_size(8) + + if self.is_first_run(): + resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .") + resolution = np.clip ( (resolution // 32) * 32, 64, 640) + self.options['resolution'] = resolution + self.options['face_type'] = io.input_str ("Face type", default_face_type, ['f','wf','head'], help_message="whole face / head").lower() + + + default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64) + + default_d_mask_dims = default_d_dims // 3 + default_d_mask_dims += default_d_mask_dims % 2 + default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims) + + if self.is_first_run(): + self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) + self.options['inter_dims'] = np.clip ( io.input_int("Inter dimensions", default_inter_dims, add_info="32-2048", help_message="Should be equal or more than AutoEncoder dimensions. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 2048 ) + + e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['e_dims'] = e_dims + e_dims % 2 + + d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['d_dims'] = d_dims + d_dims % 2 + + d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 ) + self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2 + + morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="Typical fine value is 0.5"), 0.1, 0.5 ) + self.options['morph_factor'] = morph_factor + + if self.is_first_run() or ask_override: + self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') + self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.') + self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.") + + default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) + default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8) + default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16) + + if self.is_first_run() or ask_override: + self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.") + + self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") + + self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 ) + + if self.options['gan_power'] != 0.0: + gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 ) + self.options['gan_patch_size'] = gan_patch_size + + gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 ) + self.options['gan_dims'] = gan_dims + + self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. If src faceset is deverse enough, then lct mode is fine in most cases.") + self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") + + self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims']) + + #override + def on_initialize(self): + device_config = nn.getCurrentDeviceConfig() + devices = device_config.devices + self.model_data_format = "NCHW" + nn.initialize(data_format=self.model_data_format) + tf = nn.tf + + input_ch=3 + resolution = self.resolution = self.options['resolution'] + e_dims = self.options['e_dims'] + ae_dims = self.options['ae_dims'] + inter_dims = self.inter_dims = self.options['inter_dims'] + inter_res = self.inter_res = resolution // 32 + d_dims = self.options['d_dims'] + d_mask_dims = self.options['d_mask_dims'] + face_type = self.face_type = {'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[ self.options['face_type'] ] + morph_factor = self.options['morph_factor'] + gan_power = self.gan_power = self.options['gan_power'] + random_warp = self.options['random_warp'] + + blur_out_mask = self.options['blur_out_mask'] + + ct_mode = self.options['ct_mode'] + if ct_mode == 'none': + ct_mode = None + + use_fp16 = False + if self.is_exporting: + use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') + + conv_dtype = tf.float16 if use_fp16 else tf.float32 + + class Downscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=5 ): + self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + return tf.nn.leaky_relu(self.conv1(x), 0.1) + + class Upscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=3 ): + self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2) + return x + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, inp): + x = self.conv1(inp) + x = tf.nn.leaky_relu(x, 0.2) + x = self.conv2(x) + x = tf.nn.leaky_relu(inp+x, 0.2) + return x + + class Encoder(nn.ModelBase): + def on_build(self): + self.down1 = Downscale(input_ch, e_dims, kernel_size=5) + self.res1 = ResidualBlock(e_dims) + self.down2 = Downscale(e_dims, e_dims*2, kernel_size=5) + self.down3 = Downscale(e_dims*2, e_dims*4, kernel_size=5) + self.down4 = Downscale(e_dims*4, e_dims*8, kernel_size=5) + self.down5 = Downscale(e_dims*8, e_dims*8, kernel_size=5) + self.res5 = ResidualBlock(e_dims*8) + self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims ) + + def forward(self, x): + if use_fp16: + x = tf.cast(x, tf.float16) + x = self.down1(x) + x = self.res1(x) + x = self.down2(x) + x = self.down3(x) + x = self.down4(x) + x = self.down5(x) + x = self.res5(x) + if use_fp16: + x = tf.cast(x, tf.float32) + x = nn.pixel_norm(nn.flatten(x), axes=-1) + x = self.dense1(x) + return x + + + class Inter(nn.ModelBase): + def on_build(self): + self.dense2 = nn.Dense(ae_dims, inter_res * inter_res * inter_dims) + + def forward(self, inp): + x = inp + x = self.dense2(x) + x = nn.reshape_4D (x, inter_res, inter_res, inter_dims) + return x + + + class Decoder(nn.ModelBase): + def on_build(self ): + self.upscale0 = Upscale(inter_dims, d_dims*8, kernel_size=3) + self.upscale1 = Upscale(d_dims*8, d_dims*8, kernel_size=3) + self.upscale2 = Upscale(d_dims*8, d_dims*4, kernel_size=3) + self.upscale3 = Upscale(d_dims*4, d_dims*2, kernel_size=3) + + self.res0 = ResidualBlock(d_dims*8, kernel_size=3) + self.res1 = ResidualBlock(d_dims*8, kernel_size=3) + self.res2 = ResidualBlock(d_dims*4, kernel_size=3) + self.res3 = ResidualBlock(d_dims*2, kernel_size=3) + + self.upscalem0 = Upscale(inter_dims, d_mask_dims*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3) + self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3) + self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3) + self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + + self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) + self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + + def forward(self, z): + if use_fp16: + z = tf.cast(z, tf.float16) + + x = self.upscale0(z) + x = self.res0(x) + x = self.upscale1(x) + x = self.res1(x) + x = self.upscale2(x) + x = self.res2(x) + x = self.upscale3(x) + x = self.res3(x) + + x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x), + self.out_conv1(x), + self.out_conv2(x), + self.out_conv3(x)), nn.conv2d_ch_axis), 2) ) + m = self.upscalem0(z) + m = self.upscalem1(m) + m = self.upscalem2(m) + m = self.upscalem3(m) + m = self.upscalem4(m) + m = tf.nn.sigmoid(self.out_convm(m)) + + if use_fp16: + x = tf.cast(x, tf.float32) + m = tf.cast(m, tf.float32) + return x, m + + models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] + models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' + optimizer_vars_on_cpu = models_opt_device=='/CPU:0' + + bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) + mask_shape = nn.get4Dshape(resolution,resolution,1) + self.model_filename_list = [] + + with tf.device ('/CPU:0'): + #Place holders on CPU + self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') + self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') + + self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') + self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') + + self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') + self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') + self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') + self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') + + self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t') + + # Initializing model classes + with tf.device (models_opt_device): + self.encoder = Encoder(name='encoder') + self.inter_src = Inter(name='inter_src') + self.inter_dst = Inter(name='inter_dst') + self.decoder = Decoder(name='decoder') + + self.model_filename_list += [ [self.encoder, 'encoder.npy'], + [self.inter_src, 'inter_src.npy'], + [self.inter_dst , 'inter_dst.npy'], + [self.decoder , 'decoder.npy'] ] + + if self.is_training: + # Initialize optimizers + clipnorm = 1.0 if self.options['clipgrad'] else 0.0 + lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] else 1.0 + + self.G_weights = self.encoder.get_weights() + self.decoder.get_weights() + + #if random_warp: + # self.G_weights += self.inter_src.get_weights() + self.inter_dst.get_weights() + + self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') + self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu) + self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] + + if gan_power != 0: + self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN") + self.GAN_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt') + self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) + self.model_filename_list += [ [self.GAN, 'GAN.npy'], + [self.GAN_opt, 'GAN_opt.npy'] ] + + if self.is_training: + # Adjust batch size for multiple GPU + gpu_count = max(1, len(devices) ) + bs_per_gpu = max(1, self.get_batch_size() // gpu_count) + self.set_batch_size( gpu_count*bs_per_gpu) + + # Compute losses per GPU + gpu_pred_src_src_list = [] + gpu_pred_dst_dst_list = [] + gpu_pred_src_dst_list = [] + gpu_pred_src_srcm_list = [] + gpu_pred_dst_dstm_list = [] + gpu_pred_src_dstm_list = [] + + gpu_src_losses = [] + gpu_dst_losses = [] + gpu_G_loss_gradients = [] + gpu_GAN_loss_gradients = [] + + def DLossOnes(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3]) + + def DLossZeros(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3]) + + for gpu_id in range(gpu_count): + with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + with tf.device(f'/CPU:0'): + # slice on CPU, otherwise all batch data will be transfered to GPU first + batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) + gpu_warped_src = self.warped_src [batch_slice,:,:,:] + gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] + gpu_target_src = self.target_src [batch_slice,:,:,:] + gpu_target_dst = self.target_dst [batch_slice,:,:,:] + gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] + gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] + gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] + gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] + + # process model tensors + gpu_src_code = self.encoder (gpu_warped_src) + gpu_dst_code = self.encoder (gpu_warped_dst) + + gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code) + gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code) + + inter_dims_bin = int(inter_dims*morph_factor) + with tf.device(f'/CPU:0'): + inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )), + tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0) + + inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None]) + + gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial) + gpu_dst_code = gpu_dst_inter_dst_code + + inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) + gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) + + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) + gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + + gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) + gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) + gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) + + gpu_target_srcm_anti = 1-gpu_target_srcm + gpu_target_dstm_anti = 1-gpu_target_dstm + + gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32) + gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32) + + gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2 + gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2 + gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur + gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur + + if blur_out_mask: + sigma = resolution / 128 + + x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_srcm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti + + x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_dstm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti + + gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur + gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur + gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur + gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur + + gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur + gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur + gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur + gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur + + # Structural loss + gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + gpu_dst_loss = tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) + + # Pixel loss + gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src_masked-gpu_pred_src_src_masked), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst_masked-gpu_pred_dst_dst_masked), axis=[1,2,3]) + + # Eyes+mouth prio loss + gpu_src_loss += tf.reduce_mean (300*tf.abs (gpu_target_src*gpu_target_srcm_em-gpu_pred_src_src*gpu_target_srcm_em), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean (300*tf.abs (gpu_target_dst*gpu_target_dstm_em-gpu_pred_dst_dst*gpu_target_dstm_em), axis=[1,2,3]) + + # Mask loss + gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) + + gpu_src_losses += [gpu_src_loss] + gpu_dst_losses += [gpu_dst_loss] + gpu_G_loss = gpu_src_loss + gpu_dst_loss + # dst-dst background weak loss + gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] ) + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked) + + + if gan_power != 0: + gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked) + gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked) + gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked) + gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked) + + gpu_GAN_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \ + DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \ + DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \ + DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2) + ) * (1.0 / 8) + + gpu_GAN_loss_gradients += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ] + + gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \ + DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2) + ) * gan_power + + # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) + gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) + + gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.G_weights ) ] + + # Average losses and gradients, and create optimizer update ops + with tf.device(f'/CPU:0'): + pred_src_src = nn.concat(gpu_pred_src_src_list, 0) + pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) + pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) + pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) + pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) + pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) + + with tf.device (models_opt_device): + src_loss = tf.concat(gpu_src_losses, 0) + dst_loss = tf.concat(gpu_dst_losses, 0) + train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients)) + + if gan_power != 0: + GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) ) + + # Initializing training and view functions + def train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + s, d, _ = nn.tf_sess.run ([src_loss, dst_loss, train_op], + feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em, + }) + return s, d + self.train = train + + if gan_power != 0: + def GAN_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + nn.tf_sess.run ([GAN_train_op], feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em}) + self.GAN_train = GAN_train + + def AE_view(warped_src, warped_dst, morph_value): + return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], + feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) + + self.AE_view = AE_view + else: + #Initializing merge function + with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): + gpu_dst_code = self.encoder (self.warped_dst) + gpu_dst_inter_src_code = self.inter_src (gpu_dst_code) + gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) + + inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) + gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code) + + def AE_merge(warped_dst, morph_value): + return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) + + self.AE_merge = AE_merge + + # Loading/initializing all models/optimizers weights + for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): + do_init = self.is_first_run() + if self.is_training and gan_power != 0 and model == self.GAN: + if self.gan_model_changed: + do_init = True + if not do_init: + do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) + if do_init: + model.init_weights() + ############### + + # initializing sample generators + if self.is_training: + training_data_src_path = self.training_data_src_path #if not self.pretrain else self.get_pretraining_data_path() + training_data_dst_path = self.training_data_dst_path #if not self.pretrain else self.get_pretraining_data_path() + + random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain + + cpu_count = multiprocessing.cpu_count() + src_generators_count = cpu_count // 2 + dst_generators_count = cpu_count // 2 + if ct_mode is not None: + src_generators_count = int(src_generators_count * 1.5) + + + + self.set_training_data_generators ([ + SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(scale_range=[-0.125, 0.125], random_flip=self.random_src_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, + generators_count=src_generators_count ), + + SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(scale_range=[-0.125, 0.125], random_flip=self.random_dst_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, + generators_count=dst_generators_count ) + ]) + + def export_dfm (self): + output_path=self.get_strpath_storage_for_file('model.dfm') + + io.log_info(f'Dumping .dfm to {output_path}') + + tf = nn.tf + with tf.device (nn.tf_default_device_name): + warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') + warped_dst = tf.transpose(warped_dst, (0,3,1,2)) + morph_value = tf.placeholder (nn.floatx, (1,), name='morph_value') + + gpu_dst_code = self.encoder (warped_dst) + gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code) + gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code) + + inter_dims_slice = tf.cast(self.inter_dims*morph_value[0], tf.int32) + gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , self.inter_res, self.inter_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,self.inter_dims-inter_dims_slice, self.inter_res,self.inter_res]) ), 1 ) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code) + + gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1)) + gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1)) + gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1)) + + tf.identity(gpu_pred_dst_dstm, name='out_face_mask') + tf.identity(gpu_pred_src_dst, name='out_celeb_face') + tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') + + output_graph_def = tf.graph_util.convert_variables_to_constants( + nn.tf_sess, + tf.get_default_graph().as_graph_def(), + ['out_face_mask','out_celeb_face','out_celeb_face_mask'] + ) + + import tf2onnx + with tf.device("/CPU:0"): + model_proto, _ = tf2onnx.convert._convert_common( + output_graph_def, + name='AMP', + input_names=['in_face:0','morph_value:0'], + output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'], + opset=9, + output_path=output_path) + + #override + def get_model_filename_list(self): + return self.model_filename_list + + #override + def onSave(self): + for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): + model.save_weights ( self.get_strpath_storage_for_file(filename) ) + + #override + def should_save_preview_history(self): + return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \ + (io.is_colab() and self.iter % 100 == 0) + + #override + def onTrainOneIter(self): + bs = self.get_batch_size() + + ( (warped_src, target_src, target_srcm, target_srcm_em), \ + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() + + src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + if self.gan_power != 0: + self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) + + #override + def onGetPreview(self, samples, for_history=False): + ( (warped_src, target_src, target_srcm, target_srcm_em), + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples + + S, D, SS, DD, DDM_000, _, _ = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst, 0.0) ) ] + + _, _, DDM_025, SD_025, SDM_025 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.25) ] + _, _, DDM_050, SD_050, SDM_050 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.50) ] + _, _, DDM_065, SD_065, SDM_065 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.65) ] + _, _, DDM_075, SD_075, SDM_075 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.75) ] + _, _, DDM_100, SD_100, SDM_100 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 1.00) ] + + (DDM_000, + DDM_025, SDM_025, + DDM_050, SDM_050, + DDM_065, SDM_065, + DDM_075, SDM_075, + DDM_100, SDM_100) = [ np.repeat (x, (3,), -1) for x in (DDM_000, + DDM_025, SDM_025, + DDM_050, SDM_050, + DDM_065, SDM_065, + DDM_075, SDM_075, + DDM_100, SDM_100) ] + + target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] + + n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) + + result = [] + + i = np.random.randint(n_samples) if not for_history else 0 + + st = [ np.concatenate ((S[i], D[i], DD[i]*DDM_000[i]), axis=1) ] + st += [ np.concatenate ((SS[i], DD[i], SD_100[i] ), axis=1) ] + + result += [ ('AMP morph 1.0', np.concatenate (st, axis=0 )), ] + + st = [ np.concatenate ((DD[i], SD_025[i], SD_050[i]), axis=1) ] + st += [ np.concatenate ((SD_065[i], SD_075[i], SD_100[i]), axis=1) ] + result += [ ('AMP morph list', np.concatenate (st, axis=0 )), ] + + st = [ np.concatenate ((DD[i], SD_025[i]*DDM_025[i]*SDM_025[i], SD_050[i]*DDM_050[i]*SDM_050[i]), axis=1) ] + st += [ np.concatenate ((SD_065[i]*DDM_065[i]*SDM_065[i], SD_075[i]*DDM_075[i]*SDM_075[i], SD_100[i]*DDM_100[i]*SDM_100[i]), axis=1) ] + result += [ ('AMP morph list masked', np.concatenate (st, axis=0 )), ] + + return result + + def predictor_func (self, face, morph_value): + face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") + + bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face, morph_value) ] + + return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] + + #override + def get_MergerConfig(self): + morph_factor = np.clip ( io.input_number ("Morph factor", 1.0, add_info="0.0 .. 1.0"), 0.0, 1.0 ) + + def predictor_morph(face): + return self.predictor_func(face, morph_factor) + + + import merger + return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') + +Model = AMPModel diff --git a/models/Model_Quick96/Model.py b/models/Model_Quick96/Model.py new file mode 100644 index 0000000000000000000000000000000000000000..fa9e21544055a5fac5528bf1897180b79d028da2 --- /dev/null +++ b/models/Model_Quick96/Model.py @@ -0,0 +1,321 @@ +import multiprocessing +from functools import partial + +import numpy as np + +from core import mathlib +from core.interact import interact as io +from core.leras import nn +from facelib import FaceType +from models import ModelBase +from samplelib import * + +class QModel(ModelBase): + #override + def on_initialize(self): + device_config = nn.getCurrentDeviceConfig() + devices = device_config.devices + self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC" + nn.initialize(data_format=self.model_data_format) + tf = nn.tf + + resolution = self.resolution = 96 + self.face_type = FaceType.FULL + ae_dims = 128 + e_dims = 64 + d_dims = 64 + d_mask_dims = 16 + self.pretrain = False + self.pretrain_just_disabled = False + + masked_training = True + + models_opt_on_gpu = len(devices) >= 1 and all([dev.total_mem_gb >= 4 for dev in devices]) + models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' + optimizer_vars_on_cpu = models_opt_device=='/CPU:0' + + input_ch = 3 + bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) + mask_shape = nn.get4Dshape(resolution,resolution,1) + + self.model_filename_list = [] + + model_archi = nn.DeepFakeArchi(resolution, opts='ud') + + with tf.device ('/CPU:0'): + #Place holders on CPU + self.warped_src = tf.placeholder (nn.floatx, bgr_shape) + self.warped_dst = tf.placeholder (nn.floatx, bgr_shape) + + self.target_src = tf.placeholder (nn.floatx, bgr_shape) + self.target_dst = tf.placeholder (nn.floatx, bgr_shape) + + self.target_srcm = tf.placeholder (nn.floatx, mask_shape) + self.target_dstm = tf.placeholder (nn.floatx, mask_shape) + + # Initializing model classes + with tf.device (models_opt_device): + self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') + encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 + + self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') + inter_out_ch = self.inter.get_out_ch() + + self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') + self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') + + self.model_filename_list += [ [self.encoder, 'encoder.npy' ], + [self.inter, 'inter.npy' ], + [self.decoder_src, 'decoder_src.npy'], + [self.decoder_dst, 'decoder_dst.npy'] ] + + if self.is_training: + self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() + + # Initialize optimizers + self.src_dst_opt = nn.RMSprop(lr=2e-4, lr_dropout=0.3, name='src_dst_opt') + self.src_dst_opt.initialize_variables(self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu ) + self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] + + if self.is_training: + # Adjust batch size for multiple GPU + gpu_count = max(1, len(devices) ) + bs_per_gpu = max(1, 4 // gpu_count) + self.set_batch_size( gpu_count*bs_per_gpu) + + # Compute losses per GPU + gpu_pred_src_src_list = [] + gpu_pred_dst_dst_list = [] + gpu_pred_src_dst_list = [] + gpu_pred_src_srcm_list = [] + gpu_pred_dst_dstm_list = [] + gpu_pred_src_dstm_list = [] + + gpu_src_losses = [] + gpu_dst_losses = [] + gpu_src_dst_loss_gvs = [] + + for gpu_id in range(gpu_count): + with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) + with tf.device(f'/CPU:0'): + # slice on CPU, otherwise all batch data will be transfered to GPU first + gpu_warped_src = self.warped_src [batch_slice,:,:,:] + gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] + gpu_target_src = self.target_src [batch_slice,:,:,:] + gpu_target_dst = self.target_dst [batch_slice,:,:,:] + gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] + gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] + + # process model tensors + gpu_src_code = self.inter(self.encoder(gpu_warped_src)) + gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) + gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + + gpu_pred_src_src_list.append(gpu_pred_src_src) + gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) + gpu_pred_src_dst_list.append(gpu_pred_src_dst) + + gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) + gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) + gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) + + gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) + gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) + + gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur + gpu_target_dst_anti_masked = gpu_target_dst*(1.0 - gpu_target_dstm_blur) + + gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src + gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst + + gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src + gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst + + gpu_psd_target_dst_masked = gpu_pred_src_dst*gpu_target_dstm_blur + gpu_psd_target_dst_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_blur) + + gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) + gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) + + gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) + + gpu_src_losses += [gpu_src_loss] + gpu_dst_losses += [gpu_dst_loss] + + gpu_G_loss = gpu_src_loss + gpu_dst_loss + gpu_src_dst_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights ) ] + + + # Average losses and gradients, and create optimizer update ops + with tf.device (models_opt_device): + pred_src_src = nn.concat(gpu_pred_src_src_list, 0) + pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) + pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) + pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) + pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) + pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) + + src_loss = nn.average_tensor_list(gpu_src_losses) + dst_loss = nn.average_tensor_list(gpu_dst_losses) + src_dst_loss_gv = nn.average_gv_list (gpu_src_dst_loss_gvs) + src_dst_loss_gv_op = self.src_dst_opt.get_update_op (src_dst_loss_gv) + + # Initializing training and view functions + def src_dst_train(warped_src, target_src, target_srcm, \ + warped_dst, target_dst, target_dstm): + s, d, _ = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], + feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + }) + s = np.mean(s) + d = np.mean(d) + return s, d + self.src_dst_train = src_dst_train + + def AE_view(warped_src, warped_dst): + return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], + feed_dict={self.warped_src:warped_src, + self.warped_dst:warped_dst}) + + self.AE_view = AE_view + else: + # Initializing merge function + with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): + gpu_dst_code = self.inter(self.encoder(self.warped_dst)) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) + + def AE_merge( warped_dst): + + return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) + + self.AE_merge = AE_merge + + # Loading/initializing all models/optimizers weights + for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): + if self.pretrain_just_disabled: + do_init = False + if model == self.inter: + do_init = True + else: + do_init = self.is_first_run() + + if not do_init: + do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) + + if do_init and self.pretrained_model_path is not None: + pretrained_filepath = self.pretrained_model_path / filename + if pretrained_filepath.exists(): + do_init = not model.load_weights(pretrained_filepath) + + if do_init: + model.init_weights() + + # initializing sample generators + if self.is_training: + training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() + training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() + + cpu_count = min(multiprocessing.cpu_count(), 8) + src_generators_count = cpu_count // 2 + dst_generators_count = cpu_count // 2 + + self.set_training_data_generators ([ + SampleGeneratorFace(training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution} + ], + generators_count=src_generators_count ), + + SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=True if self.pretrain else False), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution} + ], + generators_count=dst_generators_count ) + ]) + + self.last_samples = None + + #override + def get_model_filename_list(self): + return self.model_filename_list + + #override + def onSave(self): + for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): + model.save_weights ( self.get_strpath_storage_for_file(filename) ) + + #override + def onTrainOneIter(self): + + if self.get_iter() % 3 == 0 and self.last_samples is not None: + ( (warped_src, target_src, target_srcm), \ + (warped_dst, target_dst, target_dstm) ) = self.last_samples + warped_src = target_src + warped_dst = target_dst + else: + samples = self.last_samples = self.generate_next_samples() + ( (warped_src, target_src, target_srcm), \ + (warped_dst, target_dst, target_dstm) ) = samples + + src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, + warped_dst, target_dst, target_dstm) + + return ( ('src_loss', src_loss), ('dst_loss', dst_loss), ) + + #override + def onGetPreview(self, samples, for_history=False): + ( (warped_src, target_src, target_srcm), + (warped_dst, target_dst, target_dstm) ) = samples + + S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] + DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] + + target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] + + n_samples = min(4, self.get_batch_size() ) + result = [] + st = [] + for i in range(n_samples): + ar = S[i], SS[i], D[i], DD[i], SD[i] + st.append ( np.concatenate ( ar, axis=1) ) + + result += [ ('Quick96', np.concatenate (st, axis=0 )), ] + + st_m = [] + for i in range(n_samples): + ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) + st_m.append ( np.concatenate ( ar, axis=1) ) + + result += [ ('Quick96 masked', np.concatenate (st_m, axis=0 )), ] + + return result + + def predictor_func (self, face=None): + face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") + + bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x, "NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ] + return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] + + #override + def get_MergerConfig(self): + import merger + return self.predictor_func, (self.resolution, self.resolution, 3), merger.MergerConfigMasked(face_type=self.face_type, + default_mode = 'overlay', + ) + +Model = QModel diff --git a/models/Model_Quick96/__init__.py b/models/Model_Quick96/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0188f11aec7882710edf9d40586a00823f0d8c20 --- /dev/null +++ b/models/Model_Quick96/__init__.py @@ -0,0 +1 @@ +from .Model import Model diff --git a/models/Model_RTM/Model.py b/models/Model_RTM/Model.py new file mode 100644 index 0000000000000000000000000000000000000000..7a025436c89de2a628c50e9bc250012824de41b9 --- /dev/null +++ b/models/Model_RTM/Model.py @@ -0,0 +1,794 @@ +import multiprocessing +import operator +from functools import partial + +import numpy as np + +from core import mathlib +from core.interact import interact as io +from core.leras import nn +from facelib import FaceType +from models import ModelBase +from samplelib import * +from core.cv2ex import * + +class RTMModel(ModelBase): + + #override + def on_initialize_options(self): + device_config = nn.getCurrentDeviceConfig() + + lowest_vram = 2 + if len(device_config.devices) != 0: + lowest_vram = device_config.devices.get_worst_device().total_mem_gb + + if lowest_vram >= 4: + suggest_batch_size = 8 + else: + suggest_batch_size = 4 + + yn_str = {True:'y',False:'n'} + min_res = 64 + max_res = 640 + + default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False) + default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224) + default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') + default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) + + default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) + + inter_dims = self.load_or_def_option('inter_dims', None) + if inter_dims is None: + inter_dims = self.options['ae_dims'] + default_inter_dims = self.options['inter_dims'] = inter_dims + + default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) + default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) + default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) + default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True) + default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', True) + default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False) + + default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) + default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none') + default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False) + #default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False) + + + ask_override = self.ask_override() + if self.is_first_run() or ask_override: + self.ask_autobackup_hour() + self.ask_write_preview_history() + self.ask_target_iter() + self.ask_random_src_flip() + self.ask_random_dst_flip() + self.ask_batch_size(suggest_batch_size) + self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.') + + if self.is_first_run(): + resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .") + resolution = np.clip ( (resolution // 32) * 32, min_res, max_res) + self.options['resolution'] = resolution + self.options['face_type'] = io.input_str ("Face type", default_face_type, ['f','wf','head'], help_message="whole face / head").lower() + + + default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64) + + default_d_mask_dims = default_d_dims // 3 + default_d_mask_dims += default_d_mask_dims % 2 + default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims) + + if self.is_first_run(): + self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) + self.options['inter_dims'] = np.clip ( io.input_int("Inter dimensions", default_inter_dims, add_info="32-2048", help_message="Should be equal or more than AutoEncoder dimensions. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 2048 ) + + e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['e_dims'] = e_dims + e_dims % 2 + + d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['d_dims'] = d_dims + d_dims % 2 + + d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 ) + self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2 + + if self.is_first_run() or ask_override: + if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head': + self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' or 'head' type. Masked training clips training area to full_face mask or XSeg mask, thus network will train the faces properly.") + + self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.') + self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') + + default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) + default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8) + default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16) + + if self.is_first_run() or ask_override: + self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.") + + self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") + + self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 ) + + if self.options['gan_power'] != 0.0: + gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 ) + self.options['gan_patch_size'] = gan_patch_size + + gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-64", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 64 ) + self.options['gan_dims'] = gan_dims + + self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.") + self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") + + #self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, uniform_yaw=Y") + + self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims']) + #self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) + + #override + def on_initialize(self): + device_config = nn.getCurrentDeviceConfig() + devices = device_config.devices + self.model_data_format = "NCHW" + nn.initialize(data_format=self.model_data_format) + tf = nn.tf + + self.resolution = resolution = self.options['resolution'] + + input_ch=3 + ae_dims = self.ae_dims = self.options['ae_dims'] + inter_dims = self.inter_dims = self.options['inter_dims'] + e_dims = self.options['e_dims'] + d_dims = self.options['d_dims'] + d_mask_dims = self.options['d_mask_dims'] + inter_res = self.inter_res = resolution // (2**5) + + use_fp16 = True#self.options['use_fp16'] + conv_dtype = tf.float16 if use_fp16 else tf.float32 + + class Downscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=5): + self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + return tf.nn.leaky_relu(self.conv1(x), 0.1) + + class Upscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=3): + self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2) + return x + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, inp): + x = self.conv1(inp) + x = tf.nn.leaky_relu(x, 0.2) + x = self.conv2(x) + x = tf.nn.leaky_relu(inp+x, 0.2) + return x + + class Encoder(nn.ModelBase): + def on_build(self): + self.down1 = Downscale(input_ch, e_dims, kernel_size=5) + self.res1 = ResidualBlock(e_dims) + self.down2 = Downscale(e_dims, e_dims*2, kernel_size=5) + self.down3 = Downscale(e_dims*2, e_dims*4, kernel_size=5) + self.down4 = Downscale(e_dims*4, e_dims*8, kernel_size=5) + self.down5 = Downscale(e_dims*8, e_dims*8, kernel_size=5) + self.res5 = ResidualBlock(e_dims*8) + self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims ) + + def forward(self, x): + if use_fp16: + x = tf.cast(x, tf.float16) + + x = self.down1(x) + x = self.res1(x) + x = self.down2(x) + x = self.down3(x) + x = self.down4(x) + x = self.down5(x) + x = self.res5(x) + if use_fp16: + x = tf.cast(x, tf.float32) + + x = nn.pixel_norm(x, axes=[1,2,3]) + x = self.dense1(nn.flatten(x)) + return x + + + class Inter(nn.ModelBase): + def on_build(self): + self.dense2 = nn.Dense(ae_dims, inter_res * inter_res * inter_dims) + + def forward(self, inp): + x = inp + x = self.dense2(x) + x = nn.reshape_4D (x, inter_res, inter_res, inter_dims) + return x + + + class Decoder(nn.ModelBase): + def on_build(self): + self.upscale5 = Upscale(inter_dims, d_dims*8, kernel_size=3) + self.upscale4 = Upscale(d_dims*8, d_dims*8, kernel_size=3) + self.upscale3 = Upscale(d_dims*8, d_dims*4, kernel_size=3) + self.upscale2 = Upscale(d_dims*4, d_dims*2, kernel_size=3) + self.res5 = ResidualBlock(d_dims*8, kernel_size=3) + self.res4 = ResidualBlock(d_dims*8, kernel_size=3) + self.res3 = ResidualBlock(d_dims*4, kernel_size=3) + self.res2 = ResidualBlock(d_dims*2, kernel_size=3) + self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) + self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + + self.upscalem5 = Upscale(inter_dims, d_mask_dims*8, kernel_size=3) + self.upscalem4 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3) + self.upscalem3 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3) + self.upscalem2 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3) + self.upscalem1 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + + + def forward(self, z): + if use_fp16: + z = tf.cast(z, tf.float16) + + x = self.upscale5(z) + x = self.res5(x) + x = self.upscale4(x) + x = self.res4(x) + x = self.upscale3(x) + x = self.res3(x) + x = self.upscale2(x) + x = self.res2(x) + + x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x), + self.out_conv1(x), + self.out_conv2(x), + self.out_conv3(x)), nn.conv2d_ch_axis), 2) ) + m = self.upscalem5(z) + m = self.upscalem4(m) + m = self.upscalem3(m) + m = self.upscalem2(m) + m = self.upscalem1(m) + m = tf.nn.sigmoid(self.out_convm(m)) + + if use_fp16: + x = tf.cast(x, tf.float32) + m = tf.cast(m, tf.float32) + return x, m + + self.face_type = {'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[ self.options['face_type'] ] + + if 'eyes_prio' in self.options: + self.options.pop('eyes_prio') + + eyes_mouth_prio = self.options['eyes_mouth_prio'] + + + gan_power = self.gan_power = self.options['gan_power'] + random_warp = self.options['random_warp'] + random_src_flip = self.random_src_flip + random_dst_flip = self.random_dst_flip + + #pretrain = self.pretrain = self.options['pretrain'] + #if self.pretrain_just_disabled: + # self.set_iter(0) + # self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power'] + # random_warp = False if self.pretrain else self.options['random_warp'] + # random_src_flip = self.random_src_flip if not self.pretrain else True + # random_dst_flip = self.random_dst_flip if not self.pretrain else True + + # if self.pretrain: + # self.options_show_override['gan_power'] = 0.0 + # self.options_show_override['random_warp'] = False + # self.options_show_override['lr_dropout'] = 'n' + # self.options_show_override['uniform_yaw'] = True + + masked_training = self.options['masked_training'] + ct_mode = self.options['ct_mode'] + if ct_mode == 'none': + ct_mode = None + + models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] + models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' + optimizer_vars_on_cpu = models_opt_device=='/CPU:0' + + + bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) + mask_shape = nn.get4Dshape(resolution,resolution,1) + self.model_filename_list = [] + + with tf.device ('/CPU:0'): + #Place holders on CPU + self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') + self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') + + self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') + self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') + + self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') + self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') + self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') + self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') + + # Initializing model classes + + with tf.device (models_opt_device): + self.encoder = Encoder(name='encoder') + self.inter = Inter(name='inter') + self.decoder_src = Decoder(name='decoder_src') + self.decoder_dst = Decoder(name='decoder_dst') + self.true_face_gan = nn.CodeDiscriminator(inter_dims, code_res=self.inter_res, name='true_face_gan' ) + + self.model_filename_list += [ [self.encoder, 'encoder.npy'], + [self.inter, 'inter.npy'], + [self.decoder_src, 'decoder_src.npy'], + [self.decoder_dst, 'decoder_dst.npy'], + [self.true_face_gan, 'true_face_gan.npy'], + ] + + if self.is_training: + if gan_power != 0: + self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], use_fp16=use_fp16, name="GAN") + self.model_filename_list += [ [self.GAN, 'GAN.npy'] ] + + # Initialize optimizers + lr=5e-5 + lr_dropout = 0.3 + clipnorm = 1.0 if self.options['clipgrad'] else 0.0 + + self.all_weights = self.true_face_gan.get_weights() + self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() + + self.src_dst_opt = nn.AdaBelief(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') + self.src_dst_opt.initialize_variables (self.all_weights, vars_on_cpu=optimizer_vars_on_cpu) + self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] + + if gan_power != 0: + self.GAN_opt = nn.AdaBelief(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt') + self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) + self.model_filename_list += [ (self.GAN_opt, 'GAN_opt.npy') ] + + #self.BGGAN_opt = nn.AdaBelief(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='BGGAN_opt') + #self.BGGAN_opt.initialize_variables ( self.BGGAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) + #self.model_filename_list += [ (self.BGGAN_opt, 'BGGAN_opt.npy') ] + + if self.is_training: + # Adjust batch size for multiple GPU + gpu_count = max(1, len(devices) ) + bs_per_gpu = max(1, self.get_batch_size() // gpu_count) + self.set_batch_size( gpu_count*bs_per_gpu) + + # Compute losses per GPU + gpu_pred_src_src_list = [] + gpu_pred_dst_dst_list = [] + gpu_pred_src_dst_list = [] + gpu_pred_src_srcm_list = [] + gpu_pred_dst_dstm_list = [] + gpu_pred_src_dstm_list = [] + + gpu_pred_test_list = [] + gpu_pred_src_dst_bg_list = [] + + gpu_src_losses = [] + gpu_dst_losses = [] + gpu_G_loss_gvs = [] + gpu_D_src_dst_loss_gvs = [] + gpu_D_code_loss_gvs = [] + gpu_D_bg_loss_gvs = [] + + for gpu_id in range(gpu_count): + with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + with tf.device(f'/CPU:0'): + # slice on CPU, otherwise all batch data will be transfered to GPU first + batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) + gpu_warped_src = self.warped_src [batch_slice,:,:,:] + gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] + gpu_target_src = self.target_src [batch_slice,:,:,:] + gpu_target_dst = self.target_dst [batch_slice,:,:,:] + gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] + gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] + gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] + gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] + + # process model tensors + gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) + gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) + gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2 + gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2 + + gpu_target_srcm_antiblur = 1.0-gpu_target_srcm_blur + gpu_target_dstm_antiblur = 1.0-gpu_target_dstm_blur + #gpu_target_dstm_edge = tf.clip_by_value(gpu_target_dstm_blur*gpu_target_dstm_antiblur*4, 0, 1) + #gpu_target_dst_edge = gpu_target_dst*gpu_target_dstm_edge + #gpu_pred_dst_dst_edge = gpu_pred_dst_dst*gpu_target_dstm_edge + #gpu_pred_src_dst_edge = gpu_pred_src_dst*gpu_target_dstm_edge + #gpu_pred_test_list.append( tf.tile( tf.clip_by_value(gpu_target_dstm_blur*gpu_target_dstm_antiblur*4, 0, 1), (1,3,1,1) ) ) + + gpu_src_code, gpu_dst_code = self.inter(self.encoder(gpu_warped_src)), self.inter(self.encoder(gpu_warped_dst)) + gpu_src_code_d, gpu_dst_code_d = self.true_face_gan(gpu_src_code), self.true_face_gan(gpu_dst_code) + + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) + gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + + gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) + gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) + gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) + + gpu_target_src_blur = gpu_target_src*gpu_target_srcm_blur + gpu_pred_src_src_blur = gpu_pred_src_src*gpu_target_srcm_blur + gpu_pred_dst_dst_blur = gpu_pred_dst_dst*gpu_target_dstm_blur + + gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src, gpu_pred_src_src_blur, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src, gpu_pred_src_src_blur, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src-gpu_pred_src_src_blur), axis=[1,2,3]) + gpu_src_loss += tf.reduce_mean ( 0.1*nn.dssim(gpu_pred_src_src*gpu_target_srcm_antiblur, gpu_target_src*gpu_target_srcm_antiblur, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + + if eyes_mouth_prio: + gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_em - gpu_pred_src_src_blur*gpu_target_srcm_em ), axis=[1,2,3]) + + gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_srcm-gpu_pred_src_srcm),axis=[1,2,3] ) + + # sewing loss + #src_dstm_diff = tf.stop_gradient( (1-gpu_pred_src_dstm)*gpu_pred_dst_dstm + (1-gpu_pred_dst_dstm)*gpu_pred_src_dstm ) + src_dstm_diff = nn.gaussian_blur(gpu_pred_dst_dstm, resolution//4) + src_dstm_diff += nn.gaussian_blur(gpu_pred_dst_dstm, resolution//8) + src_dstm_diff += nn.gaussian_blur(gpu_pred_dst_dstm, resolution//16) + src_dstm_diff *= (1-gpu_pred_dst_dstm) + src_dstm_diff = tf.stop_gradient(src_dstm_diff) + gpu_src_loss += tf.reduce_mean ( 5*nn.dssim (gpu_target_dst*src_dstm_diff, gpu_pred_src_dst*src_dstm_diff, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_pred_test_list.append( tf.tile( tf.clip_by_value( src_dstm_diff,0,1), (1,3,1,1) ) ) #src_dstm_diff + src_dstm_diff_blur * (1-gpu_pred_dst_dstm) + + gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst, gpu_pred_dst_dst_blur, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst, gpu_pred_dst_dst_blur, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst-gpu_pred_dst_dst_blur ), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean ( 1*nn.dssim(gpu_pred_dst_dst*gpu_target_dstm_antiblur, gpu_target_dst*gpu_target_dstm_antiblur, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + + if eyes_mouth_prio: + gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_em - gpu_pred_dst_dst_blur*gpu_target_dstm_em ), axis=[1,2,3]) + + gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dstm-gpu_pred_dst_dstm),axis=[1,2,3] ) + + gpu_src_losses += [gpu_src_loss] + gpu_dst_losses += [gpu_dst_loss] + + def DLossOnes(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3]) + + def DLossZeros(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3]) + + #dst_dst_edge_loss = tf.reduce_mean ( 5*nn.dssim (gpu_target_dst,gpu_pred_dst_dst_edge, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + #dst_dst_edge_loss += tf.reduce_mean ( 5*nn.dssim (gpu_target_dst,gpu_pred_dst_dst_edge, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) + #dst_dst_edge_loss += tf.reduce_mean (10*tf.square(gpu_target_dst-gpu_pred_dst_dst_edge), axis=[1,2,3]) + + gpu_G_loss = gpu_src_loss + gpu_dst_loss + 1.0*DLossOnes(gpu_src_code_d) + + + #gpu_G_loss += 0.1*(# DLossOnes(gpu_bg_src_dst_d) + DLossOnes(gpu_bg_src_dst_d2) + \ + # DLossOnes(gpu_bg_fg_dst_dst_d) + DLossOnes(gpu_bg_fg_dst_dst_d2) ) / 4.0 + + gpu_D_code_loss = ( DLossOnes(gpu_dst_code_d) + DLossZeros(gpu_src_code_d) ) * 0.5 + #gpu_bg_target_dst_d, gpu_bg_target_dst_d2 + #gpu_D_bg_loss = ( DLossOnes(gpu_bg_dst_dst_d) + DLossOnes(gpu_bg_dst_dst_d2) + \ + # DLossZeros(gpu_bg_fg_dst_dst_d) + DLossZeros(gpu_bg_fg_dst_dst_d2) \ + # ) / 6.0 + #DLossZeros(gpu_bg_src_dst_d) + DLossZeros(gpu_bg_src_dst_d2) \ + + + + #gpu_D_bg_loss_gvs += [ nn.gradients (gpu_D_bg_loss, self.BGGAN.get_weights() ) ] + + # + # Suppress random bright dots from BGGAN + #gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked) + #gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_dst_anti_masked) + + if gan_power != 0: + gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_blur) + gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_blur) + #gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked_opt) + #gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked_opt) + + gpu_D_src_dst_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \ + DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + ) * ( 1.0 / 8) + #DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \ + #DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2) + + gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.GAN.get_weights() ) ] + + gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) ) * gan_power + #DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2) + + if masked_training: + # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) + #gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) + + gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() ) ] + gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.true_face_gan.get_weights() ) ] + + + # Average losses and gradients, and create optimizer update ops + with tf.device(f'/CPU:0'): + pred_src_src = nn.concat(gpu_pred_src_src_list, 0) + pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) + pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) + pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) + pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) + pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) + + pred_test = nn.concat(gpu_pred_test_list, 0) + + #pred_dst_dst_bg = nn.concat(gpu_pred_dst_dst_bg_list, 0) + #pred_src_dst_bg = nn.concat(gpu_pred_src_dst_bg_list, 0) + + + + with tf.device (models_opt_device): + src_loss = tf.concat(gpu_src_losses, 0) + dst_loss = tf.concat(gpu_dst_losses, 0) + + src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) + D_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list(gpu_D_code_loss_gvs)) + + #D_bg_loss_gv_op = self.BGGAN_opt.get_update_op (nn.average_gv_list(gpu_D_bg_loss_gvs) ) + + if gan_power != 0: + src_D_src_dst_loss_gv_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) ) + + # Initializing training and view functions + def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + s, d, _, _, = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op, D_loss_gv_op, ],#D_bg_loss_gv_op + feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em, + }) + return s, d + self.src_dst_train = src_dst_train + + if gan_power != 0: + def D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em}) + self.D_src_dst_train = D_src_dst_train + + def AE_view(warped_src, warped_dst, target_dstm): + return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm, + pred_test], + feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, + self.target_dstm : target_dstm }) + + self.AE_view = AE_view + else: + #Initializing merge function + with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): + gpu_dst_code = self.inter(self.encoder (self.warped_dst)) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) + + def AE_merge(warped_dst, morph_value): + return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) + + self.AE_merge = AE_merge + + # Loading/initializing all models/optimizers weights + for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): + # if self.pretrain_just_disabled: + # do_init = False + # if model == self.inter_src or model == self.inter_dst: + # do_init = True + # else: + do_init = self.is_first_run() + if self.is_training and gan_power != 0 and model == self.GAN: + if self.gan_model_changed: + do_init = True + + if not do_init: + do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) + if do_init: + model.init_weights() + ############### + + # initializing sample generators + if self.is_training: + training_data_src_path = self.training_data_src_path #if not self.pretrain else self.get_pretraining_data_path() + training_data_dst_path = self.training_data_dst_path #if not self.pretrain else self.get_pretraining_data_path() + + random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain + + cpu_count = min(multiprocessing.cpu_count(), 8) + src_generators_count = cpu_count // 2 + dst_generators_count = cpu_count // 2 + if ct_mode is not None: + src_generators_count = int(src_generators_count * 1.5) + + self.set_training_data_generators ([ + SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=random_src_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, + generators_count=src_generators_count ), + + SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, + generators_count=dst_generators_count ) + ]) + + self.last_src_samples_loss = [] + self.last_dst_samples_loss = [] + #if self.pretrain_just_disabled: + # self.update_sample_for_preview(force_new=True) + + + def export_dfm (self): + output_path=self.get_strpath_storage_for_file('model.dfm') + + io.log_info(f'Dumping .dfm to {output_path}') + + tf = nn.tf + with tf.device (nn.tf_default_device_name): + warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') + warped_dst = tf.transpose(warped_dst, (0,3,1,2)) + + gpu_dst_code = self.inter(self.encoder (warped_dst)) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) + + gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1)) + gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1)) + gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1)) + + tf.identity(gpu_pred_dst_dstm, name='out_face_mask') + tf.identity(gpu_pred_src_dst, name='out_celeb_face') + tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') + + output_graph_def = tf.graph_util.convert_variables_to_constants( + nn.tf_sess, + tf.get_default_graph().as_graph_def(), + ['out_face_mask','out_celeb_face','out_celeb_face_mask'] + ) + + import tf2onnx + with tf.device("/CPU:0"): + model_proto, _ = tf2onnx.convert._convert_common( + output_graph_def, + name='AMP', + input_names=['in_face:0'], + output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'], + opset=13, + output_path=output_path) + + #override + def get_model_filename_list(self): + return self.model_filename_list + + #override + def onSave(self): + for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): + model.save_weights ( self.get_strpath_storage_for_file(filename) ) + + #override + def should_save_preview_history(self): + return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \ + (io.is_colab() and self.iter % 100 == 0) + + #override + def onTrainOneIter(self): + bs = self.get_batch_size() + + ( (warped_src, target_src, target_srcm, target_srcm_em), \ + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() + + src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + for i in range(bs): + self.last_src_samples_loss.append ( (src_loss[i], warped_src[i], target_src[i], target_srcm[i], target_srcm_em[i]) ) + self.last_dst_samples_loss.append ( (dst_loss[i], warped_dst[i], target_dst[i], target_dstm[i], target_dstm_em[i]) ) + + if len(self.last_src_samples_loss) >= bs*16: + src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True) + dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True) + + warped_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) + target_src = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) + target_srcm = np.stack( [ x[3] for x in src_samples_loss[:bs] ] ) + target_srcm_em = np.stack( [ x[4] for x in src_samples_loss[:bs] ] ) + + warped_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) + target_dst = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) + target_dstm = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) + target_dstm_em = np.stack( [ x[4] for x in dst_samples_loss[:bs] ] ) + + src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + self.last_src_samples_loss = [] + self.last_dst_samples_loss = [] + + if self.gan_power != 0: + self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + + return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) + + #override + def onGetPreview(self, samples, for_history=False): + ( (warped_src, target_src, target_srcm, target_srcm_em), + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples + + S, D, SS, DD, DDM, SD, SDM, TEST = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + \ + self.AE_view (target_src, target_dst, target_dstm) ) ] + DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] + + target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] + + n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) + + result = [] + + st = [] + for i in range(n_samples): + ar = S[i], SS[i], D[i], DD[i], SD[i], TEST[i] + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('RTM', np.concatenate (st, axis=0 )), ] + + + st_m = [] + for i in range(n_samples): + + ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*SDM[i], TEST[i] + st_m.append ( np.concatenate ( ar, axis=1) ) + + result += [ ('RTM masked', np.concatenate (st_m, axis=0 )), ] + + return result + + def predictor_func (self, face, morph_value): + face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") + + bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face, morph_value) ] + + return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] + + #override + def get_MergerConfig(self): + import merger + return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') + +Model = RTMModel diff --git a/models/Model_RTM/__init__.py b/models/Model_RTM/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0188f11aec7882710edf9d40586a00823f0d8c20 --- /dev/null +++ b/models/Model_RTM/__init__.py @@ -0,0 +1 @@ +from .Model import Model diff --git a/models/Model_SAEHD.zip b/models/Model_SAEHD.zip new file mode 100644 index 0000000000000000000000000000000000000000..3e625c90cdc859574604f9ae011b5ed816d3f822 --- /dev/null +++ b/models/Model_SAEHD.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6af4e64495c1ac27fcf72056771b6e0e8e086a44d2bae76cdd0dc7c2029aefa +size 24081 diff --git "a/models/Model_SAEHD/Model \342\200\224 \320\272\320\276\320\277\320\270\321\217.py" "b/models/Model_SAEHD/Model \342\200\224 \320\272\320\276\320\277\320\270\321\217.py" new file mode 100644 index 0000000000000000000000000000000000000000..b4f03de3e7048cd22dfc97e2dfe739511c49e8b2 --- /dev/null +++ "b/models/Model_SAEHD/Model \342\200\224 \320\272\320\276\320\277\320\270\321\217.py" @@ -0,0 +1,889 @@ +import multiprocessing +import operator +from functools import partial + +import numpy as np + +from core import mathlib +from core.interact import interact as io +from core.leras import nn +from facelib import FaceType +from models import ModelBase +from samplelib import * + +class SAEHDModel(ModelBase): + + #override + def on_initialize_options(self): + device_config = nn.getCurrentDeviceConfig() + + lowest_vram = 2 + if len(device_config.devices) != 0: + lowest_vram = device_config.devices.get_worst_device().total_mem_gb + + if lowest_vram >= 4: + suggest_batch_size = 8 + else: + suggest_batch_size = 4 + + yn_str = {True:'y',False:'n'} + min_res = 64 + max_res = 640 + + #default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False) + default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128) + default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f') + default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) + + default_archi = self.options['archi'] = self.load_or_def_option('archi', 'liae-ud') + + default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) + default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) + default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) + default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) + default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True) + default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', False) + default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False) + default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False) + + default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True) + + lr_dropout = self.load_or_def_option('lr_dropout', 'n') + lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp + default_lr_dropout = self.options['lr_dropout'] = lr_dropout + + default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) + default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0) + default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0) + default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0) + default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none') + default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False) + default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False) + + ask_override = self.ask_override() + if self.is_first_run() or ask_override: + self.ask_autobackup_hour() + self.ask_write_preview_history() + self.ask_target_iter() + self.ask_random_src_flip() + self.ask_random_dst_flip() + self.ask_batch_size(suggest_batch_size) + #self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.') + + if self.is_first_run(): + resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16 and 32 for -d archi.") + resolution = np.clip ( (resolution // 16) * 16, min_res, max_res) + self.options['resolution'] = resolution + + + + self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower() + + while True: + archi = io.input_str ("AE architecture", default_archi, help_message=\ +""" +'df' keeps more identity-preserved face. +'liae' can fix overly different face shapes. +'-u' increased likeness of the face. +'-d' (experimental) doubling the resolution using the same computation cost. +Examples: df, liae, df-d, df-ud, liae-ud, ... +""").lower() + + archi_split = archi.split('-') + + if len(archi_split) == 2: + archi_type, archi_opts = archi_split + elif len(archi_split) == 1: + archi_type, archi_opts = archi_split[0], None + else: + continue + + if archi_type not in ['df', 'liae']: + continue + + if archi_opts is not None: + if len(archi_opts) == 0: + continue + if len([ 1 for opt in archi_opts if opt not in ['u','d','t','c'] ]) != 0: + continue + + if 'd' in archi_opts: + self.options['resolution'] = np.clip ( (self.options['resolution'] // 32) * 32, min_res, max_res) + + break + self.options['archi'] = archi + + default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64) + + default_d_mask_dims = default_d_dims // 3 + default_d_mask_dims += default_d_mask_dims % 2 + default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims) + + if self.is_first_run(): + self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) + + e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['e_dims'] = e_dims + e_dims % 2 + + d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['d_dims'] = d_dims + d_dims % 2 + + d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 ) + self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2 + + if self.is_first_run() or ask_override: + if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head': + self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' or 'head' type. Masked training clips training area to full_face mask or XSeg mask, thus network will train the faces properly.") + + self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.') + self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') + self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.') + + default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) + default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8) + default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16) + + if self.is_first_run() or ask_override: + self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.") + + self.options['adabelief'] = io.input_bool ("Use AdaBelief optimizer?", default_adabelief, help_message="Use AdaBelief optimizer. It requires more VRAM, but the accuracy and the generalization of the model is higher.") + + self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.") + + self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") + + self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with lr_dropout(on) and random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 ) + + if self.options['gan_power'] != 0.0: + gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 ) + self.options['gan_patch_size'] = gan_patch_size + + gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 ) + self.options['gan_dims'] = gan_dims + + if 'df' in self.options['archi']: + self.options['true_face_power'] = np.clip ( io.input_number ("'True face' power.", default_true_face_power, add_info="0.0000 .. 1.0", help_message="Experimental option. Discriminates result face to be more like src face. Higher value - stronger discrimination. Typical value is 0.01 . Comparison - https://i.imgur.com/czScS9q.png"), 0.0, 1.0 ) + else: + self.options['true_face_power'] = 0.0 + + self.options['face_style_power'] = np.clip ( io.input_number("Face style power", default_face_style_power, add_info="0.0..100.0", help_message="Learn the color of the predicted face to be the same as dst inside mask. If you want to use this option with 'whole_face' you have to use XSeg trained mask. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.001 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 ) + self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn the area outside mask of the predicted face to be the same as dst. If you want to use this option with 'whole_face' you have to use XSeg trained mask. For whole_face you have to use XSeg trained mask. This can make face more like dst. Enabling this option increases the chance of model collapse. Typical value is 2.0"), 0.0, 100.0 ) + + self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.") + self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") + + self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, styles=0.0, uniform_yaw=Y") + + if self.options['pretrain'] and self.get_pretraining_data_path() is None: + raise Exception("pretraining_data_path is not defined") + + self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims']) + + self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) + + #override + def on_initialize(self): + device_config = nn.getCurrentDeviceConfig() + devices = device_config.devices + self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC" + nn.initialize(data_format=self.model_data_format) + tf = nn.tf + + self.resolution = resolution = self.options['resolution'] + self.face_type = {'h' : FaceType.HALF, + 'mf' : FaceType.MID_FULL, + 'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[ self.options['face_type'] ] + + if 'eyes_prio' in self.options: + self.options.pop('eyes_prio') + + eyes_mouth_prio = self.options['eyes_mouth_prio'] + + archi_split = self.options['archi'].split('-') + + if len(archi_split) == 2: + archi_type, archi_opts = archi_split + elif len(archi_split) == 1: + archi_type, archi_opts = archi_split[0], None + + self.archi_type = archi_type + + ae_dims = self.options['ae_dims'] + e_dims = self.options['e_dims'] + d_dims = self.options['d_dims'] + d_mask_dims = self.options['d_mask_dims'] + self.pretrain = self.options['pretrain'] + if self.pretrain_just_disabled: + self.set_iter(0) + + adabelief = self.options['adabelief'] + + use_fp16 = False + if self.is_exporting: + use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') + + self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power'] + random_warp = False if self.pretrain else self.options['random_warp'] + random_src_flip = self.random_src_flip if not self.pretrain else True + random_dst_flip = self.random_dst_flip if not self.pretrain else True + blur_out_mask = self.options['blur_out_mask'] + learn_dst_bg = False#True + + if self.pretrain: + self.options_show_override['gan_power'] = 0.0 + self.options_show_override['random_warp'] = False + self.options_show_override['lr_dropout'] = 'n' + self.options_show_override['face_style_power'] = 0.0 + self.options_show_override['bg_style_power'] = 0.0 + self.options_show_override['uniform_yaw'] = True + + masked_training = self.options['masked_training'] + ct_mode = self.options['ct_mode'] + if ct_mode == 'none': + ct_mode = None + + + models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] + models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' + optimizer_vars_on_cpu = models_opt_device=='/CPU:0' + + input_ch=3 + bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) + mask_shape = nn.get4Dshape(resolution,resolution,1) + self.model_filename_list = [] + + with tf.device ('/CPU:0'): + #Place holders on CPU + self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') + self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') + self.warped_src2 = tf.placeholder (nn.floatx, bgr_shape, name='warped_src2') + self.warped_dst2 = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst2') + + self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') + self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') + + self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') + self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') + self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') + self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') + + # Initializing model classes + model_archi = nn.DeepFakeArchi(resolution, use_fp16=use_fp16, opts=archi_opts) + + with tf.device (models_opt_device): + if 'df' in archi_type: + self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') + encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 + + self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') + inter_out_ch = self.inter.get_out_ch() + + self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') + self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') + + self.model_filename_list += [ [self.encoder, 'encoder.npy' ], + [self.inter, 'inter.npy' ], + [self.decoder_src, 'decoder_src.npy'], + [self.decoder_dst, 'decoder_dst.npy'] ] + + if self.is_training: + if self.options['true_face_power'] != 0: + self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=self.inter.get_out_res(), name='dis' ) + self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ] + + elif 'liae' in archi_type: + self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') + encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 + + self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB') + self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B') + + inter_out_ch = self.inter_AB.get_out_ch() + inters_out_ch = inter_out_ch*2 + self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder') + + self.model_filename_list += [ [self.encoder, 'encoder.npy'], + [self.inter_AB, 'inter_AB.npy'], + [self.inter_B , 'inter_B.npy'], + [self.decoder , 'decoder.npy'] ] + + if self.is_training: + if gan_power != 0: + self.D_src = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="D_src") + self.model_filename_list += [ [self.D_src, 'GAN.npy'] ] + + # Initialize optimizers + lr=5e-5 + lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain else 1.0 + OptimizerClass = nn.AdaBelief if adabelief else nn.RMSprop + clipnorm = 1.0 if self.options['clipgrad'] else 0.0 + + if 'df' in archi_type: + self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() + elif 'liae' in archi_type: + self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() + + self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') + self.src_dst_opt.initialize_variables (self.src_dst_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') + self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] + + if self.options['true_face_power'] != 0: + self.D_code_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_code_opt') + self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') + self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ] + + if gan_power != 0: + self.D_src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt') + self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights() + self.model_filename_list += [ (self.D_src_dst_opt, 'GAN_opt.npy') ] + + if self.is_training: + # Adjust batch size for multiple GPU + gpu_count = max(1, len(devices) ) + bs_per_gpu = max(1, self.get_batch_size() // gpu_count) + self.set_batch_size( gpu_count*bs_per_gpu) + + # Compute losses per GPU + gpu_pred_src_src_list = [] + gpu_pred_dst_dst_list = [] + gpu_pred_src_dst_list = [] + gpu_pred_src_srcm_list = [] + gpu_pred_dst_dstm_list = [] + gpu_pred_src_dstm_list = [] + + gpu_src_losses = [] + gpu_dst_losses = [] + gpu_G_loss_gvs = [] + gpu_D_code_loss_gvs = [] + gpu_D_src_dst_loss_gvs = [] + + for gpu_id in range(gpu_count): + with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + with tf.device(f'/CPU:0'): + # slice on CPU, otherwise all batch data will be transfered to GPU first + batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) + gpu_warped_src = self.warped_src [batch_slice,:,:,:] + gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] + gpu_warped_src2 = self.warped_src2 [batch_slice,:,:,:] + gpu_warped_dst2 = self.warped_dst2 [batch_slice,:,:,:] + gpu_target_src = self.target_src [batch_slice,:,:,:] + gpu_target_dst = self.target_dst [batch_slice,:,:,:] + gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] + gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] + gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] + gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] + + gpu_target_srcm_anti = 1-gpu_target_srcm + gpu_target_dstm_anti = 1-gpu_target_dstm + + if blur_out_mask: + sigma = resolution / 128 + + x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_srcm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti + + x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_dstm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti + + + # process model tensors + if 'df' in archi_type: + gpu_src_code = self.inter(self.encoder(gpu_warped_src)) + gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) + gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + gpu_pred_src_dst_no_code_grad, _ = self.decoder_src(tf.stop_gradient(gpu_dst_code)) + + elif 'liae' in archi_type: + gpu_src_code = self.encoder (gpu_warped_src) + gpu_src_inter_AB_code = self.inter_AB (gpu_src_code) + gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis ) + gpu_dst_code = self.encoder (gpu_warped_dst) + gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) + gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) + gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) + gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) + + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) + gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + gpu_pred_dst_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_dst_code)) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + gpu_pred_src_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_src_dst_code)) + + gpu_src2_code = self.encoder (gpu_warped_src2) + gpu_src2_inter_AB_code = self.inter_AB (gpu_src2_code) + gpu_src2_code = tf.concat([gpu_src2_inter_AB_code,gpu_src2_inter_AB_code], nn.conv2d_ch_axis ) + + gpu_dst2_code = self.encoder (gpu_warped_dst2) + gpu_dst2_inter_B_code = self.inter_B (gpu_dst2_code) + gpu_dst2_inter_AB_code = self.inter_AB (gpu_dst2_code) + gpu_dst2_code = tf.concat([gpu_dst2_inter_B_code,gpu_dst2_inter_AB_code], nn.conv2d_ch_axis ) + + gpu_pred_src_src2, gpu_pred_src_srcm2 = self.decoder(gpu_src2_code) + gpu_pred_dst_dst2, gpu_pred_dst_dstm2 = self.decoder(gpu_dst_code) + + + gpu_pred_src_src_list.append(gpu_pred_src_src) + gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) + gpu_pred_src_dst_list.append(gpu_pred_src_dst) + + gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) + gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) + gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) + + gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) + gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2 + gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur + + gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) + gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary + gpu_target_dstm_style_anti_blur = 1.0 - gpu_target_dstm_style_blur + gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2 + gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur + + gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur + gpu_target_dst_style_masked = gpu_target_dst*gpu_target_dstm_style_blur + gpu_target_dst_style_anti_masked = gpu_target_dst*gpu_target_dstm_style_anti_blur + + gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur + gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur + gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur + gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur + + gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src + gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst + gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src + gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst + + gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur + gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*gpu_target_dstm_style_anti_blur + + if resolution < 256: + gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + else: + gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) + + if eyes_mouth_prio: + gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_em - gpu_pred_src_src*gpu_target_srcm_em ), axis=[1,2,3]) + + gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) + + + + face_style_power = self.options['face_style_power'] / 100.0 + if face_style_power != 0 and not self.pretrain: + gpu_src_loss += nn.style_loss(gpu_pred_src_dst_no_code_grad*tf.stop_gradient(gpu_pred_src_dstm), tf.stop_gradient(gpu_pred_dst_dst*gpu_pred_dst_dstm), gaussian_blur_radius=resolution//8, loss_weight=10000*face_style_power) + #gpu_src_loss += nn.style_loss(gpu_psd_target_dst_style_masked, gpu_target_dst_style_masked, gaussian_blur_radius=resolution//16, loss_weight=10000*face_style_power) + + bg_style_power = self.options['bg_style_power'] / 100.0 + if bg_style_power != 0 and not self.pretrain: + gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim( gpu_psd_target_dst_style_anti_masked, gpu_target_dst_style_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square(gpu_psd_target_dst_style_anti_masked - gpu_target_dst_style_anti_masked), axis=[1,2,3] ) + + if resolution < 256: + gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + else: + gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) + + if eyes_mouth_prio: + gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_em - gpu_pred_dst_dst*gpu_target_dstm_em ), axis=[1,2,3]) + + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) + + # Invariance loss + #gpu_src_loss += tf.reduce_mean ( tf.square( gpu_pred_src_src - gpu_pred_src_src2 ),axis=[1,2,3] ) + #gpu_dst_loss += tf.reduce_mean ( tf.square( gpu_pred_dst_dst - gpu_pred_dst_dst2 ),axis=[1,2,3] ) + + + gpu_src_losses += [gpu_src_loss] + gpu_dst_losses += [gpu_dst_loss] + + gpu_G_loss = gpu_src_loss + gpu_dst_loss + + if learn_dst_bg and masked_training and 'liae' in archi_type: + gpu_G_loss += tf.reduce_mean( tf.square(gpu_pred_dst_dst_no_code_grad*gpu_target_dstm_anti_blur-gpu_target_dst_anti_masked),axis=[1,2,3] ) + + def DLoss(labels,logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3]) + + if self.options['true_face_power'] != 0: + gpu_src_code_d = self.code_discriminator( gpu_src_code ) + gpu_src_code_d_ones = tf.ones_like (gpu_src_code_d) + gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d) + gpu_dst_code_d = self.code_discriminator( gpu_dst_code ) + gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d) + + gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d) + + gpu_D_code_loss = (DLoss(gpu_dst_code_d_ones , gpu_dst_code_d) + \ + DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5 + + gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ] + + if gan_power != 0: + gpu_pred_src_src_d, \ + gpu_pred_src_src_d2 = self.D_src(gpu_pred_src_src_masked_opt) + + gpu_pred_src_src_d_ones = tf.ones_like (gpu_pred_src_src_d) + gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d) + + gpu_pred_src_src_d2_ones = tf.ones_like (gpu_pred_src_src_d2) + gpu_pred_src_src_d2_zeros = tf.zeros_like(gpu_pred_src_src_d2) + + gpu_target_src_d, \ + gpu_target_src_d2 = self.D_src(gpu_target_src_masked_opt) + + gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d) + gpu_target_src_d2_ones = tf.ones_like(gpu_target_src_d2) + + gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \ + DLoss(gpu_pred_src_src_d_zeros , gpu_pred_src_src_d) ) * 0.5 + \ + (DLoss(gpu_target_src_d2_ones , gpu_target_src_d2) + \ + DLoss(gpu_pred_src_src_d2_zeros , gpu_pred_src_src_d2) ) * 0.5 + + gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights() ) ]#+self.D_src_x2.get_weights() + + gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \ + DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2)) + + if masked_training: + # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) + gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) + + gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights )] + + + + + # Average losses and gradients, and create optimizer update ops + with tf.device(f'/CPU:0'): + pred_src_src = nn.concat(gpu_pred_src_src_list, 0) + pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) + pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) + pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) + pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) + pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) + + with tf.device (models_opt_device): + src_loss = tf.concat(gpu_src_losses, 0) + dst_loss = tf.concat(gpu_dst_losses, 0) + src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) + + if self.options['true_face_power'] != 0: + D_loss_gv_op = self.D_code_opt.get_update_op (nn.average_gv_list(gpu_D_code_loss_gvs)) + + if gan_power != 0: + src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) ) + + + # Initializing training and view functions + def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + s, d = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], + feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em, + })[:2] + return s, d + self.src_dst_train = src_dst_train + + if self.options['true_face_power'] != 0: + def D_train(warped_src, warped_dst): + nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst}) + self.D_train = D_train + + if gan_power != 0: + def D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em}) + self.D_src_dst_train = D_src_dst_train + + + def AE_view(warped_src, warped_dst): + return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], + feed_dict={self.warped_src:warped_src, + self.warped_dst:warped_dst}) + self.AE_view = AE_view + else: + # Initializing merge function + with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): + if 'df' in archi_type: + gpu_dst_code = self.inter(self.encoder(self.warped_dst)) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) + + elif 'liae' in archi_type: + gpu_dst_code = self.encoder (self.warped_dst) + gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) + gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) + gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + + + def AE_merge( warped_dst): + return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) + + self.AE_merge = AE_merge + + # Loading/initializing all models/optimizers weights + for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): + if self.pretrain_just_disabled: + do_init = False + if 'df' in archi_type: + if model == self.inter: + do_init = True + elif 'liae' in archi_type: + if model == self.inter_AB or model == self.inter_B: + do_init = True + else: + do_init = self.is_first_run() + if self.is_training and gan_power != 0 and model == self.D_src: + if self.gan_model_changed: + do_init = True + + if not do_init: + do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) + + if do_init: + model.init_weights() + + + ############### + + # initializing sample generators + if self.is_training: + training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() + training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() + + random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None + + cpu_count = multiprocessing.cpu_count() + src_generators_count = cpu_count // 2 + dst_generators_count = cpu_count // 2 + if ct_mode is not None: + src_generators_count = int(src_generators_count * 1.5) + + self.set_training_data_generators ([ + SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(scale_range=[-0.1, 0.1], random_flip=random_src_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'warp_rnd_seed_shift': 0, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'warp_rnd_seed_shift': 1, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, + generators_count=src_generators_count ), + + SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(scale_range=[-0.1, 0.1], random_flip=random_dst_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'warp_rnd_seed_shift': 0, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'warp_rnd_seed_shift': 1, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, + generators_count=dst_generators_count ) + ]) + + if self.pretrain_just_disabled: + self.update_sample_for_preview(force_new=True) + + def export_dfm (self): + output_path=self.get_strpath_storage_for_file('model.dfm') + + io.log_info(f'Dumping .dfm to {output_path}') + + tf = nn.tf + nn.set_data_format('NCHW') + + with tf.device (nn.tf_default_device_name): + warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') + warped_dst = tf.transpose(warped_dst, (0,3,1,2)) + + + if 'df' in self.archi_type: + gpu_dst_code = self.inter(self.encoder(warped_dst)) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) + + elif 'liae' in self.archi_type: + gpu_dst_code = self.encoder (warped_dst) + gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) + gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) + gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + + gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1)) + gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1)) + gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1)) + + tf.identity(gpu_pred_dst_dstm, name='out_face_mask') + tf.identity(gpu_pred_src_dst, name='out_celeb_face') + tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') + + output_graph_def = tf.graph_util.convert_variables_to_constants( + nn.tf_sess, + tf.get_default_graph().as_graph_def(), + ['out_face_mask','out_celeb_face','out_celeb_face_mask'] + ) + + import tf2onnx + with tf.device("/CPU:0"): + model_proto, _ = tf2onnx.convert._convert_common( + output_graph_def, + name='SAEHD', + input_names=['in_face:0'], + output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'], + opset=9, + output_path=output_path) + + #override + def get_model_filename_list(self): + return self.model_filename_list + + #override + def onSave(self): + for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): + model.save_weights ( self.get_strpath_storage_for_file(filename) ) + + #override + def should_save_preview_history(self): + return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \ + (io.is_colab() and self.iter % 100 == 0) + + #override + def onTrainOneIter(self): + if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled: + io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n') + + ( (warped_src, target_src, target_srcm, target_srcm_em), \ + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() + + src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + if self.options['true_face_power'] != 0 and not self.pretrain: + self.D_train (warped_src, warped_dst) + + if self.gan_power != 0: + self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) + + #override + def onGetPreview(self, samples, for_history=False): + ( (warped_src, target_src, target_srcm, target_srcm_em), + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples + + S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] + DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] + + target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] + + n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) + + if self.resolution <= 256: + result = [] + + st = [] + for i in range(n_samples): + ar = S[i], SS[i], D[i], DD[i], SD[i] + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD', np.concatenate (st, axis=0 )), ] + + + st_m = [] + for i in range(n_samples): + SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i] + + ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*SD_mask + st_m.append ( np.concatenate ( ar, axis=1) ) + + result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ] + else: + result = [] + + st = [] + for i in range(n_samples): + ar = S[i], SS[i] + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD src-src', np.concatenate (st, axis=0 )), ] + + st = [] + for i in range(n_samples): + ar = D[i], DD[i] + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD dst-dst', np.concatenate (st, axis=0 )), ] + + st = [] + for i in range(n_samples): + ar = D[i], SD[i] + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD pred', np.concatenate (st, axis=0 )), ] + + + st_m = [] + for i in range(n_samples): + ar = S[i]*target_srcm[i], SS[i] + st_m.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD masked src-src', np.concatenate (st_m, axis=0 )), ] + + st_m = [] + for i in range(n_samples): + ar = D[i]*target_dstm[i], DD[i]*DDM[i] + st_m.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD masked dst-dst', np.concatenate (st_m, axis=0 )), ] + + st_m = [] + for i in range(n_samples): + SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i] + ar = D[i]*target_dstm[i], SD[i]*SD_mask + st_m.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD masked pred', np.concatenate (st_m, axis=0 )), ] + + return result + + def predictor_func (self, face=None): + face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") + + bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ] + + return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] + + #override + def get_MergerConfig(self): + import merger + return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') + +Model = SAEHDModel diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py new file mode 100644 index 0000000000000000000000000000000000000000..ecfaa7344daa78b99304a37720a1a59c76930423 --- /dev/null +++ b/models/Model_SAEHD/Model.py @@ -0,0 +1,869 @@ +import multiprocessing +import operator +from functools import partial + +import numpy as np + +from core import mathlib +from core.interact import interact as io +from core.leras import nn +from facelib import FaceType +from models import ModelBase +from samplelib import * + +class SAEHDModel(ModelBase): + + #override + def on_initialize_options(self): + device_config = nn.getCurrentDeviceConfig() + + lowest_vram = 2 + if len(device_config.devices) != 0: + lowest_vram = device_config.devices.get_worst_device().total_mem_gb + + if lowest_vram >= 4: + suggest_batch_size = 8 + else: + suggest_batch_size = 4 + + yn_str = {True:'y',False:'n'} + min_res = 64 + max_res = 640 + + #default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False) + default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128) + default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f') + default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) + + default_archi = self.options['archi'] = self.load_or_def_option('archi', 'liae-ud') + + default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) + default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) + default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) + default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) + default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True) + default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', False) + default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False) + default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False) + + default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True) + + lr_dropout = self.load_or_def_option('lr_dropout', 'n') + lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp + default_lr_dropout = self.options['lr_dropout'] = lr_dropout + + default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) + default_random_hsv_power = self.options['random_hsv_power'] = self.load_or_def_option('random_hsv_power', 0.0) + default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0) + default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0) + default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0) + default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none') + default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False) + default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False) + + ask_override = self.ask_override() + if self.is_first_run() or ask_override: + self.ask_autobackup_hour() + self.ask_write_preview_history() + self.ask_target_iter() + self.ask_random_src_flip() + self.ask_random_dst_flip() + self.ask_batch_size(suggest_batch_size) + #self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.') + + if self.is_first_run(): + resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16 and 32 for -d archi.") + resolution = np.clip ( (resolution // 16) * 16, min_res, max_res) + self.options['resolution'] = resolution + + + + self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower() + + while True: + archi = io.input_str ("AE architecture", default_archi, help_message=\ +""" +'df' keeps more identity-preserved face. +'liae' can fix overly different face shapes. +'-u' increased likeness of the face. +'-d' (experimental) doubling the resolution using the same computation cost. +Examples: df, liae, df-d, df-ud, liae-ud, ... +""").lower() + + archi_split = archi.split('-') + + if len(archi_split) == 2: + archi_type, archi_opts = archi_split + elif len(archi_split) == 1: + archi_type, archi_opts = archi_split[0], None + else: + continue + + if archi_type not in ['df', 'liae']: + continue + + if archi_opts is not None: + if len(archi_opts) == 0: + continue + if len([ 1 for opt in archi_opts if opt not in ['u','d','t','c'] ]) != 0: + continue + + if 'd' in archi_opts: + self.options['resolution'] = np.clip ( (self.options['resolution'] // 32) * 32, min_res, max_res) + + break + self.options['archi'] = archi + + default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64) + + default_d_mask_dims = default_d_dims // 3 + default_d_mask_dims += default_d_mask_dims % 2 + default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims) + + if self.is_first_run(): + self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) + + e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['e_dims'] = e_dims + e_dims % 2 + + d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['d_dims'] = d_dims + d_dims % 2 + + d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 ) + self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2 + + if self.is_first_run() or ask_override: + if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head': + self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' or 'head' type. Masked training clips training area to full_face mask or XSeg mask, thus network will train the faces properly.") + + self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.') + self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') + self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.') + + default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) + default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8) + default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16) + + if self.is_first_run() or ask_override: + self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.") + + self.options['adabelief'] = io.input_bool ("Use AdaBelief optimizer?", default_adabelief, help_message="Use AdaBelief optimizer. It requires more VRAM, but the accuracy and the generalization of the model is higher.") + + self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.") + + self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") + + self.options['random_hsv_power'] = np.clip ( io.input_number ("Random hue/saturation/light intensity", default_random_hsv_power, add_info="0.0 .. 0.3", help_message="Random hue/saturation/light intensity applied to the src face set only at the input of the neural network. Stabilizes color perturbations during face swapping. Reduces the quality of the color transfer by selecting the closest one in the src faceset. Thus the src faceset must be diverse enough. Typical fine value is 0.05"), 0.0, 0.3 ) + + self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with lr_dropout(on) and random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 ) + + if self.options['gan_power'] != 0.0: + gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 ) + self.options['gan_patch_size'] = gan_patch_size + + gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 ) + self.options['gan_dims'] = gan_dims + + if 'df' in self.options['archi']: + self.options['true_face_power'] = np.clip ( io.input_number ("'True face' power.", default_true_face_power, add_info="0.0000 .. 1.0", help_message="Experimental option. Discriminates result face to be more like src face. Higher value - stronger discrimination. Typical value is 0.01 . Comparison - https://i.imgur.com/czScS9q.png"), 0.0, 1.0 ) + else: + self.options['true_face_power'] = 0.0 + + self.options['face_style_power'] = np.clip ( io.input_number("Face style power", default_face_style_power, add_info="0.0..100.0", help_message="Learn the color of the predicted face to be the same as dst inside mask. If you want to use this option with 'whole_face' you have to use XSeg trained mask. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.001 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 ) + self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn the area outside mask of the predicted face to be the same as dst. If you want to use this option with 'whole_face' you have to use XSeg trained mask. For whole_face you have to use XSeg trained mask. This can make face more like dst. Enabling this option increases the chance of model collapse. Typical value is 2.0"), 0.0, 100.0 ) + + self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.") + self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") + + self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, styles=0.0, uniform_yaw=Y") + + if self.options['pretrain'] and self.get_pretraining_data_path() is None: + raise Exception("pretraining_data_path is not defined") + + self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims']) + + self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) + + #override + def on_initialize(self): + device_config = nn.getCurrentDeviceConfig() + devices = device_config.devices + self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC" + nn.initialize(data_format=self.model_data_format) + tf = nn.tf + + self.resolution = resolution = self.options['resolution'] + self.face_type = {'h' : FaceType.HALF, + 'mf' : FaceType.MID_FULL, + 'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[ self.options['face_type'] ] + + if 'eyes_prio' in self.options: + self.options.pop('eyes_prio') + + eyes_mouth_prio = self.options['eyes_mouth_prio'] + + archi_split = self.options['archi'].split('-') + + if len(archi_split) == 2: + archi_type, archi_opts = archi_split + elif len(archi_split) == 1: + archi_type, archi_opts = archi_split[0], None + + self.archi_type = archi_type + + ae_dims = self.options['ae_dims'] + e_dims = self.options['e_dims'] + d_dims = self.options['d_dims'] + d_mask_dims = self.options['d_mask_dims'] + self.pretrain = self.options['pretrain'] + if self.pretrain_just_disabled: + self.set_iter(0) + + adabelief = self.options['adabelief'] + + use_fp16 = False + if self.is_exporting: + use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') + + self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power'] + random_warp = False if self.pretrain else self.options['random_warp'] + random_src_flip = self.random_src_flip if not self.pretrain else True + random_dst_flip = self.random_dst_flip if not self.pretrain else True + random_hsv_power = self.options['random_hsv_power'] if not self.pretrain else 0.0 + blur_out_mask = self.options['blur_out_mask'] + + if self.pretrain: + self.options_show_override['lr_dropout'] = 'n' + self.options_show_override['random_warp'] = False + self.options_show_override['gan_power'] = 0.0 + self.options_show_override['random_hsv_power'] = 0.0 + self.options_show_override['face_style_power'] = 0.0 + self.options_show_override['bg_style_power'] = 0.0 + self.options_show_override['uniform_yaw'] = True + + masked_training = self.options['masked_training'] + ct_mode = self.options['ct_mode'] + if ct_mode == 'none': + ct_mode = None + + + models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] + models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' + optimizer_vars_on_cpu = models_opt_device=='/CPU:0' + + input_ch=3 + bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) + mask_shape = nn.get4Dshape(resolution,resolution,1) + self.model_filename_list = [] + + with tf.device ('/CPU:0'): + #Place holders on CPU + self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') + self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') + + self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') + self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') + + self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') + self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') + self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') + self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') + + # Initializing model classes + model_archi = nn.DeepFakeArchi(resolution, use_fp16=use_fp16, opts=archi_opts) + + with tf.device (models_opt_device): + if 'df' in archi_type: + self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') + encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 + + self.inter = model_archi.Inter (in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims, name='inter') + inter_out_ch = self.inter.get_out_ch() + + self.decoder_src = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_src') + self.decoder_dst = model_archi.Decoder(in_ch=inter_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder_dst') + + self.model_filename_list += [ [self.encoder, 'encoder.npy' ], + [self.inter, 'inter.npy' ], + [self.decoder_src, 'decoder_src.npy'], + [self.decoder_dst, 'decoder_dst.npy'] ] + + if self.is_training: + if self.options['true_face_power'] != 0: + self.code_discriminator = nn.CodeDiscriminator(ae_dims, code_res=self.inter.get_out_res(), name='dis' ) + self.model_filename_list += [ [self.code_discriminator, 'code_discriminator.npy'] ] + + elif 'liae' in archi_type: + self.encoder = model_archi.Encoder(in_ch=input_ch, e_ch=e_dims, name='encoder') + encoder_out_ch = self.encoder.get_out_ch()*self.encoder.get_out_res(resolution)**2 + + self.inter_AB = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_AB') + self.inter_B = model_archi.Inter(in_ch=encoder_out_ch, ae_ch=ae_dims, ae_out_ch=ae_dims*2, name='inter_B') + + inter_out_ch = self.inter_AB.get_out_ch() + inters_out_ch = inter_out_ch*2 + self.decoder = model_archi.Decoder(in_ch=inters_out_ch, d_ch=d_dims, d_mask_ch=d_mask_dims, name='decoder') + + self.model_filename_list += [ [self.encoder, 'encoder.npy'], + [self.inter_AB, 'inter_AB.npy'], + [self.inter_B , 'inter_B.npy'], + [self.decoder , 'decoder.npy'] ] + + if self.is_training: + if gan_power != 0: + self.D_src = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="D_src") + self.model_filename_list += [ [self.D_src, 'GAN.npy'] ] + + # Initialize optimizers + lr=5e-5 + if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain: + lr_cos = 500 + lr_dropout = 0.3 + else: + lr_cos = 0 + lr_dropout = 1.0 + OptimizerClass = nn.AdaBelief if adabelief else nn.RMSprop + clipnorm = 1.0 if self.options['clipgrad'] else 0.0 + + if 'df' in archi_type: + self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() + self.src_dst_trainable_weights = self.src_dst_saveable_weights + elif 'liae' in archi_type: + self.src_dst_saveable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() + if random_warp: + self.src_dst_trainable_weights = self.src_dst_saveable_weights + else: + self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights() + + self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='src_dst_opt') + self.src_dst_opt.initialize_variables (self.src_dst_saveable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') + self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] + + if self.options['true_face_power'] != 0: + self.D_code_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='D_code_opt') + self.D_code_opt.initialize_variables ( self.code_discriminator.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') + self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ] + + if gan_power != 0: + self.D_src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, lr_cos=lr_cos, clipnorm=clipnorm, name='GAN_opt') + self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights() + self.model_filename_list += [ (self.D_src_dst_opt, 'GAN_opt.npy') ] + + if self.is_training: + # Adjust batch size for multiple GPU + gpu_count = max(1, len(devices) ) + bs_per_gpu = max(1, self.get_batch_size() // gpu_count) + self.set_batch_size( gpu_count*bs_per_gpu) + + # Compute losses per GPU + gpu_pred_src_src_list = [] + gpu_pred_dst_dst_list = [] + gpu_pred_src_dst_list = [] + gpu_pred_src_srcm_list = [] + gpu_pred_dst_dstm_list = [] + gpu_pred_src_dstm_list = [] + + gpu_src_losses = [] + gpu_dst_losses = [] + gpu_G_loss_gvs = [] + gpu_D_code_loss_gvs = [] + gpu_D_src_dst_loss_gvs = [] + + for gpu_id in range(gpu_count): + with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + with tf.device(f'/CPU:0'): + # slice on CPU, otherwise all batch data will be transfered to GPU first + batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) + gpu_warped_src = self.warped_src [batch_slice,:,:,:] + gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] + gpu_target_src = self.target_src [batch_slice,:,:,:] + gpu_target_dst = self.target_dst [batch_slice,:,:,:] + gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] + gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] + gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] + gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] + + gpu_target_srcm_anti = 1-gpu_target_srcm + gpu_target_dstm_anti = 1-gpu_target_dstm + + if blur_out_mask: + sigma = resolution / 128 + + x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_srcm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti + + x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_dstm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti + + + # process model tensors + if 'df' in archi_type: + gpu_src_code = self.inter(self.encoder(gpu_warped_src)) + gpu_dst_code = self.inter(self.encoder(gpu_warped_dst)) + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder_src(gpu_src_code) + gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + gpu_pred_src_dst_no_code_grad, _ = self.decoder_src(tf.stop_gradient(gpu_dst_code)) + + elif 'liae' in archi_type: + gpu_src_code = self.encoder (gpu_warped_src) + gpu_src_inter_AB_code = self.inter_AB (gpu_src_code) + gpu_src_code = tf.concat([gpu_src_inter_AB_code,gpu_src_inter_AB_code], nn.conv2d_ch_axis ) + gpu_dst_code = self.encoder (gpu_warped_dst) + gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) + gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) + gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) + gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis ) + + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) + gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + gpu_pred_src_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_src_dst_code)) + + gpu_pred_src_src_list.append(gpu_pred_src_src) + gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) + gpu_pred_src_dst_list.append(gpu_pred_src_dst) + + gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) + gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) + gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) + + gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) + gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2 + gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur + + gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) + gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2 + + gpu_style_mask_blur = nn.gaussian_blur(gpu_pred_src_dstm*gpu_pred_dst_dstm, max(1, resolution // 32) ) + gpu_style_mask_blur = tf.stop_gradient(tf.clip_by_value(gpu_target_srcm_blur, 0, 1.0)) + gpu_style_mask_anti_blur = 1.0 - gpu_style_mask_blur + + gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur + + gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur + gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur + + gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src + gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst + gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src + gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst + + if resolution < 256: + gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + else: + gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) + + if eyes_mouth_prio: + gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_em - gpu_pred_src_src*gpu_target_srcm_em ), axis=[1,2,3]) + + gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) + + face_style_power = self.options['face_style_power'] / 100.0 + if face_style_power != 0 and not self.pretrain: + gpu_src_loss += nn.style_loss(gpu_pred_src_dst_no_code_grad*tf.stop_gradient(gpu_pred_src_dstm), tf.stop_gradient(gpu_pred_dst_dst*gpu_pred_dst_dstm), gaussian_blur_radius=resolution//8, loss_weight=10000*face_style_power) + + bg_style_power = self.options['bg_style_power'] / 100.0 + if bg_style_power != 0 and not self.pretrain: + gpu_target_dst_style_anti_masked = gpu_target_dst*gpu_style_mask_anti_blur + gpu_psd_style_anti_masked = gpu_pred_src_dst*gpu_style_mask_anti_blur + + gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim( gpu_psd_style_anti_masked, gpu_target_dst_style_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square(gpu_psd_style_anti_masked - gpu_target_dst_style_anti_masked), axis=[1,2,3] ) + + if resolution < 256: + gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + else: + gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) + + if eyes_mouth_prio: + gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_em - gpu_pred_dst_dst*gpu_target_dstm_em ), axis=[1,2,3]) + + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) + + gpu_src_losses += [gpu_src_loss] + gpu_dst_losses += [gpu_dst_loss] + + gpu_G_loss = gpu_src_loss + gpu_dst_loss + + def DLoss(labels,logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[1,2,3]) + + if self.options['true_face_power'] != 0: + gpu_src_code_d = self.code_discriminator( gpu_src_code ) + gpu_src_code_d_ones = tf.ones_like (gpu_src_code_d) + gpu_src_code_d_zeros = tf.zeros_like(gpu_src_code_d) + gpu_dst_code_d = self.code_discriminator( gpu_dst_code ) + gpu_dst_code_d_ones = tf.ones_like(gpu_dst_code_d) + + gpu_G_loss += self.options['true_face_power']*DLoss(gpu_src_code_d_ones, gpu_src_code_d) + + gpu_D_code_loss = (DLoss(gpu_dst_code_d_ones , gpu_dst_code_d) + \ + DLoss(gpu_src_code_d_zeros, gpu_src_code_d) ) * 0.5 + + gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ] + + if gan_power != 0: + gpu_pred_src_src_d, \ + gpu_pred_src_src_d2 = self.D_src(gpu_pred_src_src_masked_opt) + + gpu_pred_src_src_d_ones = tf.ones_like (gpu_pred_src_src_d) + gpu_pred_src_src_d_zeros = tf.zeros_like(gpu_pred_src_src_d) + + gpu_pred_src_src_d2_ones = tf.ones_like (gpu_pred_src_src_d2) + gpu_pred_src_src_d2_zeros = tf.zeros_like(gpu_pred_src_src_d2) + + gpu_target_src_d, \ + gpu_target_src_d2 = self.D_src(gpu_target_src_masked_opt) + + gpu_target_src_d_ones = tf.ones_like(gpu_target_src_d) + gpu_target_src_d2_ones = tf.ones_like(gpu_target_src_d2) + + gpu_D_src_dst_loss = (DLoss(gpu_target_src_d_ones , gpu_target_src_d) + \ + DLoss(gpu_pred_src_src_d_zeros , gpu_pred_src_src_d) ) * 0.5 + \ + (DLoss(gpu_target_src_d2_ones , gpu_target_src_d2) + \ + DLoss(gpu_pred_src_src_d2_zeros , gpu_pred_src_src_d2) ) * 0.5 + + gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights() ) ]#+self.D_src_x2.get_weights() + + gpu_G_loss += gan_power*(DLoss(gpu_pred_src_src_d_ones, gpu_pred_src_src_d) + \ + DLoss(gpu_pred_src_src_d2_ones, gpu_pred_src_src_d2)) + + if masked_training: + # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) + gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) + + gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights )] + + + + + # Average losses and gradients, and create optimizer update ops + with tf.device(f'/CPU:0'): + pred_src_src = nn.concat(gpu_pred_src_src_list, 0) + pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) + pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) + pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) + pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) + pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) + + with tf.device (models_opt_device): + src_loss = tf.concat(gpu_src_losses, 0) + dst_loss = tf.concat(gpu_dst_losses, 0) + src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) + + if self.options['true_face_power'] != 0: + D_loss_gv_op = self.D_code_opt.get_update_op (nn.average_gv_list(gpu_D_code_loss_gvs)) + + if gan_power != 0: + src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) ) + + + # Initializing training and view functions + def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + s, d = nn.tf_sess.run ( [ src_loss, dst_loss, src_dst_loss_gv_op], + feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em, + })[:2] + return s, d + self.src_dst_train = src_dst_train + + if self.options['true_face_power'] != 0: + def D_train(warped_src, warped_dst): + nn.tf_sess.run ([D_loss_gv_op], feed_dict={self.warped_src: warped_src, self.warped_dst: warped_dst}) + self.D_train = D_train + + if gan_power != 0: + def D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em}) + self.D_src_dst_train = D_src_dst_train + + + def AE_view(warped_src, warped_dst): + return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], + feed_dict={self.warped_src:warped_src, + self.warped_dst:warped_dst}) + self.AE_view = AE_view + else: + # Initializing merge function + with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): + if 'df' in archi_type: + gpu_dst_code = self.inter(self.encoder(self.warped_dst)) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) + + elif 'liae' in archi_type: + gpu_dst_code = self.encoder (self.warped_dst) + gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) + gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) + gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + + + def AE_merge( warped_dst): + return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) + + self.AE_merge = AE_merge + + # Loading/initializing all models/optimizers weights + for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): + if self.pretrain_just_disabled: + do_init = False + if 'df' in archi_type: + if model == self.inter: + do_init = True + elif 'liae' in archi_type: + if model == self.inter_AB or model == self.inter_B: + do_init = True + else: + do_init = self.is_first_run() + if self.is_training and gan_power != 0 and model == self.D_src: + if self.gan_model_changed: + do_init = True + + if not do_init: + do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) + + if do_init: + model.init_weights() + + + ############### + + # initializing sample generators + if self.is_training: + training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() + training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() + + random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None + + cpu_count = multiprocessing.cpu_count() + src_generators_count = cpu_count // 2 + dst_generators_count = cpu_count // 2 + if ct_mode is not None: + src_generators_count = int(src_generators_count * 1.5) + + self.set_training_data_generators ([ + SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=random_src_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'random_hsv_shift_amount' : random_hsv_power, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, + generators_count=src_generators_count ), + + SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(scale_range=[-0.15, 0.15], random_flip=random_dst_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, + generators_count=dst_generators_count ) + ]) + + if self.pretrain_just_disabled: + self.update_sample_for_preview(force_new=True) + + def export_dfm (self): + output_path=self.get_strpath_storage_for_file('model.dfm') + + io.log_info(f'Dumping .dfm to {output_path}') + + tf = nn.tf + nn.set_data_format('NCHW') + + with tf.device (nn.tf_default_device_name): + warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') + warped_dst = tf.transpose(warped_dst, (0,3,1,2)) + + + if 'df' in self.archi_type: + gpu_dst_code = self.inter(self.encoder(warped_dst)) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code) + _, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code) + + elif 'liae' in self.archi_type: + gpu_dst_code = self.encoder (warped_dst) + gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) + gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) + gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + + gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1)) + gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1)) + gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1)) + + tf.identity(gpu_pred_dst_dstm, name='out_face_mask') + tf.identity(gpu_pred_src_dst, name='out_celeb_face') + tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') + + output_graph_def = tf.graph_util.convert_variables_to_constants( + nn.tf_sess, + tf.get_default_graph().as_graph_def(), + ['out_face_mask','out_celeb_face','out_celeb_face_mask'] + ) + + import tf2onnx + with tf.device("/CPU:0"): + model_proto, _ = tf2onnx.convert._convert_common( + output_graph_def, + name='SAEHD', + input_names=['in_face:0'], + output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'], + opset=12, + output_path=output_path) + + #override + def get_model_filename_list(self): + return self.model_filename_list + + #override + def onSave(self): + for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): + model.save_weights ( self.get_strpath_storage_for_file(filename) ) + + #override + def should_save_preview_history(self): + return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \ + (io.is_colab() and self.iter % 100 == 0) + + #override + def onTrainOneIter(self): + if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled: + io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n') + + ( (warped_src, target_src, target_srcm, target_srcm_em), \ + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() + + src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + if self.options['true_face_power'] != 0 and not self.pretrain: + self.D_train (warped_src, warped_dst) + + if self.gan_power != 0: + self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) + + #override + def onGetPreview(self, samples, for_history=False): + ( (warped_src, target_src, target_srcm, target_srcm_em), + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples + + S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] + DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] + + target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] + + n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) + + if self.resolution <= 256: + result = [] + + st = [] + for i in range(n_samples): + ar = S[i], SS[i], D[i], DD[i], SD[i] + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD', np.concatenate (st, axis=0 )), ] + + + st_m = [] + for i in range(n_samples): + SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i] + + ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*SD_mask + st_m.append ( np.concatenate ( ar, axis=1) ) + + result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ] + else: + result = [] + + st = [] + for i in range(n_samples): + ar = S[i], SS[i] + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD src-src', np.concatenate (st, axis=0 )), ] + + st = [] + for i in range(n_samples): + ar = D[i], DD[i] + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD dst-dst', np.concatenate (st, axis=0 )), ] + + st = [] + for i in range(n_samples): + ar = D[i], SD[i] + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD pred', np.concatenate (st, axis=0 )), ] + + + st_m = [] + for i in range(n_samples): + ar = S[i]*target_srcm[i], SS[i] + st_m.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD masked src-src', np.concatenate (st_m, axis=0 )), ] + + st_m = [] + for i in range(n_samples): + ar = D[i]*target_dstm[i], DD[i]*DDM[i] + st_m.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD masked dst-dst', np.concatenate (st_m, axis=0 )), ] + + st_m = [] + for i in range(n_samples): + SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i] + ar = D[i]*target_dstm[i], SD[i]*SD_mask + st_m.append ( np.concatenate ( ar, axis=1) ) + result += [ ('SAEHD masked pred', np.concatenate (st_m, axis=0 )), ] + + return result + + def predictor_func (self, face=None): + face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") + + bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ] + + return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] + + #override + def get_MergerConfig(self): + import merger + return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') + +Model = SAEHDModel diff --git a/models/Model_SAEHD/__init__.py b/models/Model_SAEHD/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0188f11aec7882710edf9d40586a00823f0d8c20 --- /dev/null +++ b/models/Model_SAEHD/__init__.py @@ -0,0 +1 @@ +from .Model import Model diff --git a/models/Model_TEST/Model C1.py b/models/Model_TEST/Model C1.py new file mode 100644 index 0000000000000000000000000000000000000000..f15b51c267926c92908d5bc8d92479c457818ec6 --- /dev/null +++ b/models/Model_TEST/Model C1.py @@ -0,0 +1,890 @@ +import multiprocessing +import operator +from functools import partial + +import numpy as np + +from core import mathlib +from core.interact import interact as io +from core.leras import nn +from facelib import FaceType +from models import ModelBase +from samplelib import * + +class TESTModel(ModelBase): + + #override + def on_initialize_options(self): + device_config = nn.getCurrentDeviceConfig() + + lowest_vram = 2 + if len(device_config.devices) != 0: + lowest_vram = device_config.devices.get_worst_device().total_mem_gb + + if lowest_vram >= 4: + suggest_batch_size = 8 + else: + suggest_batch_size = 4 + + yn_str = {True:'y',False:'n'} + min_res = 64 + max_res = 640 + + #default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False) + default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 96) + default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f') + default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) + + default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) + default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) + default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) + default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) + default_masked_training = self.options['masked_training'] = self.load_or_def_option('masked_training', True) + default_eyes_mouth_prio = self.options['eyes_mouth_prio'] = self.load_or_def_option('eyes_mouth_prio', True) + default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', True) + default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', True) + + default_adabelief = self.options['adabelief'] = self.load_or_def_option('adabelief', True) + + lr_dropout = self.load_or_def_option('lr_dropout', 'n') + lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp + default_lr_dropout = self.options['lr_dropout'] = lr_dropout + + default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) + default_face_style_power = self.options['face_style_power'] = self.load_or_def_option('face_style_power', 0.0) + default_bg_style_power = self.options['bg_style_power'] = self.load_or_def_option('bg_style_power', 0.0) + default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none') + default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False) + default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False) + + ask_override = self.ask_override() + if self.is_first_run() or ask_override: + self.ask_autobackup_hour() + self.ask_write_preview_history() + self.ask_target_iter() + self.ask_random_src_flip() + self.ask_random_dst_flip() + self.ask_batch_size(suggest_batch_size) + #self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.') + + if self.is_first_run(): + resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 16 and 32 for -d archi.") + resolution = np.clip ( (resolution // 16) * 16, min_res, max_res) + self.options['resolution'] = resolution + + + + self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Half face has better resolution, but covers less area of cheeks. Mid face is 30% wider than half face. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower() + + + default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64) + default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', 32) + + if self.is_first_run(): + self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) + + e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['e_dims'] = e_dims + e_dims % 2 + + d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['d_dims'] = d_dims + d_dims % 2 + + d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 ) + self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2 + + if self.is_first_run() or ask_override: + if self.options['face_type'] == 'wf' or self.options['face_type'] == 'head': + self.options['masked_training'] = io.input_bool ("Masked training", default_masked_training, help_message="This option is available only for 'whole_face' or 'head' type. Masked training clips training area to full_face mask or XSeg mask, thus network will train the faces properly.") + + self.options['eyes_mouth_prio'] = io.input_bool ("Eyes and mouth priority", default_eyes_mouth_prio, help_message='Helps to fix eye problems during training like "alien eyes" and wrong eyes direction. Also makes the detail of the teeth higher.') + self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') + self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.') + + default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) + default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8) + default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16) + + if self.is_first_run() or ask_override: + self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.") + + self.options['adabelief'] = io.input_bool ("Use AdaBelief optimizer?", default_adabelief, help_message="Use AdaBelief optimizer. It requires more VRAM, but the accuracy and the generalization of the model is higher.") + + self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.") + + self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") + + self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with lr_dropout(on) and random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 ) + + if self.options['gan_power'] != 0.0: + gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 ) + self.options['gan_patch_size'] = gan_patch_size + + gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 ) + self.options['gan_dims'] = gan_dims + + self.options['face_style_power'] = np.clip ( io.input_number("Face style power", default_face_style_power, add_info="0.0..100.0", help_message="Learn the color of the predicted face to be the same as dst inside mask. If you want to use this option with 'whole_face' you have to use XSeg trained mask. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.001 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 ) + self.options['bg_style_power'] = np.clip ( io.input_number("Background style power", default_bg_style_power, add_info="0.0..100.0", help_message="Learn the area outside mask of the predicted face to be the same as dst. If you want to use this option with 'whole_face' you have to use XSeg trained mask. For whole_face you have to use XSeg trained mask. This can make face more like dst. Enabling this option increases the chance of model collapse. Typical value is 2.0"), 0.0, 100.0 ) + + self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. Try all modes to find the best.") + self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") + + self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain, help_message="Pretrain the model with large amount of various faces. After that, model can be used to train the fakes more quickly. Forces random_warp=N, random_flips=Y, gan_power=0.0, lr_dropout=N, styles=0.0, uniform_yaw=Y") + + if self.options['pretrain'] and self.get_pretraining_data_path() is None: + raise Exception("pretraining_data_path is not defined") + + self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims']) + + self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) + + #override + def on_initialize(self): + device_config = nn.getCurrentDeviceConfig() + devices = device_config.devices + self.model_data_format = "NCHW" if len(devices) != 0 and not self.is_debug() else "NHWC" + nn.initialize(data_format=self.model_data_format) + tf = nn.tf + + self.resolution = resolution = self.options['resolution'] + inter_res = self.inter_res = resolution // 32 + self.face_type = {'h' : FaceType.HALF, + 'mf' : FaceType.MID_FULL, + 'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[ self.options['face_type'] ] + + if 'eyes_prio' in self.options: + self.options.pop('eyes_prio') + + eyes_mouth_prio = self.options['eyes_mouth_prio'] + + ae_dims = self.options['ae_dims'] + inter_dims = ae_dims + e_dims = self.options['e_dims'] + d_dims = self.options['d_dims'] + d_mask_dims = self.options['d_mask_dims'] + self.pretrain = self.options['pretrain'] + if self.pretrain_just_disabled: + self.set_iter(0) + + adabelief = self.options['adabelief'] + + use_fp16 = False + if self.is_exporting: + use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') + + self.gan_power = gan_power = 0.0 if self.pretrain else self.options['gan_power'] + random_warp = False if self.pretrain else self.options['random_warp'] + random_src_flip = self.random_src_flip if not self.pretrain else True + random_dst_flip = self.random_dst_flip if not self.pretrain else True + blur_out_mask = self.options['blur_out_mask'] + learn_dst_bg = False#True + + if self.pretrain: + self.options_show_override['gan_power'] = 0.0 + self.options_show_override['random_warp'] = False + self.options_show_override['lr_dropout'] = 'n' + self.options_show_override['face_style_power'] = 0.0 + self.options_show_override['bg_style_power'] = 0.0 + self.options_show_override['uniform_yaw'] = True + + masked_training = self.options['masked_training'] + ct_mode = self.options['ct_mode'] + if ct_mode == 'none': + ct_mode = None + + + models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] + models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' + optimizer_vars_on_cpu = models_opt_device=='/CPU:0' + + input_ch=3 + bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) + mask_shape = nn.get4Dshape(resolution,resolution,1) + self.model_filename_list = [] + + with tf.device ('/CPU:0'): + #Place holders on CPU + self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') + self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') + + self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') + self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') + + self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') + self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') + self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') + self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') + + conv_dtype = tf.float16 if use_fp16 else tf.float32 + + # Initializing model classes + class Downscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=5 ): + self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + return tf.nn.leaky_relu(self.conv1(x), 0.1) + + class Upscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=3 ): + self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2) + return x + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, inp): + x = self.conv1(inp) + x = tf.nn.leaky_relu(x, 0.2) + x = self.conv2(x) + x = tf.nn.leaky_relu(inp+x, 0.2) + return x + + class Encoder(nn.ModelBase): + def on_build(self): + self.down1 = Downscale(input_ch, e_dims, kernel_size=5) + self.res1 = ResidualBlock(e_dims) + self.down2 = Downscale(e_dims, e_dims*2, kernel_size=5) + self.down3 = Downscale(e_dims*2, e_dims*4, kernel_size=5) + self.down4 = Downscale(e_dims*4, e_dims*8, kernel_size=5) + self.down5 = Downscale(e_dims*8, e_dims*8, kernel_size=5) + self.res5 = ResidualBlock(e_dims*8) + self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims ) + + def forward(self, x): + if use_fp16: + x = tf.cast(x, tf.float16) + x = self.down1(x) + x = self.res1(x) + x = self.down2(x) + x = self.down3(x) + x = self.down4(x) + x = self.down5(x) + x = self.res5(x) + if use_fp16: + x = tf.cast(x, tf.float32) + x = nn.pixel_norm(nn.flatten(x), axes=-1) + x = self.dense1(x) + return x + + + class Inter(nn.ModelBase): + def on_build(self): + self.dense2 = nn.Dense(ae_dims, inter_res * inter_res * inter_dims) + self.res0 = ResidualBlock(inter_dims) + self.res1 = ResidualBlock(inter_dims) + self.res2 = ResidualBlock(inter_dims) + self.res3 = ResidualBlock(inter_dims) + self.res4 = ResidualBlock(inter_dims) + self.res5 = ResidualBlock(inter_dims) + + def forward(self, inp): + x = inp + x = self.dense2(x) + x = nn.reshape_4D (x, inter_res, inter_res, inter_dims) + x = self.res0(x) + x = self.res1(x) + x = self.res2(x) + x = self.res3(x) + x = self.res4(x) + x = self.res5(x) + return x + + class Decoder(nn.ModelBase): + def on_build(self ): + self.upscale0 = Upscale(inter_dims, d_dims*8, kernel_size=3) + self.upscale1 = Upscale(d_dims*8, d_dims*8, kernel_size=3) + self.upscale2 = Upscale(d_dims*8, d_dims*4, kernel_size=3) + self.upscale3 = Upscale(d_dims*4, d_dims*2, kernel_size=3) + + self.res0 = ResidualBlock(d_dims*8, kernel_size=3) + self.res1 = ResidualBlock(d_dims*8, kernel_size=3) + self.res2 = ResidualBlock(d_dims*4, kernel_size=3) + self.res3 = ResidualBlock(d_dims*2, kernel_size=3) + + self.upscalem0 = Upscale(inter_dims, d_mask_dims*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3) + self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3) + self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3) + self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + + self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) + self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + + def forward(self, z): + if use_fp16: + z = tf.cast(z, tf.float16) + + x = self.upscale0(z) + x = self.res0(x) + x = self.upscale1(x) + x = self.res1(x) + x = self.upscale2(x) + x = self.res2(x) + x = self.upscale3(x) + x = self.res3(x) + + x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x), + self.out_conv1(x), + self.out_conv2(x), + self.out_conv3(x)), nn.conv2d_ch_axis), 2) ) + m = self.upscalem0(z) + m = self.upscalem1(m) + m = self.upscalem2(m) + m = self.upscalem3(m) + m = self.upscalem4(m) + m = tf.nn.sigmoid(self.out_convm(m)) + + if use_fp16: + x = tf.cast(x, tf.float32) + m = tf.cast(m, tf.float32) + return x, m + + with tf.device (models_opt_device): + + self.encoder = Encoder(name='encoder') + self.inter_src = Inter(name='inter_src') + self.inter_dst = Inter(name='inter_dst') + self.decoder = Decoder(name='decoder') + + self.model_filename_list += [ [self.encoder, 'encoder.npy'], + [self.inter_src, 'inter_src.npy'], + [self.inter_dst, 'inter_dst.npy'], + [self.decoder, 'decoder.npy'] ] + + if self.is_training: + if gan_power != 0: + self.D_src = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="D_src") + self.model_filename_list += [ [self.D_src, 'GAN.npy'] ] + + # Initialize optimizers + lr=5e-5 + lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] and not self.pretrain else 1.0 + OptimizerClass = nn.AdaBelief if adabelief else nn.RMSprop + clipnorm = 1.0 if self.options['clipgrad'] else 0.0 + + self.all_trainable_weights = self.encoder.get_weights() + self.inter_src.get_weights() + self.inter_dst.get_weights() + self.decoder.get_weights() + #if random_warp: + # self.src_dst_trainable_weights += self.inter_B.get_weights() + self.inter_AB.get_weights() + + self.src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') + self.src_dst_opt.initialize_variables (self.all_trainable_weights, vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu') + self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] + + if gan_power != 0: + self.D_src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt') + self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights() + self.model_filename_list += [ (self.D_src_dst_opt, 'GAN_opt.npy') ] + + if self.is_training: + # Adjust batch size for multiple GPU + gpu_count = max(1, len(devices) ) + bs_per_gpu = max(1, self.get_batch_size() // gpu_count) + self.set_batch_size( gpu_count*bs_per_gpu) + + # Compute losses per GPU + gpu_pred_src_src_list = [] + gpu_pred_dst_dst_list = [] + gpu_pred_src_dst_list = [] + gpu_pred_src_srcm_list = [] + gpu_pred_dst_dstm_list = [] + gpu_pred_src_dstm_list = [] + + gpu_src_losses = [] + gpu_dst_losses = [] + gpu_G_loss_gvs = [] + gpu_src_loss_gvs = [] + gpu_dst_loss_gvs = [] + + gpu_D_code_loss_gvs = [] + gpu_D_src_dst_loss_gvs = [] + + def DLossOnes(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3]) + def DLossZeros(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3]) + + for gpu_id in range(gpu_count): + with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + with tf.device(f'/CPU:0'): + # slice on CPU, otherwise all batch data will be transfered to GPU first + batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) + gpu_warped_src = self.warped_src [batch_slice,:,:,:] + gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] + gpu_target_src = self.target_src [batch_slice,:,:,:] + gpu_target_dst = self.target_dst [batch_slice,:,:,:] + gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] + gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] + gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] + gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] + + gpu_target_srcm_anti = 1-gpu_target_srcm + gpu_target_dstm_anti = 1-gpu_target_dstm + + if blur_out_mask: + #gpu_target_src = gpu_target_src*gpu_target_srcm_blur + nn.gaussian_blur(gpu_target_src, resolution // 32)*gpu_target_srcm_anti_blur + #gpu_target_dst = gpu_target_dst*gpu_target_dstm_blur + nn.gaussian_blur(gpu_target_dst, resolution // 32)*gpu_target_dstm_anti_blur + bg_blur_div = 128 + + gpu_target_src = gpu_target_src*gpu_target_srcm + \ + tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, resolution / bg_blur_div), + (1-nn.gaussian_blur(gpu_target_srcm, resolution / bg_blur_div) ) ) * gpu_target_srcm_anti + + gpu_target_dst = gpu_target_dst*gpu_target_dstm + \ + tf.math.divide_no_nan(nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, resolution / bg_blur_div), + (1-nn.gaussian_blur(gpu_target_dstm, resolution / bg_blur_div)) ) * gpu_target_dstm_anti + + # process model tensors + + gpu_src_code = self.encoder (gpu_warped_src) + gpu_src_code = self.inter_src (gpu_src_code) + gpu_dst_code = self.encoder (gpu_warped_dst) + gpu_dst_code, gpu_src_dst_code = self.inter_dst (gpu_dst_code), self.inter_src (gpu_dst_code) + + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) + gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + + #gpu_pred_src_src_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_src_code)) + #gpu_pred_dst_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_dst_code)) + #gpu_pred_src_dst_no_code_grad, _ = self.decoder(tf.stop_gradient(gpu_src_dst_code)) + + gpu_pred_src_src_list.append(gpu_pred_src_src) + gpu_pred_dst_dst_list.append(gpu_pred_dst_dst) + gpu_pred_src_dst_list.append(gpu_pred_src_dst) + + gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) + gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) + gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) + + gpu_target_srcm_blur = nn.gaussian_blur(gpu_target_srcm, max(1, resolution // 32) ) + gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_blur, 0, 0.5) * 2 + gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur + + gpu_target_dstm_blur = nn.gaussian_blur(gpu_target_dstm, max(1, resolution // 32) ) + gpu_target_dstm_style_blur = gpu_target_dstm_blur #default style mask is 0.5 on boundary + gpu_target_dstm_style_anti_blur = 1.0 - gpu_target_dstm_style_blur + gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_blur, 0, 0.5) * 2 + + gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur + gpu_target_dst_style_anti_masked = gpu_target_dst*gpu_target_dstm_style_anti_blur + + gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur + gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur + + gpu_target_src_masked_opt = gpu_target_src*gpu_target_srcm_blur if masked_training else gpu_target_src + gpu_target_dst_masked_opt = gpu_target_dst_masked if masked_training else gpu_target_dst + gpu_pred_src_src_masked_opt = gpu_pred_src_src*gpu_target_srcm_blur if masked_training else gpu_pred_src_src + gpu_pred_dst_dst_masked_opt = gpu_pred_dst_dst*gpu_target_dstm_blur if masked_training else gpu_pred_dst_dst + gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*gpu_target_dstm_style_anti_blur + + # Structural loss + + gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + gpu_dst_loss = tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) + + # Pixel loss + gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst_masked_opt - gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) + + # Eyes+mouth prio loss + if eyes_mouth_prio: + gpu_src_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_src*gpu_target_srcm_em - gpu_pred_src_src*gpu_target_srcm_em ), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean ( 300*tf.abs ( gpu_target_dst*gpu_target_dstm_em - gpu_pred_dst_dst*gpu_target_dstm_em ), axis=[1,2,3]) + + # Mask loss + gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) + + #gpu_src_loss += nn.style_loss(gpu_pred_src_src_no_code_grad*tf.stop_gradient(gpu_pred_src_srcm), gpu_target_src*gpu_target_srcm, gaussian_blur_radius=resolution//8, loss_weight=10000*0.05) + #gpu_dst_loss += nn.style_loss(gpu_pred_dst_dst_no_code_grad*tf.stop_gradient(gpu_pred_dst_dstm), gpu_target_dst*gpu_target_dstm, gaussian_blur_radius=resolution//8, loss_weight=10000*0.05) + + # face/bg style loss + # face_style_power = self.options['face_style_power'] / 100.0 + # if face_style_power != 0 and not self.pretrain: + # gpu_src_loss += nn.style_loss(gpu_pred_src_dst_no_code_grad*tf.stop_gradient(gpu_pred_src_dstm), tf.stop_gradient(gpu_pred_dst_dst*gpu_pred_dst_dstm), gaussian_blur_radius=resolution//8, loss_weight=10000*face_style_power) + + # bg_style_power = self.options['bg_style_power'] / 100.0 + # if bg_style_power != 0 and not self.pretrain: + # gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim( gpu_psd_target_dst_style_anti_masked, gpu_target_dst_style_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + # gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square(gpu_psd_target_dst_style_anti_masked - gpu_target_dst_style_anti_masked), axis=[1,2,3] ) + + gpu_src_losses += [gpu_src_loss] + gpu_dst_losses += [gpu_dst_loss] + + #gpu_G_loss = gpu_src_loss + gpu_dst_loss + #gpu_G_loss_gvs += [ nn.gradients ( gpu_G_loss, self.src_dst_trainable_weights )] + + gpu_src_loss_gvs += [ nn.gradients ( gpu_src_loss, self.encoder.get_weights() + self.inter_src.get_weights()+ self.decoder.get_weights() )] + gpu_dst_loss_gvs += [ nn.gradients ( gpu_dst_loss, self.encoder.get_weights() + self.inter_dst.get_weights()+ self.decoder.get_weights() )] + + # # residual dst background transfer loss + # if learn_dst_bg and 'liae' in archi_type: + # psd_bg_mask = 1.0 - tf.where( tf.greater_equal( gpu_pred_src_dstm + gpu_pred_dst_dstm, tf.constant([0.1], nn.floatx) ), tf.ones_like(gpu_pred_src_dstm), tf.zeros_like(gpu_pred_src_dstm) ) + # psd_bg_mask = tf.clip_by_value( (nn.gaussian_blur(psd_bg_mask, max(1, resolution // 16) ) - 0.5), 0, 0.5) * 2.0 + # psd_bg_mask = tf.stop_gradient(psd_bg_mask) + + # gpu_G_loss += tf.reduce_mean( 10*tf.square(gpu_pred_dst_dst_no_code_grad*psd_bg_mask-gpu_target_dst*psd_bg_mask ),axis=[1,2,3] ) + + # if self.options['true_face_power'] != 0: + # gpu_src_code_d = self.code_discriminator( gpu_src_code ) + # gpu_dst_code_d = self.code_discriminator( gpu_dst_code ) + # gpu_G_loss += self.options['true_face_power']*DLossOnes(gpu_src_code_d) + # gpu_D_code_loss = (DLossOnes(gpu_dst_code_d) + DLossZeros(gpu_src_code_d))*0.5 + # gpu_D_code_loss_gvs += [ nn.gradients (gpu_D_code_loss, self.code_discriminator.get_weights() ) ] + + # if gan_power != 0: + # gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.D_src(gpu_pred_src_src_masked_opt) + # gpu_target_src_d, gpu_target_src_d2 = self.D_src(gpu_target_src_masked_opt) + + # gpu_D_src_dst_loss = (DLossOnes(gpu_target_src_d) + DLossZeros(gpu_pred_src_src_d) ) * 0.5 + \ + # (DLossOnes(gpu_target_src_d2) + DLossZeros(gpu_pred_src_src_d2) ) * 0.5 + + # gpu_D_src_dst_loss_gvs += [ nn.gradients (gpu_D_src_dst_loss, self.D_src.get_weights() ) ] + + # gpu_G_loss += gan_power*(DLossOnes(gpu_pred_src_src_d) + \ + # DLossOnes(gpu_pred_src_src_d2)) + + # if masked_training: + # # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan + # gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) + # gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) + + + + + + # Average losses and gradients, and create optimizer update ops + with tf.device(f'/CPU:0'): + pred_src_src = nn.concat(gpu_pred_src_src_list, 0) + pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) + pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) + pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) + pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) + pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) + + with tf.device (models_opt_device): + src_loss = tf.concat(gpu_src_losses, 0) + dst_loss = tf.concat(gpu_dst_losses, 0) + #src_dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gvs)) + + src_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_src_loss_gvs)) + dst_loss_gv_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_dst_loss_gvs)) + + if gan_power != 0: + src_D_src_dst_loss_gv_op = self.D_src_dst_opt.get_update_op (nn.average_gv_list(gpu_D_src_dst_loss_gvs) ) + + + # Initializing training and view functions + def src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + s, = nn.tf_sess.run ( [ src_loss, src_loss_gv_op], + feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em + })[:1] + d, = nn.tf_sess.run ( [ dst_loss, dst_loss_gv_op], + feed_dict={self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em + })[:1] + return s, d + self.src_dst_train = src_dst_train + + + if gan_power != 0: + def D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em}) + self.D_src_dst_train = D_src_dst_train + + + def AE_view(warped_src, warped_dst): + return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], + feed_dict={self.warped_src:warped_src, + self.warped_dst:warped_dst}) + self.AE_view = AE_view + else: + # Initializing merge function + with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): + + gpu_dst_code = self.encoder (self.warped_dst) + gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) + gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) + gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + + + def AE_merge( warped_dst): + return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst}) + + self.AE_merge = AE_merge + + # Loading/initializing all models/optimizers weights + for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): + if self.pretrain_just_disabled: + do_init = False + if model == self.inter_src or model == self.inter_dst: + do_init = True + else: + do_init = self.is_first_run() + if self.is_training and gan_power != 0 and model == self.D_src: + if self.gan_model_changed: + do_init = True + + if not do_init: + do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) + + if do_init: + model.init_weights() + + + ############### + + # initializing sample generators + if self.is_training: + training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path() + training_data_dst_path = self.training_data_dst_path if not self.pretrain else self.get_pretraining_data_path() + + random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None + + cpu_count = min(multiprocessing.cpu_count(), 8) + src_generators_count = cpu_count // 2 + dst_generators_count = cpu_count // 2 + if ct_mode is not None: + src_generators_count = int(src_generators_count * 1.5) + + self.set_training_data_generators ([ + SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=random_src_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, + generators_count=src_generators_count ), + + SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=random_dst_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'] or self.pretrain, + generators_count=dst_generators_count ) + ]) + + self.last_src_samples_loss = [] + self.last_dst_samples_loss = [] + + if self.pretrain_just_disabled: + self.update_sample_for_preview(force_new=True) + + def export_dfm (self): + output_path=self.get_strpath_storage_for_file('model.dfm') + + io.log_info(f'Dumping .dfm to {output_path}') + + tf = nn.tf + nn.set_data_format('NCHW') + + with tf.device (nn.tf_default_device_name): + warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') + warped_dst = tf.transpose(warped_dst, (0,3,1,2)) + + + gpu_dst_code = self.encoder (warped_dst) + gpu_dst_inter_B_code = self.inter_B (gpu_dst_code) + gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code) + gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + + gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1)) + gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1)) + gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1)) + + tf.identity(gpu_pred_dst_dstm, name='out_face_mask') + tf.identity(gpu_pred_src_dst, name='out_celeb_face') + tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') + + output_graph_def = tf.graph_util.convert_variables_to_constants( + nn.tf_sess, + tf.get_default_graph().as_graph_def(), + ['out_face_mask','out_celeb_face','out_celeb_face_mask'] + ) + + import tf2onnx + with tf.device("/CPU:0"): + model_proto, _ = tf2onnx.convert._convert_common( + output_graph_def, + name='TEST', + input_names=['in_face:0'], + output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'], + opset=13, + output_path=output_path) + + #override + def get_model_filename_list(self): + return self.model_filename_list + + #override + def onSave(self): + for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): + model.save_weights ( self.get_strpath_storage_for_file(filename) ) + + #override + def should_save_preview_history(self): + return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \ + (io.is_colab() and self.iter % 100 == 0) + + #override + def onTrainOneIter(self): + if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled: + io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n') + + bs = self.get_batch_size() + + ( (warped_src, target_src, target_srcm, target_srcm_em), \ + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() + + src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + # for i in range(bs): + # self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i],) ) + # self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i],) ) + + # if len(self.last_src_samples_loss) >= bs*16: + # src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True) + # dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True) + + # target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) + # target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) + # target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] ) + + # target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) + # target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) + # target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) + + # src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em) + # self.last_src_samples_loss = [] + # self.last_dst_samples_loss = [] + + if self.gan_power != 0: + self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) + + #override + def onGetPreview(self, samples, for_history=False): + ( (warped_src, target_src, target_srcm, target_srcm_em), + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples + + S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] + DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] + + target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] + + n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) + + if self.resolution <= 256: + result = [] + + st = [] + for i in range(n_samples): + ar = S[i], SS[i], D[i], DD[i], SD[i] + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('TEST', np.concatenate (st, axis=0 )), ] + + + st_m = [] + for i in range(n_samples): + SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i] + + ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*SD_mask + st_m.append ( np.concatenate ( ar, axis=1) ) + + result += [ ('TEST masked', np.concatenate (st_m, axis=0 )), ] + else: + result = [] + + st = [] + for i in range(n_samples): + ar = S[i], SS[i] + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('TEST src-src', np.concatenate (st, axis=0 )), ] + + st = [] + for i in range(n_samples): + ar = D[i], DD[i] + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('TEST dst-dst', np.concatenate (st, axis=0 )), ] + + st = [] + for i in range(n_samples): + ar = D[i], SD[i] + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('TEST pred', np.concatenate (st, axis=0 )), ] + + + st_m = [] + for i in range(n_samples): + ar = S[i]*target_srcm[i], SS[i] + st_m.append ( np.concatenate ( ar, axis=1) ) + result += [ ('TEST masked src-src', np.concatenate (st_m, axis=0 )), ] + + st_m = [] + for i in range(n_samples): + ar = D[i]*target_dstm[i], DD[i]*DDM[i] + st_m.append ( np.concatenate ( ar, axis=1) ) + result += [ ('TEST masked dst-dst', np.concatenate (st_m, axis=0 )), ] + + st_m = [] + for i in range(n_samples): + SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i] + ar = D[i]*target_dstm[i], SD[i]*SD_mask + st_m.append ( np.concatenate ( ar, axis=1) ) + result += [ ('TEST masked pred', np.concatenate (st_m, axis=0 )), ] + + return result + + def predictor_func (self, face=None): + face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") + + bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face) ] + + return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] + + #override + def get_MergerConfig(self): + import merger + return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') + +Model = TESTModel diff --git a/models/Model_TEST/Model.py b/models/Model_TEST/Model.py new file mode 100644 index 0000000000000000000000000000000000000000..ae8eb6175e0e60c9c6ba7d97f70207c2bd398f08 --- /dev/null +++ b/models/Model_TEST/Model.py @@ -0,0 +1,724 @@ +import multiprocessing +import operator +from functools import partial + +import numpy as np + +from core import mathlib +from core.interact import interact as io +from core.leras import nn +from facelib import FaceType +from models import ModelBase +from samplelib import * +from core.cv2ex import * + +class AMPModel(ModelBase): + + #override + def on_initialize_options(self): + default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224) + default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') + default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True) + + default_ae_dims = self.options['ae_dims'] = self.load_or_def_option('ae_dims', 256) + default_inter_dims = self.options['inter_dims'] = self.load_or_def_option('inter_dims', 1024) + + default_e_dims = self.options['e_dims'] = self.load_or_def_option('e_dims', 64) + default_d_dims = self.options['d_dims'] = self.options.get('d_dims', None) + default_d_mask_dims = self.options['d_mask_dims'] = self.options.get('d_mask_dims', None) + default_morph_factor = self.options['morph_factor'] = self.options.get('morph_factor', 0.5) + default_uniform_yaw = self.options['uniform_yaw'] = self.load_or_def_option('uniform_yaw', False) + default_blur_out_mask = self.options['blur_out_mask'] = self.load_or_def_option('blur_out_mask', False) + default_lr_dropout = self.options['lr_dropout'] = self.load_or_def_option('lr_dropout', 'n') + default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) + default_ct_mode = self.options['ct_mode'] = self.load_or_def_option('ct_mode', 'none') + default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False) + + ask_override = self.ask_override() + if self.is_first_run() or ask_override: + self.ask_autobackup_hour() + self.ask_write_preview_history() + self.ask_target_iter() + self.ask_random_src_flip() + self.ask_random_dst_flip() + self.ask_batch_size(8) + + if self.is_first_run(): + resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .") + resolution = np.clip ( (resolution // 32) * 32, 64, 640) + self.options['resolution'] = resolution + self.options['face_type'] = io.input_str ("Face type", default_face_type, ['f','wf','head'], help_message="whole face / head").lower() + + + default_d_dims = self.options['d_dims'] = self.load_or_def_option('d_dims', 64) + + default_d_mask_dims = default_d_dims // 3 + default_d_mask_dims += default_d_mask_dims % 2 + default_d_mask_dims = self.options['d_mask_dims'] = self.load_or_def_option('d_mask_dims', default_d_mask_dims) + + if self.is_first_run(): + self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dimensions", default_ae_dims, add_info="32-1024", help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 ) + self.options['inter_dims'] = np.clip ( io.input_int("Inter dimensions", default_inter_dims, add_info="32-2048", help_message="Should be equal or more than AutoEncoder dimensions. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 2048 ) + + e_dims = np.clip ( io.input_int("Encoder dimensions", default_e_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['e_dims'] = e_dims + e_dims % 2 + + d_dims = np.clip ( io.input_int("Decoder dimensions", default_d_dims, add_info="16-256", help_message="More dims help to recognize more facial features and achieve sharper result, but require more VRAM. You can fine-tune model size to fit your GPU." ), 16, 256 ) + self.options['d_dims'] = d_dims + d_dims % 2 + + d_mask_dims = np.clip ( io.input_int("Decoder mask dimensions", default_d_mask_dims, add_info="16-256", help_message="Typical mask dimensions = decoder dimensions / 3. If you manually cut out obstacles from the dst mask, you can increase this parameter to achieve better quality." ), 16, 256 ) + self.options['d_mask_dims'] = d_mask_dims + d_mask_dims % 2 + + morph_factor = np.clip ( io.input_number ("Morph factor.", default_morph_factor, add_info="0.1 .. 0.5", help_message="Typical fine value is 0.5"), 0.1, 0.5 ) + self.options['morph_factor'] = morph_factor + + if self.is_first_run() or ask_override: + self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') + self.options['blur_out_mask'] = io.input_bool ("Blur out mask", default_blur_out_mask, help_message='Blurs nearby area outside of applied face mask of training samples. The result is the background near the face is smoothed and less noticeable on swapped face. The exact xseg mask in src and dst faceset is required.') + self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.") + + default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) + default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8) + default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16) + + if self.is_first_run() or ask_override: + self.options['models_opt_on_gpu'] = io.input_bool ("Place models and optimizer on GPU", default_models_opt_on_gpu, help_message="When you train on one GPU, by default model and optimizer weights are placed on GPU to accelerate the process. You can place they on CPU to free up extra VRAM, thus set bigger dimensions.") + + self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") + + self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 5.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 5.0 ) + + if self.options['gan_power'] != 0.0: + gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 ) + self.options['gan_patch_size'] = gan_patch_size + + gan_dims = np.clip ( io.input_int("GAN dimensions", default_gan_dims, add_info="4-512", help_message="The dimensions of the GAN network. The higher dimensions, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is 16." ), 4, 512 ) + self.options['gan_dims'] = gan_dims + + self.options['ct_mode'] = io.input_str (f"Color transfer for src faceset", default_ct_mode, ['none','rct','lct','mkl','idt','sot'], help_message="Change color distribution of src samples close to dst samples. If src faceset is deverse enough, then lct mode is fine in most cases.") + self.options['clipgrad'] = io.input_bool ("Enable gradient clipping", default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.") + + self.gan_model_changed = (default_gan_patch_size != self.options['gan_patch_size']) or (default_gan_dims != self.options['gan_dims']) + + #override + def on_initialize(self): + device_config = nn.getCurrentDeviceConfig() + devices = device_config.devices + self.model_data_format = "NCHW" + nn.initialize(data_format=self.model_data_format) + tf = nn.tf + + input_ch=3 + resolution = self.resolution = self.options['resolution'] + e_dims = self.options['e_dims'] + ae_dims = self.options['ae_dims'] + inter_dims = self.inter_dims = self.options['inter_dims'] + inter_res = self.inter_res = resolution // 32 + d_dims = self.options['d_dims'] + d_mask_dims = self.options['d_mask_dims'] + face_type = self.face_type = {'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[ self.options['face_type'] ] + morph_factor = self.options['morph_factor'] + gan_power = self.gan_power = self.options['gan_power'] + random_warp = self.options['random_warp'] + + blur_out_mask = self.options['blur_out_mask'] + + ct_mode = self.options['ct_mode'] + if ct_mode == 'none': + ct_mode = None + + use_fp16 = False + if self.is_exporting: + use_fp16 = io.input_bool ("Export quantized?", False, help_message='Makes the exported model faster. If you have problems, disable this option.') + + conv_dtype = tf.float16 if use_fp16 else tf.float32 + + class Downscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=5 ): + self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + return tf.nn.leaky_relu(self.conv1(x), 0.1) + + class Upscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=3 ): + self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, x): + x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2) + return x + + class ResidualBlock(nn.ModelBase): + def on_build(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + + def forward(self, inp): + x = self.conv1(inp) + x = tf.nn.leaky_relu(x, 0.2) + x = self.conv2(x) + x = tf.nn.leaky_relu(inp+x, 0.2) + return x + + class Encoder(nn.ModelBase): + def on_build(self): + self.down1 = Downscale(input_ch, e_dims, kernel_size=5) + self.res1 = ResidualBlock(e_dims) + self.down2 = Downscale(e_dims, e_dims*2, kernel_size=5) + self.down3 = Downscale(e_dims*2, e_dims*4, kernel_size=5) + self.down4 = Downscale(e_dims*4, e_dims*8, kernel_size=5) + self.down5 = Downscale(e_dims*8, e_dims*8, kernel_size=5) + self.res5 = ResidualBlock(e_dims*8) + self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims ) + + def forward(self, x): + if use_fp16: + x = tf.cast(x, tf.float16) + x = self.down1(x) + x = self.res1(x) + x = self.down2(x) + x = self.down3(x) + x = self.down4(x) + x = self.down5(x) + x = self.res5(x) + if use_fp16: + x = tf.cast(x, tf.float32) + x = nn.pixel_norm(nn.flatten(x), axes=-1) + x = self.dense1(x) + return x + + + class Inter(nn.ModelBase): + def on_build(self): + self.dense2 = nn.Dense(ae_dims, inter_res * inter_res * inter_dims) + + def forward(self, inp): + x = inp + x = self.dense2(x) + x = nn.reshape_4D (x, inter_res, inter_res, inter_dims) + return x + + + class Decoder(nn.ModelBase): + def on_build(self ): + self.upscale0 = Upscale(inter_dims, d_dims*8, kernel_size=3) + self.upscale1 = Upscale(d_dims*8, d_dims*8, kernel_size=3) + self.upscale2 = Upscale(d_dims*8, d_dims*4, kernel_size=3) + self.upscale3 = Upscale(d_dims*4, d_dims*2, kernel_size=3) + + self.res0 = ResidualBlock(d_dims*8, kernel_size=3) + self.res1 = ResidualBlock(d_dims*8, kernel_size=3) + self.res2 = ResidualBlock(d_dims*4, kernel_size=3) + self.res3 = ResidualBlock(d_dims*2, kernel_size=3) + + self.upscalem0 = Upscale(inter_dims, d_mask_dims*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_dims*8, d_mask_dims*8, kernel_size=3) + self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3) + self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3) + self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype) + + self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype) + self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype) + + def forward(self, z): + if use_fp16: + z = tf.cast(z, tf.float16) + + x = self.upscale0(z) + x = self.res0(x) + x = self.upscale1(x) + x = self.res1(x) + x = self.upscale2(x) + x = self.res2(x) + x = self.upscale3(x) + x = self.res3(x) + + x = tf.nn.sigmoid( nn.depth_to_space(tf.concat( (self.out_conv(x), + self.out_conv1(x), + self.out_conv2(x), + self.out_conv3(x)), nn.conv2d_ch_axis), 2) ) + m = self.upscalem0(z) + m = self.upscalem1(m) + m = self.upscalem2(m) + m = self.upscalem3(m) + m = self.upscalem4(m) + m = tf.nn.sigmoid(self.out_convm(m)) + + if use_fp16: + x = tf.cast(x, tf.float32) + m = tf.cast(m, tf.float32) + return x, m + + models_opt_on_gpu = False if len(devices) == 0 else self.options['models_opt_on_gpu'] + models_opt_device = nn.tf_default_device_name if models_opt_on_gpu and self.is_training else '/CPU:0' + optimizer_vars_on_cpu = models_opt_device=='/CPU:0' + + bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch) + mask_shape = nn.get4Dshape(resolution,resolution,1) + self.model_filename_list = [] + + with tf.device ('/CPU:0'): + #Place holders on CPU + self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src') + self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst') + + self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src') + self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst') + + self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm') + self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em') + self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm') + self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em') + + self.morph_value_t = tf.placeholder (nn.floatx, (1,), name='morph_value_t') + + # Initializing model classes + with tf.device (models_opt_device): + self.encoder = Encoder(name='encoder') + self.inter_src = Inter(name='inter_src') + self.inter_dst = Inter(name='inter_dst') + self.decoder = Decoder(name='decoder') + + self.model_filename_list += [ [self.encoder, 'encoder.npy'], + [self.inter_src, 'inter_src.npy'], + [self.inter_dst , 'inter_dst.npy'], + [self.decoder , 'decoder.npy'] ] + + if self.is_training: + # Initialize optimizers + clipnorm = 1.0 if self.options['clipgrad'] else 0.0 + lr_dropout = 0.3 if self.options['lr_dropout'] in ['y','cpu'] else 1.0 + + self.G_weights = self.encoder.get_weights() + self.decoder.get_weights() + + #if random_warp: + # self.G_weights += self.inter_src.get_weights() + self.inter_dst.get_weights() + + self.src_dst_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') + self.src_dst_opt.initialize_variables (self.G_weights, vars_on_cpu=optimizer_vars_on_cpu) + self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] + + if gan_power != 0: + self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN") + self.GAN_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt') + self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) + self.model_filename_list += [ [self.GAN, 'GAN.npy'], + [self.GAN_opt, 'GAN_opt.npy'] ] + + if self.is_training: + # Adjust batch size for multiple GPU + gpu_count = max(1, len(devices) ) + bs_per_gpu = max(1, self.get_batch_size() // gpu_count) + self.set_batch_size( gpu_count*bs_per_gpu) + + # Compute losses per GPU + gpu_pred_src_src_list = [] + gpu_pred_dst_dst_list = [] + gpu_pred_src_dst_list = [] + gpu_pred_src_srcm_list = [] + gpu_pred_dst_dstm_list = [] + gpu_pred_src_dstm_list = [] + + gpu_src_losses = [] + gpu_dst_losses = [] + gpu_G_loss_gradients = [] + gpu_GAN_loss_gradients = [] + + def DLossOnes(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(logits), logits=logits), axis=[1,2,3]) + + def DLossZeros(logits): + return tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(logits), logits=logits), axis=[1,2,3]) + + for gpu_id in range(gpu_count): + with tf.device( f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + with tf.device(f'/CPU:0'): + # slice on CPU, otherwise all batch data will be transfered to GPU first + batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) + gpu_warped_src = self.warped_src [batch_slice,:,:,:] + gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] + gpu_target_src = self.target_src [batch_slice,:,:,:] + gpu_target_dst = self.target_dst [batch_slice,:,:,:] + gpu_target_srcm = self.target_srcm[batch_slice,:,:,:] + gpu_target_srcm_em = self.target_srcm_em[batch_slice,:,:,:] + gpu_target_dstm = self.target_dstm[batch_slice,:,:,:] + gpu_target_dstm_em = self.target_dstm_em[batch_slice,:,:,:] + + # process model tensors + gpu_src_code = self.encoder (gpu_warped_src) + gpu_dst_code = self.encoder (gpu_warped_dst) + + gpu_src_inter_src_code, gpu_src_inter_dst_code = self.inter_src (gpu_src_code), self.inter_dst (gpu_src_code) + gpu_dst_inter_src_code, gpu_dst_inter_dst_code = self.inter_src (gpu_dst_code), self.inter_dst (gpu_dst_code) + + inter_dims_bin = int(inter_dims*morph_factor) + with tf.device(f'/CPU:0'): + inter_rnd_binomial = tf.stack([tf.random.shuffle(tf.concat([tf.tile(tf.constant([1], tf.float32), ( inter_dims_bin, )), + tf.tile(tf.constant([0], tf.float32), ( inter_dims-inter_dims_bin, ))], 0 )) for _ in range(bs_per_gpu)], 0) + + inter_rnd_binomial = tf.stop_gradient(inter_rnd_binomial[...,None,None]) + + gpu_src_code = gpu_src_inter_src_code * inter_rnd_binomial + gpu_src_inter_dst_code * (1-inter_rnd_binomial) + gpu_dst_code = gpu_dst_inter_dst_code + + inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) + gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) + + gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) + gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + + gpu_pred_src_src_list.append(gpu_pred_src_src), gpu_pred_src_srcm_list.append(gpu_pred_src_srcm) + gpu_pred_dst_dst_list.append(gpu_pred_dst_dst), gpu_pred_dst_dstm_list.append(gpu_pred_dst_dstm) + gpu_pred_src_dst_list.append(gpu_pred_src_dst), gpu_pred_src_dstm_list.append(gpu_pred_src_dstm) + + gpu_target_srcm_anti = 1-gpu_target_srcm + gpu_target_dstm_anti = 1-gpu_target_dstm + + gpu_target_srcm_gblur = nn.gaussian_blur(gpu_target_srcm, resolution // 32) + gpu_target_dstm_gblur = nn.gaussian_blur(gpu_target_dstm, resolution // 32) + + gpu_target_srcm_blur = tf.clip_by_value(gpu_target_srcm_gblur, 0, 0.5) * 2 + gpu_target_dstm_blur = tf.clip_by_value(gpu_target_dstm_gblur, 0, 0.5) * 2 + gpu_target_srcm_anti_blur = 1.0-gpu_target_srcm_blur + gpu_target_dstm_anti_blur = 1.0-gpu_target_dstm_blur + + if blur_out_mask: + sigma = resolution / 128 + + x = nn.gaussian_blur(gpu_target_src*gpu_target_srcm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_srcm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_src = gpu_target_src*gpu_target_srcm + (x/y)*gpu_target_srcm_anti + + x = nn.gaussian_blur(gpu_target_dst*gpu_target_dstm_anti, sigma) + y = 1-nn.gaussian_blur(gpu_target_dstm, sigma) + y = tf.where(tf.equal(y, 0), tf.ones_like(y), y) + gpu_target_dst = gpu_target_dst*gpu_target_dstm + (x/y)*gpu_target_dstm_anti + + gpu_target_src_masked = gpu_target_src*gpu_target_srcm_blur + gpu_target_dst_masked = gpu_target_dst*gpu_target_dstm_blur + gpu_target_src_anti_masked = gpu_target_src*gpu_target_srcm_anti_blur + gpu_target_dst_anti_masked = gpu_target_dst*gpu_target_dstm_anti_blur + + gpu_pred_src_src_masked = gpu_pred_src_src*gpu_target_srcm_blur + gpu_pred_dst_dst_masked = gpu_pred_dst_dst*gpu_target_dstm_blur + gpu_pred_src_src_anti_masked = gpu_pred_src_src*gpu_target_srcm_anti_blur + gpu_pred_dst_dst_anti_masked = gpu_pred_dst_dst*gpu_target_dstm_anti_blur + + # Structural loss + gpu_src_loss = tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean (5*nn.dssim(gpu_target_src_masked, gpu_pred_src_src_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + gpu_dst_loss = tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean (5*nn.dssim(gpu_target_dst_masked, gpu_pred_dst_dst_masked, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) + + # Pixel loss + gpu_src_loss += tf.reduce_mean (10*tf.square(gpu_target_src_masked-gpu_pred_src_src_masked), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean (10*tf.square(gpu_target_dst_masked-gpu_pred_dst_dst_masked), axis=[1,2,3]) + + # Eyes+mouth prio loss + gpu_src_loss += tf.reduce_mean (300*tf.abs (gpu_target_src*gpu_target_srcm_em-gpu_pred_src_src*gpu_target_srcm_em), axis=[1,2,3]) + gpu_dst_loss += tf.reduce_mean (300*tf.abs (gpu_target_dst*gpu_target_dstm_em-gpu_pred_dst_dst*gpu_target_dstm_em), axis=[1,2,3]) + + # Mask loss + gpu_src_loss += tf.reduce_mean ( 10*tf.square( gpu_target_srcm - gpu_pred_src_srcm ),axis=[1,2,3] ) + gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) + + gpu_src_losses += [gpu_src_loss] + gpu_dst_losses += [gpu_dst_loss] + gpu_G_loss = gpu_src_loss + gpu_dst_loss + # dst-dst background weak loss + gpu_G_loss += tf.reduce_mean(0.1*tf.square(gpu_pred_dst_dst_anti_masked-gpu_target_dst_anti_masked),axis=[1,2,3] ) + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_dst_dst_anti_masked) + + + if gan_power != 0: + gpu_pred_src_src_d, gpu_pred_src_src_d2 = self.GAN(gpu_pred_src_src_masked) + gpu_pred_dst_dst_d, gpu_pred_dst_dst_d2 = self.GAN(gpu_pred_dst_dst_masked) + gpu_target_src_d, gpu_target_src_d2 = self.GAN(gpu_target_src_masked) + gpu_target_dst_d, gpu_target_dst_d2 = self.GAN(gpu_target_dst_masked) + + gpu_GAN_loss = (DLossOnes (gpu_target_src_d) + DLossOnes (gpu_target_src_d2) + \ + DLossZeros(gpu_pred_src_src_d) + DLossZeros(gpu_pred_src_src_d2) + \ + DLossOnes (gpu_target_dst_d) + DLossOnes (gpu_target_dst_d2) + \ + DLossZeros(gpu_pred_dst_dst_d) + DLossZeros(gpu_pred_dst_dst_d2) + ) * (1.0 / 8) + + gpu_GAN_loss_gradients += [ nn.gradients (gpu_GAN_loss, self.GAN.get_weights() ) ] + + gpu_G_loss += (DLossOnes(gpu_pred_src_src_d) + DLossOnes(gpu_pred_src_src_d2) + \ + DLossOnes(gpu_pred_dst_dst_d) + DLossOnes(gpu_pred_dst_dst_d2) + ) * gan_power + + # Minimal src-src-bg rec with total_variation_mse to suppress random bright dots from gan + gpu_G_loss += 0.000001*nn.total_variation_mse(gpu_pred_src_src) + gpu_G_loss += 0.02*tf.reduce_mean(tf.square(gpu_pred_src_src_anti_masked-gpu_target_src_anti_masked),axis=[1,2,3] ) + + gpu_G_loss_gradients += [ nn.gradients ( gpu_G_loss, self.G_weights ) ] + + # Average losses and gradients, and create optimizer update ops + with tf.device(f'/CPU:0'): + pred_src_src = nn.concat(gpu_pred_src_src_list, 0) + pred_dst_dst = nn.concat(gpu_pred_dst_dst_list, 0) + pred_src_dst = nn.concat(gpu_pred_src_dst_list, 0) + pred_src_srcm = nn.concat(gpu_pred_src_srcm_list, 0) + pred_dst_dstm = nn.concat(gpu_pred_dst_dstm_list, 0) + pred_src_dstm = nn.concat(gpu_pred_src_dstm_list, 0) + + with tf.device (models_opt_device): + src_loss = tf.concat(gpu_src_losses, 0) + dst_loss = tf.concat(gpu_dst_losses, 0) + train_op = self.src_dst_opt.get_update_op (nn.average_gv_list (gpu_G_loss_gradients)) + + if gan_power != 0: + GAN_train_op = self.GAN_opt.get_update_op (nn.average_gv_list(gpu_GAN_loss_gradients) ) + + # Initializing training and view functions + def train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + s, d, _ = nn.tf_sess.run ([src_loss, dst_loss, train_op], + feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em, + }) + return s, d + self.train = train + + if gan_power != 0: + def GAN_train(warped_src, target_src, target_srcm, target_srcm_em, \ + warped_dst, target_dst, target_dstm, target_dstm_em, ): + nn.tf_sess.run ([GAN_train_op], feed_dict={self.warped_src :warped_src, + self.target_src :target_src, + self.target_srcm:target_srcm, + self.target_srcm_em:target_srcm_em, + self.warped_dst :warped_dst, + self.target_dst :target_dst, + self.target_dstm:target_dstm, + self.target_dstm_em:target_dstm_em}) + self.GAN_train = GAN_train + + def AE_view(warped_src, warped_dst, morph_value): + return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], + feed_dict={self.warped_src:warped_src, self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) + + self.AE_view = AE_view + else: + #Initializing merge function + with tf.device( nn.tf_default_device_name if len(devices) != 0 else f'/CPU:0'): + gpu_dst_code = self.encoder (self.warped_dst) + gpu_dst_inter_src_code = self.inter_src (gpu_dst_code) + gpu_dst_inter_dst_code = self.inter_dst (gpu_dst_code) + + inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) + gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , inter_res, inter_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, inter_res,inter_res]) ), 1 ) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code) + + def AE_merge(warped_dst, morph_value): + return nn.tf_sess.run ( [gpu_pred_src_dst, gpu_pred_dst_dstm, gpu_pred_src_dstm], feed_dict={self.warped_dst:warped_dst, self.morph_value_t:[morph_value] }) + + self.AE_merge = AE_merge + + # Loading/initializing all models/optimizers weights + for model, filename in io.progress_bar_generator(self.model_filename_list, "Initializing models"): + do_init = self.is_first_run() + if self.is_training and gan_power != 0 and model == self.GAN: + if self.gan_model_changed: + do_init = True + if not do_init: + do_init = not model.load_weights( self.get_strpath_storage_for_file(filename) ) + if do_init: + model.init_weights() + ############### + + # initializing sample generators + if self.is_training: + training_data_src_path = self.training_data_src_path #if not self.pretrain else self.get_pretraining_data_path() + training_data_dst_path = self.training_data_dst_path #if not self.pretrain else self.get_pretraining_data_path() + + random_ct_samples_path=training_data_dst_path if ct_mode is not None else None #and not self.pretrain + + cpu_count = multiprocessing.cpu_count() + src_generators_count = cpu_count // 2 + dst_generators_count = cpu_count // 2 + if ct_mode is not None: + src_generators_count = int(src_generators_count * 1.5) + + + + self.set_training_data_generators ([ + SampleGeneratorFace(training_data_src_path, random_ct_samples_path=random_ct_samples_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=self.random_src_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': ct_mode, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, + generators_count=src_generators_count ), + + SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=self.random_dst_flip), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':random_warp, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.EYES_MOUTH, 'face_type':face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=self.options['uniform_yaw'],# or self.pretrain, + generators_count=dst_generators_count ) + ]) + + def export_dfm (self): + output_path=self.get_strpath_storage_for_file('model.dfm') + + io.log_info(f'Dumping .dfm to {output_path}') + + tf = nn.tf + with tf.device (nn.tf_default_device_name): + warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') + warped_dst = tf.transpose(warped_dst, (0,3,1,2)) + morph_value = tf.placeholder (nn.floatx, (1,), name='morph_value') + + gpu_dst_code = self.encoder (warped_dst) + gpu_dst_inter_src_code = self.inter_src ( gpu_dst_code) + gpu_dst_inter_dst_code = self.inter_dst ( gpu_dst_code) + + inter_dims_slice = tf.cast(self.inter_dims*morph_value[0], tf.int32) + gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , self.inter_res, self.inter_res]), + tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,self.inter_dims-inter_dims_slice, self.inter_res,self.inter_res]) ), 1 ) + + gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) + _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code) + + gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1)) + gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1)) + gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1)) + + tf.identity(gpu_pred_dst_dstm, name='out_face_mask') + tf.identity(gpu_pred_src_dst, name='out_celeb_face') + tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask') + + output_graph_def = tf.graph_util.convert_variables_to_constants( + nn.tf_sess, + tf.get_default_graph().as_graph_def(), + ['out_face_mask','out_celeb_face','out_celeb_face_mask'] + ) + + import tf2onnx + with tf.device("/CPU:0"): + model_proto, _ = tf2onnx.convert._convert_common( + output_graph_def, + name='AMP', + input_names=['in_face:0','morph_value:0'], + output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'], + opset=9, + output_path=output_path) + + #override + def get_model_filename_list(self): + return self.model_filename_list + + #override + def onSave(self): + for model, filename in io.progress_bar_generator(self.get_model_filename_list(), "Saving", leave=False): + model.save_weights ( self.get_strpath_storage_for_file(filename) ) + + #override + def should_save_preview_history(self): + return (not io.is_colab() and self.iter % ( 10*(max(1,self.resolution // 64)) ) == 0) or \ + (io.is_colab() and self.iter % 100 == 0) + + #override + def onTrainOneIter(self): + bs = self.get_batch_size() + + ( (warped_src, target_src, target_srcm, target_srcm_em), \ + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() + + src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + if self.gan_power != 0: + self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) + + return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) + + #override + def onGetPreview(self, samples, for_history=False): + ( (warped_src, target_src, target_srcm, target_srcm_em), + (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples + + S, D, SS, DD, DDM_000, _, _ = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst, 0.0) ) ] + + _, _, DDM_025, SD_025, SDM_025 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.25) ] + _, _, DDM_050, SD_050, SDM_050 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.50) ] + _, _, DDM_065, SD_065, SDM_065 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.65) ] + _, _, DDM_075, SD_075, SDM_075 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 0.75) ] + _, _, DDM_100, SD_100, SDM_100 = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in self.AE_view (target_src, target_dst, 1.00) ] + + (DDM_000, + DDM_025, SDM_025, + DDM_050, SDM_050, + DDM_065, SDM_065, + DDM_075, SDM_075, + DDM_100, SDM_100) = [ np.repeat (x, (3,), -1) for x in (DDM_000, + DDM_025, SDM_025, + DDM_050, SDM_050, + DDM_065, SDM_065, + DDM_075, SDM_075, + DDM_100, SDM_100) ] + + target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] + + n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) + + result = [] + + i = np.random.randint(n_samples) if not for_history else 0 + + st = [ np.concatenate ((S[i], D[i], DD[i]*DDM_000[i]), axis=1) ] + st += [ np.concatenate ((SS[i], DD[i], SD_100[i] ), axis=1) ] + + result += [ ('AMP morph 1.0', np.concatenate (st, axis=0 )), ] + + st = [ np.concatenate ((DD[i], SD_025[i], SD_050[i]), axis=1) ] + st += [ np.concatenate ((SD_065[i], SD_075[i], SD_100[i]), axis=1) ] + result += [ ('AMP morph list', np.concatenate (st, axis=0 )), ] + + st = [ np.concatenate ((DD[i], SD_025[i]*DDM_025[i]*SDM_025[i], SD_050[i]*DDM_050[i]*SDM_050[i]), axis=1) ] + st += [ np.concatenate ((SD_065[i]*DDM_065[i]*SDM_065[i], SD_075[i]*DDM_075[i]*SDM_075[i], SD_100[i]*DDM_100[i]*SDM_100[i]), axis=1) ] + result += [ ('AMP morph list masked', np.concatenate (st, axis=0 )), ] + + return result + + def predictor_func (self, face, morph_value): + face = nn.to_data_format(face[None,...], self.model_data_format, "NHWC") + + bgr, mask_dst_dstm, mask_src_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format).astype(np.float32) for x in self.AE_merge (face, morph_value) ] + + return bgr[0], mask_src_dstm[0][...,0], mask_dst_dstm[0][...,0] + + #override + def get_MergerConfig(self): + morph_factor = np.clip ( io.input_number ("Morph factor", 1.0, add_info="0.0 .. 1.0"), 0.0, 1.0 ) + + def predictor_morph(face): + return self.predictor_func(face, morph_factor) + + + import merger + return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') + +Model = AMPModel diff --git a/models/Model_TEST/__init__.py b/models/Model_TEST/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0188f11aec7882710edf9d40586a00823f0d8c20 --- /dev/null +++ b/models/Model_TEST/__init__.py @@ -0,0 +1 @@ +from .Model import Model diff --git a/models/Model_XSeg/Model.py b/models/Model_XSeg/Model.py new file mode 100644 index 0000000000000000000000000000000000000000..b0addfd7dc44931b09146836d2b3b8b2627a6338 --- /dev/null +++ b/models/Model_XSeg/Model.py @@ -0,0 +1,283 @@ +import multiprocessing +import operator +from functools import partial + +import numpy as np + +from core import mathlib +from core.interact import interact as io +from core.leras import nn +from facelib import FaceType, XSegNet +from models import ModelBase +from samplelib import * + +class XSegModel(ModelBase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, force_model_class_name='XSeg', **kwargs) + + #override + def on_initialize_options(self): + ask_override = self.ask_override() + + if not self.is_first_run() and ask_override: + if io.input_bool(f"Restart training?", False, help_message="Reset model weights and start training from scratch."): + self.set_iter(0) + + default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf') + default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False) + + if self.is_first_run(): + self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower() + + if self.is_first_run() or ask_override: + self.ask_batch_size(4, range=[2,16]) + self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain) + + if not self.is_exporting and (self.options['pretrain'] and self.get_pretraining_data_path() is None): + raise Exception("pretraining_data_path is not defined") + + self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False) + + #override + def on_initialize(self): + device_config = nn.getCurrentDeviceConfig() + self.model_data_format = "NCHW" if self.is_exporting or (len(device_config.devices) != 0 and not self.is_debug()) else "NHWC" + nn.initialize(data_format=self.model_data_format) + tf = nn.tf + + device_config = nn.getCurrentDeviceConfig() + devices = device_config.devices + + self.resolution = resolution = 256 + + + self.face_type = {'h' : FaceType.HALF, + 'mf' : FaceType.MID_FULL, + 'f' : FaceType.FULL, + 'wf' : FaceType.WHOLE_FACE, + 'head' : FaceType.HEAD}[ self.options['face_type'] ] + + + place_model_on_cpu = len(devices) == 0 + models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name + + bgr_shape = nn.get4Dshape(resolution,resolution,3) + mask_shape = nn.get4Dshape(resolution,resolution,1) + + # Initializing model classes + self.model = XSegNet(name='XSeg', + resolution=resolution, + load_weights=not self.is_first_run(), + weights_file_root=self.get_model_root_path(), + training=True, + place_model_on_cpu=place_model_on_cpu, + optimizer=nn.RMSprop(lr=0.0001, lr_dropout=0.3, name='opt'), + data_format=nn.data_format) + + self.pretrain = self.options['pretrain'] + if self.pretrain_just_disabled: + self.set_iter(0) + + if self.is_training: + # Adjust batch size for multiple GPU + gpu_count = max(1, len(devices) ) + bs_per_gpu = max(1, self.get_batch_size() // gpu_count) + self.set_batch_size( gpu_count*bs_per_gpu) + + # Compute losses per GPU + gpu_pred_list = [] + + gpu_losses = [] + gpu_loss_gvs = [] + + for gpu_id in range(gpu_count): + with tf.device(f'/{devices[gpu_id].tf_dev_type}:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): + with tf.device(f'/CPU:0'): + # slice on CPU, otherwise all batch data will be transfered to GPU first + batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) + gpu_input_t = self.model.input_t [batch_slice,:,:,:] + gpu_target_t = self.model.target_t [batch_slice,:,:,:] + + # process model tensors + gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t, pretrain=self.pretrain) + gpu_pred_list.append(gpu_pred_t) + + + if self.pretrain: + # Structural loss + gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + # Pixel loss + gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t-gpu_pred_t), axis=[1,2,3]) + else: + gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3]) + + gpu_losses += [gpu_loss] + + gpu_loss_gvs += [ nn.gradients ( gpu_loss, self.model.get_weights() ) ] + + + # Average losses and gradients, and create optimizer update ops + #with tf.device(f'/CPU:0'): # Temporary fix. Unknown bug with training freeze starts from 2.4.0, but 2.3.1 was ok + with tf.device (models_opt_device): + pred = tf.concat(gpu_pred_list, 0) + loss = tf.concat(gpu_losses, 0) + loss_gv_op = self.model.opt.get_update_op (nn.average_gv_list (gpu_loss_gvs)) + + + # Initializing training and view functions + if self.pretrain: + def train(input_np, target_np): + l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np}) + return l + else: + def train(input_np, target_np): + l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np }) + return l + self.train = train + + def view(input_np): + return nn.tf_sess.run ( [pred], feed_dict={self.model.input_t :input_np}) + self.view = view + + # initializing sample generators + cpu_count = min(multiprocessing.cpu_count(), 8) + src_dst_generators_count = cpu_count // 2 + src_generators_count = cpu_count // 2 + dst_generators_count = cpu_count // 2 + + if self.pretrain: + pretrain_gen = SampleGeneratorFace(self.get_pretraining_data_path(), debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=True), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + uniform_yaw_distribution=False, + generators_count=cpu_count ) + self.set_training_data_generators ([pretrain_gen]) + else: + srcdst_generator = SampleGeneratorFaceXSeg([self.training_data_src_path, self.training_data_dst_path], + debug=self.is_debug(), + batch_size=self.get_batch_size(), + resolution=resolution, + face_type=self.face_type, + generators_count=src_dst_generators_count, + data_format=nn.data_format) + + src_generator = SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=False), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + generators_count=src_generators_count, + raise_on_no_data=False ) + dst_generator = SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(), + sample_process_options=SampleProcessor.Options(random_flip=False), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':False, 'channel_type' : SampleProcessor.ChannelType.BGR, 'border_replicate':False, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution}, + ], + generators_count=dst_generators_count, + raise_on_no_data=False ) + + self.set_training_data_generators ([srcdst_generator, src_generator, dst_generator]) + + #override + def get_model_filename_list(self): + return self.model.model_filename_list + + #override + def onSave(self): + self.model.save_weights() + + #override + def onTrainOneIter(self): + image_np, target_np = self.generate_next_samples()[0] + loss = self.train (image_np, target_np) + + return ( ('loss', np.mean(loss) ), ) + + #override + def onGetPreview(self, samples, for_history=False): + n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) + + if self.pretrain: + srcdst_samples, = samples + image_np, mask_np = srcdst_samples + else: + srcdst_samples, src_samples, dst_samples = samples + image_np, mask_np = srcdst_samples + + I, M, IM, = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([image_np,mask_np] + self.view (image_np) ) ] + M, IM, = [ np.repeat (x, (3,), -1) for x in [M, IM] ] + + green_bg = np.tile( np.array([0,1,0], dtype=np.float32)[None,None,...], (self.resolution,self.resolution,1) ) + + result = [] + st = [] + for i in range(n_samples): + if self.pretrain: + ar = I[i], IM[i] + else: + ar = I[i]*M[i]+0.5*I[i]*(1-M[i])+0.5*green_bg*(1-M[i]), IM[i], I[i]*IM[i]+0.5*I[i]*(1-IM[i]) + 0.5*green_bg*(1-IM[i]) + st.append ( np.concatenate ( ar, axis=1) ) + result += [ ('XSeg training faces', np.concatenate (st, axis=0 )), ] + + if not self.pretrain and len(src_samples) != 0: + src_np, = src_samples + + + D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([src_np] + self.view (src_np) ) ] + DM, = [ np.repeat (x, (3,), -1) for x in [DM] ] + + st = [] + for i in range(n_samples): + ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i]) + st.append ( np.concatenate ( ar, axis=1) ) + + result += [ ('XSeg src faces', np.concatenate (st, axis=0 )), ] + + if not self.pretrain and len(dst_samples) != 0: + dst_np, = dst_samples + + + D, DM, = [ np.clip(nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([dst_np] + self.view (dst_np) ) ] + DM, = [ np.repeat (x, (3,), -1) for x in [DM] ] + + st = [] + for i in range(n_samples): + ar = D[i], DM[i], D[i]*DM[i] + 0.5*D[i]*(1-DM[i]) + 0.5*green_bg*(1-DM[i]) + st.append ( np.concatenate ( ar, axis=1) ) + + result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ] + + return result + + def export_dfm (self): + output_path = self.get_strpath_storage_for_file(f'model.onnx') + io.log_info(f'Dumping .onnx to {output_path}') + tf = nn.tf + + with tf.device (nn.tf_default_device_name): + input_t = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face') + input_t = tf.transpose(input_t, (0,3,1,2)) + _, pred_t = self.model.flow(input_t) + pred_t = tf.transpose(pred_t, (0,2,3,1)) + + tf.identity(pred_t, name='out_mask') + + output_graph_def = tf.graph_util.convert_variables_to_constants( + nn.tf_sess, + tf.get_default_graph().as_graph_def(), + ['out_mask'] + ) + + import tf2onnx + with tf.device("/CPU:0"): + model_proto, _ = tf2onnx.convert._convert_common( + output_graph_def, + name='XSeg', + input_names=['in_face:0'], + output_names=['out_mask:0'], + opset=13, + output_path=output_path) + +Model = XSegModel \ No newline at end of file diff --git a/models/Model_XSeg/__init__.py b/models/Model_XSeg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0188f11aec7882710edf9d40586a00823f0d8c20 --- /dev/null +++ b/models/Model_XSeg/__init__.py @@ -0,0 +1 @@ +from .Model import Model diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..490e9c8e7b0080f06c2a8cf3311c1141f2967848 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,5 @@ +from .ModelBase import ModelBase + +def import_model(model_class_name): + module = __import__('Model_'+model_class_name, globals(), locals(), [], 1) + return getattr(module, 'Model') diff --git a/models/trash/models.zip b/models/trash/models.zip new file mode 100644 index 0000000000000000000000000000000000000000..6a4c5ae66e10bd70751a6b960d9cc4258ca14d19 --- /dev/null +++ b/models/trash/models.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8089b3122e894d7d09646e47259e6166fd652efacb7d9673782c916d6329a2fa +size 226142 diff --git a/requirements-colab.txt b/requirements-colab.txt new file mode 100644 index 0000000000000000000000000000000000000000..33546b2a754869de14af41f4754c4da64fc05a56 --- /dev/null +++ b/requirements-colab.txt @@ -0,0 +1,11 @@ +tqdm +numpy==1.19.3 +numexpr +h5py==2.10.0 +opencv-python==4.1.0.25 +ffmpeg-python==0.1.17 +scikit-image==0.14.2 +scipy==1.4.1 +colorama +tensorflow-gpu==2.4.0 +tf2onnx==1.9.3 \ No newline at end of file diff --git a/requirements-cuda.txt b/requirements-cuda.txt new file mode 100644 index 0000000000000000000000000000000000000000..b70520dadd3e4e1e7829801fe2c838c67daaa6cf --- /dev/null +++ b/requirements-cuda.txt @@ -0,0 +1,12 @@ +tqdm +numpy==1.19.3 +numexpr +h5py==2.10.0 +opencv-python==4.1.0.25 +ffmpeg-python==0.1.17 +scikit-image==0.14.2 +scipy==1.4.1 +colorama +tensorflow-gpu==2.4.0 +pyqt5 +tf2onnx==1.9.3 \ No newline at end of file diff --git a/samplelib/PackedFaceset.py b/samplelib/PackedFaceset.py new file mode 100644 index 0000000000000000000000000000000000000000..e7ae1d48f196cdd4200ef8c6e5340b860e4ff2df --- /dev/null +++ b/samplelib/PackedFaceset.py @@ -0,0 +1,156 @@ +import pickle +import shutil +import struct +from pathlib import Path + +import samplelib.SampleLoader +from core.interact import interact as io +from samplelib import Sample +from core import pathex + +packed_faceset_filename = 'faceset.pak' + +class PackedFaceset(): + VERSION = 1 + + @staticmethod + def pack(samples_path): + samples_dat_path = samples_path / packed_faceset_filename + + if samples_dat_path.exists(): + io.log_info(f"{samples_dat_path} : file already exists !") + io.input("Press enter to continue and overwrite.") + + as_person_faceset = False + dir_names = pathex.get_all_dir_names(samples_path) + if len(dir_names) != 0: + as_person_faceset = io.input_bool(f"{len(dir_names)} subdirectories found, process as person faceset?", True) + + if as_person_faceset: + image_paths = [] + + for dir_name in dir_names: + image_paths += pathex.get_image_paths(samples_path / dir_name) + else: + image_paths = pathex.get_image_paths(samples_path) + + samples = samplelib.SampleLoader.load_face_samples(image_paths) + samples_len = len(samples) + + samples_configs = [] + for sample in io.progress_bar_generator (samples, "Processing"): + sample_filepath = Path(sample.filename) + sample.filename = sample_filepath.name + + if as_person_faceset: + sample.person_name = sample_filepath.parent.name + samples_configs.append ( sample.get_config() ) + samples_bytes = pickle.dumps(samples_configs, 4) + + of = open(samples_dat_path, "wb") + of.write ( struct.pack ("Q", PackedFaceset.VERSION ) ) + of.write ( struct.pack ("Q", len(samples_bytes) ) ) + of.write ( samples_bytes ) + + del samples_bytes #just free mem + del samples_configs + + sample_data_table_offset = of.tell() + of.write ( bytes( 8*(samples_len+1) ) ) #sample data offset table + + data_start_offset = of.tell() + offsets = [] + + for sample in io.progress_bar_generator(samples, "Packing"): + try: + if sample.person_name is not None: + sample_path = samples_path / sample.person_name / sample.filename + else: + sample_path = samples_path / sample.filename + + + with open(sample_path, "rb") as f: + b = f.read() + + offsets.append ( of.tell() - data_start_offset ) + of.write(b) + except: + raise Exception(f"error while processing sample {sample_path}") + + offsets.append ( of.tell() ) + + of.seek(sample_data_table_offset, 0) + for offset in offsets: + of.write ( struct.pack("Q", offset) ) + of.seek(0,2) + of.close() + + if io.input_bool(f"Delete original files?", True): + for filename in io.progress_bar_generator(image_paths, "Deleting files"): + Path(filename).unlink() + + if as_person_faceset: + for dir_name in io.progress_bar_generator(dir_names, "Deleting dirs"): + dir_path = samples_path / dir_name + try: + shutil.rmtree(dir_path) + except: + io.log_info (f"unable to remove: {dir_path} ") + + @staticmethod + def unpack(samples_path): + samples_dat_path = samples_path / packed_faceset_filename + if not samples_dat_path.exists(): + io.log_info(f"{samples_dat_path} : file not found.") + return + + samples = PackedFaceset.load(samples_path) + + for sample in io.progress_bar_generator(samples, "Unpacking"): + person_name = sample.person_name + if person_name is not None: + person_path = samples_path / person_name + person_path.mkdir(parents=True, exist_ok=True) + + target_filepath = person_path / sample.filename + else: + target_filepath = samples_path / sample.filename + + with open(target_filepath, "wb") as f: + f.write( sample.read_raw_file() ) + + samples_dat_path.unlink() + + @staticmethod + def path_contains(samples_path): + samples_dat_path = samples_path / packed_faceset_filename + return samples_dat_path.exists() + + @staticmethod + def load(samples_path): + samples_dat_path = samples_path / packed_faceset_filename + if not samples_dat_path.exists(): + return None + + f = open(samples_dat_path, "rb") + version, = struct.unpack("Q", f.read(8) ) + if version != PackedFaceset.VERSION: + raise NotImplementedError + + sizeof_samples_bytes, = struct.unpack("Q", f.read(8) ) + + samples_configs = pickle.loads ( f.read(sizeof_samples_bytes) ) + samples = [] + for sample_config in samples_configs: + sample_config = pickle.loads(pickle.dumps (sample_config)) + samples.append ( Sample (**sample_config) ) + + offsets = [ struct.unpack("Q", f.read(8) )[0] for _ in range(len(samples)+1) ] + data_start_offset = f.tell() + f.close() + + for i, sample in enumerate(samples): + start_offset, end_offset = offsets[i], offsets[i+1] + sample.set_filename_offset_size( str(samples_dat_path), data_start_offset+start_offset, end_offset-start_offset ) + + return samples diff --git a/samplelib/Sample.py b/samplelib/Sample.py new file mode 100644 index 0000000000000000000000000000000000000000..a379275b14749d42695997eeda7b85ed218a69eb --- /dev/null +++ b/samplelib/Sample.py @@ -0,0 +1,127 @@ +from enum import IntEnum +from pathlib import Path + +import cv2 +import numpy as np + +from core.cv2ex import * +from facelib import LandmarksProcessor +from core import imagelib +from core.imagelib import SegIEPolys + +class SampleType(IntEnum): + IMAGE = 0 #raw image + + FACE_BEGIN = 1 + FACE = 1 #aligned face unsorted + FACE_PERSON = 2 #aligned face person + FACE_TEMPORAL_SORTED = 3 #sorted by source filename + FACE_END = 3 + + QTY = 4 + +class Sample(object): + __slots__ = ['sample_type', + 'filename', + 'face_type', + 'shape', + 'landmarks', + 'seg_ie_polys', + 'xseg_mask', + 'xseg_mask_compressed', + 'eyebrows_expand_mod', + 'source_filename', + 'person_name', + 'pitch_yaw_roll', + '_filename_offset_size', + ] + + def __init__(self, sample_type=None, + filename=None, + face_type=None, + shape=None, + landmarks=None, + seg_ie_polys=None, + xseg_mask=None, + xseg_mask_compressed=None, + eyebrows_expand_mod=None, + source_filename=None, + person_name=None, + pitch_yaw_roll=None, + **kwargs): + + self.sample_type = sample_type if sample_type is not None else SampleType.IMAGE + self.filename = filename + self.face_type = face_type + self.shape = shape + self.landmarks = np.array(landmarks) if landmarks is not None else None + + if isinstance(seg_ie_polys, SegIEPolys): + self.seg_ie_polys = seg_ie_polys + else: + self.seg_ie_polys = SegIEPolys.load(seg_ie_polys) + + self.xseg_mask = xseg_mask + self.xseg_mask_compressed = xseg_mask_compressed + + if self.xseg_mask_compressed is None and self.xseg_mask is not None: + xseg_mask = np.clip( imagelib.normalize_channels(xseg_mask, 1)*255, 0, 255 ).astype(np.uint8) + ret, xseg_mask_compressed = cv2.imencode('.png', xseg_mask) + if not ret: + raise Exception("Sample(): unable to generate xseg_mask_compressed") + self.xseg_mask_compressed = xseg_mask_compressed + self.xseg_mask = None + + self.eyebrows_expand_mod = eyebrows_expand_mod if eyebrows_expand_mod is not None else 1.0 + self.source_filename = source_filename + self.person_name = person_name + self.pitch_yaw_roll = pitch_yaw_roll + + self._filename_offset_size = None + + def has_xseg_mask(self): + return self.xseg_mask is not None or self.xseg_mask_compressed is not None + + def get_xseg_mask(self): + if self.xseg_mask_compressed is not None: + xseg_mask = cv2.imdecode(self.xseg_mask_compressed, cv2.IMREAD_UNCHANGED) + if len(xseg_mask.shape) == 2: + xseg_mask = xseg_mask[...,None] + return xseg_mask.astype(np.float32) / 255.0 + return self.xseg_mask + + def get_pitch_yaw_roll(self): + if self.pitch_yaw_roll is None: + self.pitch_yaw_roll = LandmarksProcessor.estimate_pitch_yaw_roll(self.landmarks, size=self.shape[1]) + return self.pitch_yaw_roll + + def set_filename_offset_size(self, filename, offset, size): + self._filename_offset_size = (filename, offset, size) + + def read_raw_file(self, filename=None): + if self._filename_offset_size is not None: + filename, offset, size = self._filename_offset_size + with open(filename, "rb") as f: + f.seek( offset, 0) + return f.read (size) + else: + with open(filename, "rb") as f: + return f.read() + + def load_bgr(self): + img = cv2_imread (self.filename, loader_func=self.read_raw_file).astype(np.float32) / 255.0 + return img + + def get_config(self): + return {'sample_type': self.sample_type, + 'filename': self.filename, + 'face_type': self.face_type, + 'shape': self.shape, + 'landmarks': self.landmarks.tolist(), + 'seg_ie_polys': self.seg_ie_polys.dump(), + 'xseg_mask' : self.xseg_mask, + 'xseg_mask_compressed' : self.xseg_mask_compressed, + 'eyebrows_expand_mod': self.eyebrows_expand_mod, + 'source_filename': self.source_filename, + 'person_name': self.person_name + } diff --git a/samplelib/SampleGeneratorBase.py b/samplelib/SampleGeneratorBase.py new file mode 100644 index 0000000000000000000000000000000000000000..6e6016058675729d07a640c3b4a99ace55cbdb3d --- /dev/null +++ b/samplelib/SampleGeneratorBase.py @@ -0,0 +1,35 @@ +from pathlib import Path + +''' +You can implement your own SampleGenerator +''' +class SampleGeneratorBase(object): + + + def __init__ (self, debug=False, batch_size=1): + self.debug = debug + self.batch_size = 1 if self.debug else batch_size + self.last_generation = None + self.active = True + + def set_active(self, is_active): + self.active = is_active + + def generate_next(self): + if not self.active and self.last_generation is not None: + return self.last_generation + self.last_generation = next(self) + return self.last_generation + + #overridable + def __iter__(self): + #implement your own iterator + return self + + def __next__(self): + #implement your own iterator + return None + + #overridable + def is_initialized(self): + return True \ No newline at end of file diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py new file mode 100644 index 0000000000000000000000000000000000000000..605d32742e286a03d307bcf9eb9ace9d91decfae --- /dev/null +++ b/samplelib/SampleGeneratorFace.py @@ -0,0 +1,144 @@ +import multiprocessing +import time +import traceback + +import cv2 +import numpy as np + +from core import mplib +from core.interact import interact as io +from core.joblib import SubprocessGenerator, ThisThreadGenerator +from facelib import LandmarksProcessor +from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, + SampleType) + + +''' +arg +output_sample_types = [ + [SampleProcessor.TypeFlags, size, (optional) {} opts ] , + ... + ] +''' +class SampleGeneratorFace(SampleGeneratorBase): + def __init__ (self, samples_path, debug=False, batch_size=1, + random_ct_samples_path=None, + sample_process_options=SampleProcessor.Options(), + output_sample_types=[], + uniform_yaw_distribution=False, + generators_count=4, + raise_on_no_data=True, + **kwargs): + + super().__init__(debug, batch_size) + self.initialized = False + self.sample_process_options = sample_process_options + self.output_sample_types = output_sample_types + + if self.debug: + self.generators_count = 1 + else: + self.generators_count = max(1, generators_count) + + samples = SampleLoader.load (SampleType.FACE, samples_path) + self.samples_len = len(samples) + + if self.samples_len == 0: + if raise_on_no_data: + raise ValueError('No training data provided.') + else: + return + + if uniform_yaw_distribution: + samples_pyr = [ ( idx, sample.get_pitch_yaw_roll() ) for idx, sample in enumerate(samples) ] + + grads = 128 + #instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2 + grads_space = np.linspace (-1.2, 1.2,grads) + + yaws_sample_list = [None]*grads + for g in io.progress_bar_generator ( range(grads), "Sort by yaw"): + yaw = grads_space[g] + next_yaw = grads_space[g+1] if g < grads-1 else yaw + + yaw_samples = [] + for idx, pyr in samples_pyr: + s_yaw = -pyr[1] + if (g == 0 and s_yaw < next_yaw) or \ + (g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \ + (g == grads-1 and s_yaw >= yaw): + yaw_samples += [ idx ] + if len(yaw_samples) > 0: + yaws_sample_list[g] = yaw_samples + + yaws_sample_list = [ y for y in yaws_sample_list if y is not None ] + + index_host = mplib.Index2DHost( yaws_sample_list ) + else: + index_host = mplib.IndexHost(self.samples_len) + + if random_ct_samples_path is not None: + ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) + ct_index_host = mplib.IndexHost( len(ct_samples) ) + else: + ct_samples = None + ct_index_host = None + + if self.debug: + self.generators = [ThisThreadGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )] + else: + self.generators = [SubprocessGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \ + for i in range(self.generators_count) ] + + SubprocessGenerator.start_in_parallel( self.generators ) + + self.generator_counter = -1 + + self.initialized = True + + #overridable + def is_initialized(self): + return self.initialized + + def __iter__(self): + return self + + def __next__(self): + if not self.initialized: + return [] + + self.generator_counter += 1 + generator = self.generators[self.generator_counter % len(self.generators) ] + return next(generator) + + def batch_func(self, param ): + samples, index_host, ct_samples, ct_index_host = param + + bs = self.batch_size + while True: + batches = None + + indexes = index_host.multi_get(bs) + ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None + + t = time.time() + for n_batch in range(bs): + sample_idx = indexes[n_batch] + sample = samples[sample_idx] + + ct_sample = None + if ct_samples is not None: + ct_sample = ct_samples[ct_indexes[n_batch]] + + try: + x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample) + except: + raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) + + if batches is None: + batches = [ [] for _ in range(len(x)) ] + + for i in range(len(x)): + batches[i].append ( x[i] ) + + yield [ np.array(batch) for batch in batches] diff --git a/samplelib/SampleGeneratorFaceAvatarOperator.py b/samplelib/SampleGeneratorFaceAvatarOperator.py new file mode 100644 index 0000000000000000000000000000000000000000..965da4e8a645ad4a245fab2e9cd02b3624eb7ffd --- /dev/null +++ b/samplelib/SampleGeneratorFaceAvatarOperator.py @@ -0,0 +1,202 @@ +import multiprocessing +import pickle +import time +import traceback +from enum import IntEnum + +import cv2 +import numpy as np + +from core import imagelib, mplib, pathex +from core.imagelib import sd +from core.cv2ex import * +from core.interact import interact as io +from core.joblib import SubprocessGenerator, ThisThreadGenerator +from facelib import LandmarksProcessor +from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType) + +class SampleGeneratorFaceAvatarOperator(SampleGeneratorBase): + def __init__ (self, root_path, debug=False, batch_size=1, resolution=256, face_type=None, + generators_count=4, data_format="NHWC", + **kwargs): + + super().__init__(debug, batch_size) + self.initialized = False + + + dataset_path = root_path / 'AvatarOperatorDataset' + if not dataset_path.exists(): + raise ValueError(f'Unable to find {dataset_path}') + + chains_dir_names = pathex.get_all_dir_names(dataset_path) + + samples = SampleLoader.load (SampleType.FACE, dataset_path, subdirs=True) + sample_idx_by_path = { sample.filename : i for i,sample in enumerate(samples) } + + kf_idxs = [] + + for chain_dir_name in chains_dir_names: + chain_root_path = dataset_path / chain_dir_name + + subchain_dir_names = pathex.get_all_dir_names(chain_root_path) + try: + subchain_dir_names.sort(key=int) + except: + raise Exception(f'{chain_root_path} must contain only numerical name of directories') + chain_samples = [] + + for subchain_dir_name in subchain_dir_names: + subchain_root = chain_root_path / subchain_dir_name + subchain_samples = [ sample_idx_by_path[image_path] for image_path in pathex.get_image_paths(subchain_root) \ + if image_path in sample_idx_by_path ] + + if len(subchain_samples) < 3: + raise Exception(f'subchain {subchain_dir_name} must contain at least 3 faces. If you delete this subchain, then th echain will be corrupted.') + + chain_samples += [ subchain_samples ] + + chain_samples_len = len(chain_samples) + for i in range(chain_samples_len-1): + kf_idxs += [ ( chain_samples[i+1][0], chain_samples[i][-1], chain_samples[i][:-1] ) ] + + for i in range(1,chain_samples_len): + kf_idxs += [ ( chain_samples[i-1][-1], chain_samples[i][0], chain_samples[i][1:] ) ] + + if self.debug: + self.generators_count = 1 + else: + self.generators_count = max(1, generators_count) + + if self.debug: + self.generators = [ThisThreadGenerator ( self.batch_func, (samples, kf_idxs, resolution, face_type, data_format) )] + else: + self.generators = [SubprocessGenerator ( self.batch_func, (samples, kf_idxs, resolution, face_type, data_format), start_now=False ) \ + for i in range(self.generators_count) ] + + SubprocessGenerator.start_in_parallel( self.generators ) + + self.generator_counter = -1 + + self.initialized = True + + #overridable + def is_initialized(self): + return self.initialized + + def __iter__(self): + return self + + def __next__(self): + self.generator_counter += 1 + generator = self.generators[self.generator_counter % len(self.generators) ] + return next(generator) + + def batch_func(self, param ): + samples, kf_idxs, resolution, face_type, data_format = param + + kf_idxs_len = len(kf_idxs) + + shuffle_idxs = [] + idxs = [*range(len(samples))] + + random_flip = True + rotation_range=[-10,10] + scale_range=[-0.05, 0.05] + tx_range=[-0.05, 0.05] + ty_range=[-0.05, 0.05] + + bs = self.batch_size + while True: + batches = [ [], [] , [], [], [], [] ] + + n_batch = 0 + while n_batch < bs: + try: + if len(shuffle_idxs) == 0: + shuffle_idxs = idxs.copy() + np.random.shuffle(shuffle_idxs) + idx = shuffle_idxs.pop() + + + key_idx, key_chain_idx, chain_idxs = kf_idxs[ np.random.randint(kf_idxs_len) ] + + key_sample = samples[key_idx] + key_chain_sample = samples[key_chain_idx] + chain_sample = samples[ chain_idxs[np.random.randint(len(chain_idxs)) ] ] + + #print('==========') + #print(key_sample.filename) + #print(key_chain_sample.filename) + #print(chain_sample.filename) + + sample = samples[idx] + + img = sample.load_bgr() + + key_img = key_sample.load_bgr() + key_chain_img = key_chain_sample.load_bgr() + chain_img = chain_sample.load_bgr() + + h,w,c = img.shape + + mask = LandmarksProcessor.get_image_hull_mask (img.shape, sample.landmarks) + mask = np.clip(mask, 0, 1) + + warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range ) + + if face_type == sample.face_type: + if w != resolution: + img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC ) + key_img = cv2.resize( key_img, (resolution, resolution), cv2.INTER_CUBIC ) + key_chain_img = cv2.resize( key_chain_img, (resolution, resolution), cv2.INTER_CUBIC ) + chain_img = cv2.resize( chain_img, (resolution, resolution), cv2.INTER_CUBIC ) + + mask = cv2.resize( mask, (resolution, resolution), cv2.INTER_CUBIC ) + else: + mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, face_type) + img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) + key_img = cv2.warpAffine( key_img, mat, (resolution,resolution), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) + key_chain_img = cv2.warpAffine( key_chain_img, mat, (resolution,resolution), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) + chain_img = cv2.warpAffine( chain_img, mat, (resolution,resolution), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) + mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_CUBIC ) + + if len(mask.shape) == 2: + mask = mask[...,None] + + img_warped = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=True) + img_transformed = imagelib.warp_by_params (warp_params, img, can_warp=False, can_transform=True, can_flip=True, border_replicate=True) + + mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False) + + key_img = imagelib.warp_by_params (warp_params, key_img, can_warp=False, can_transform=False, can_flip=False, border_replicate=True) + key_chain_img = imagelib.warp_by_params (warp_params, key_chain_img, can_warp=False, can_transform=False, can_flip=False, border_replicate=True) + chain_img = imagelib.warp_by_params (warp_params, chain_img, can_warp=False, can_transform=False, can_flip=False, border_replicate=True) + + + img_warped = np.clip(img_warped.astype(np.float32), 0, 1) + img_transformed = np.clip(img_transformed.astype(np.float32), 0, 1) + mask[mask < 0.5] = 0.0 + mask[mask >= 0.5] = 1.0 + mask = np.clip(mask, 0, 1) + + if data_format == "NCHW": + img_warped = np.transpose(img_warped, (2,0,1) ) + img_transformed = np.transpose(img_transformed, (2,0,1) ) + mask = np.transpose(mask, (2,0,1) ) + + key_img = np.transpose(key_img, (2,0,1) ) + key_chain_img = np.transpose(key_chain_img, (2,0,1) ) + chain_img = np.transpose(chain_img, (2,0,1) ) + + batches[0].append ( img_warped ) + batches[1].append ( img_transformed ) + batches[2].append ( mask ) + batches[3].append ( key_img ) + batches[4].append ( key_chain_img ) + batches[5].append ( chain_img ) + + n_batch += 1 + except: + io.log_err ( traceback.format_exc() ) + + yield [ np.array(batch) for batch in batches] diff --git a/samplelib/SampleGeneratorFaceCelebAMaskHQ.py b/samplelib/SampleGeneratorFaceCelebAMaskHQ.py new file mode 100644 index 0000000000000000000000000000000000000000..b943b1f93ae3a5d2d03c73eea37a4f34a59b3e97 --- /dev/null +++ b/samplelib/SampleGeneratorFaceCelebAMaskHQ.py @@ -0,0 +1,269 @@ +import multiprocessing +import pickle +import time +import traceback +from enum import IntEnum + +import cv2 +import numpy as np + +from core import imagelib, mplib, pathex +from core.cv2ex import * +from core.interact import interact as io +from core.joblib import SubprocessGenerator, ThisThreadGenerator +from facelib import LandmarksProcessor +from samplelib import SampleGeneratorBase + + +class MaskType(IntEnum): + none = 0, + cloth = 1, + ear_r = 2, + eye_g = 3, + hair = 4, + hat = 5, + l_brow = 6, + l_ear = 7, + l_eye = 8, + l_lip = 9, + mouth = 10, + neck = 11, + neck_l = 12, + nose = 13, + r_brow = 14, + r_ear = 15, + r_eye = 16, + skin = 17, + u_lip = 18 + + + +MaskType_to_name = { + int(MaskType.none ) : 'none', + int(MaskType.cloth ) : 'cloth', + int(MaskType.ear_r ) : 'ear_r', + int(MaskType.eye_g ) : 'eye_g', + int(MaskType.hair ) : 'hair', + int(MaskType.hat ) : 'hat', + int(MaskType.l_brow) : 'l_brow', + int(MaskType.l_ear ) : 'l_ear', + int(MaskType.l_eye ) : 'l_eye', + int(MaskType.l_lip ) : 'l_lip', + int(MaskType.mouth ) : 'mouth', + int(MaskType.neck ) : 'neck', + int(MaskType.neck_l) : 'neck_l', + int(MaskType.nose ) : 'nose', + int(MaskType.r_brow) : 'r_brow', + int(MaskType.r_ear ) : 'r_ear', + int(MaskType.r_eye ) : 'r_eye', + int(MaskType.skin ) : 'skin', + int(MaskType.u_lip ) : 'u_lip', +} + +MaskType_from_name = { MaskType_to_name[k] : k for k in MaskType_to_name.keys() } + +class SampleGeneratorFaceCelebAMaskHQ(SampleGeneratorBase): + def __init__ (self, root_path, debug=False, batch_size=1, resolution=256, + generators_count=4, data_format="NHWC", + **kwargs): + + super().__init__(debug, batch_size) + self.initialized = False + + dataset_path = root_path / 'CelebAMask-HQ' + if not dataset_path.exists(): + raise ValueError(f'Unable to find {dataset_path}') + + images_path = dataset_path /'CelebA-HQ-img' + if not images_path.exists(): + raise ValueError(f'Unable to find {images_path}') + + masks_path = dataset_path / 'CelebAMask-HQ-mask-anno' + if not masks_path.exists(): + raise ValueError(f'Unable to find {masks_path}') + + + if self.debug: + self.generators_count = 1 + else: + self.generators_count = max(1, generators_count) + + source_images_paths = pathex.get_image_paths(images_path, return_Path_class=True) + source_images_paths_len = len(source_images_paths) + mask_images_paths = pathex.get_image_paths(masks_path, subdirs=True, return_Path_class=True) + + if source_images_paths_len == 0 or len(mask_images_paths) == 0: + raise ValueError('No training data provided.') + + mask_file_id_hash = {} + + for filepath in io.progress_bar_generator(mask_images_paths, "Loading"): + stem = filepath.stem + + file_id, mask_type = stem.split('_', 1) + file_id = int(file_id) + + if file_id not in mask_file_id_hash: + mask_file_id_hash[file_id] = {} + + mask_file_id_hash[file_id][ MaskType_from_name[mask_type] ] = str(filepath.relative_to(masks_path)) + + source_file_id_set = set() + + for filepath in source_images_paths: + stem = filepath.stem + + file_id = int(stem) + source_file_id_set.update ( {file_id} ) + + for k in mask_file_id_hash.keys(): + if k not in source_file_id_set: + io.log_err (f"Corrupted dataset: {k} not in {images_path}") + + + + if self.debug: + self.generators = [ThisThreadGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format) )] + else: + self.generators = [SubprocessGenerator ( self.batch_func, (images_path, masks_path, mask_file_id_hash, data_format), start_now=False ) \ + for i in range(self.generators_count) ] + + SubprocessGenerator.start_in_parallel( self.generators ) + + self.generator_counter = -1 + + self.initialized = True + + #overridable + def is_initialized(self): + return self.initialized + + def __iter__(self): + return self + + def __next__(self): + self.generator_counter += 1 + generator = self.generators[self.generator_counter % len(self.generators) ] + return next(generator) + + def batch_func(self, param ): + images_path, masks_path, mask_file_id_hash, data_format = param + + file_ids = list(mask_file_id_hash.keys()) + + shuffle_file_ids = [] + + resolution = 256 + random_flip = True + rotation_range=[-15,15] + scale_range=[-0.10, 0.95] + tx_range=[-0.3, 0.3] + ty_range=[-0.3, 0.3] + + random_bilinear_resize = (25,75) + motion_blur = (25, 5) + gaussian_blur = (25, 5) + + bs = self.batch_size + while True: + batches = None + + n_batch = 0 + while n_batch < bs: + try: + if len(shuffle_file_ids) == 0: + shuffle_file_ids = file_ids.copy() + np.random.shuffle(shuffle_file_ids) + + file_id = shuffle_file_ids.pop() + masks = mask_file_id_hash[file_id] + image_path = images_path / f'{file_id}.jpg' + + skin_path = masks.get(MaskType.skin, None) + hair_path = masks.get(MaskType.hair, None) + hat_path = masks.get(MaskType.hat, None) + #neck_path = masks.get(MaskType.neck, None) + + img = cv2_imread(image_path).astype(np.float32) / 255.0 + mask = cv2_imread(masks_path / skin_path)[...,0:1].astype(np.float32) / 255.0 + + if hair_path is not None: + hair_path = masks_path / hair_path + if hair_path.exists(): + hair = cv2_imread(hair_path)[...,0:1].astype(np.float32) / 255.0 + mask *= (1-hair) + + if hat_path is not None: + hat_path = masks_path / hat_path + if hat_path.exists(): + hat = cv2_imread(hat_path)[...,0:1].astype(np.float32) / 255.0 + mask *= (1-hat) + + #if neck_path is not None: + # neck_path = masks_path / neck_path + # if neck_path.exists(): + # neck = cv2_imread(neck_path)[...,0:1].astype(np.float32) / 255.0 + # mask = np.clip(mask+neck, 0, 1) + + warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range ) + + img = cv2.resize( img, (resolution,resolution), cv2.INTER_LANCZOS4 ) + h, s, v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) + h = ( h + np.random.randint(360) ) % 360 + s = np.clip ( s + np.random.random()-0.5, 0, 1 ) + v = np.clip ( v + np.random.random()/2-0.25, 0, 1 ) + img = np.clip( cv2.cvtColor(cv2.merge([h, s, v]), cv2.COLOR_HSV2BGR) , 0, 1 ) + + if motion_blur is not None: + chance, mb_max_size = motion_blur + chance = np.clip(chance, 0, 100) + + mblur_rnd_chance = np.random.randint(100) + mblur_rnd_kernel = np.random.randint(mb_max_size)+1 + mblur_rnd_deg = np.random.randint(360) + + if mblur_rnd_chance < chance: + img = imagelib.LinearMotionBlur (img, mblur_rnd_kernel, mblur_rnd_deg ) + + img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4) + + if gaussian_blur is not None: + chance, kernel_max_size = gaussian_blur + chance = np.clip(chance, 0, 100) + + gblur_rnd_chance = np.random.randint(100) + gblur_rnd_kernel = np.random.randint(kernel_max_size)*2+1 + + if gblur_rnd_chance < chance: + img = cv2.GaussianBlur(img, (gblur_rnd_kernel,) *2 , 0) + + if random_bilinear_resize is not None: + chance, max_size_per = random_bilinear_resize + chance = np.clip(chance, 0, 100) + pick_chance = np.random.randint(100) + resize_to = resolution - int( np.random.rand()* int(resolution*(max_size_per/100.0)) ) + img = cv2.resize (img, (resize_to,resize_to), cv2.INTER_LINEAR ) + img = cv2.resize (img, (resolution,resolution), cv2.INTER_LINEAR ) + + + mask = cv2.resize( mask, (resolution,resolution), cv2.INTER_LANCZOS4 )[...,None] + mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False, cv2_inter=cv2.INTER_LANCZOS4) + mask[mask < 0.5] = 0.0 + mask[mask >= 0.5] = 1.0 + mask = np.clip(mask, 0, 1) + + if data_format == "NCHW": + img = np.transpose(img, (2,0,1) ) + mask = np.transpose(mask, (2,0,1) ) + + if batches is None: + batches = [ [], [] ] + + batches[0].append ( img ) + batches[1].append ( mask ) + + n_batch += 1 + except: + io.log_err ( traceback.format_exc() ) + + yield [ np.array(batch) for batch in batches] diff --git a/samplelib/SampleGeneratorFaceDebug.py b/samplelib/SampleGeneratorFaceDebug.py new file mode 100644 index 0000000000000000000000000000000000000000..6fdf5296e85f037b81f85a50a904fbbfc838f410 --- /dev/null +++ b/samplelib/SampleGeneratorFaceDebug.py @@ -0,0 +1,133 @@ +import multiprocessing +import pickle +import time +import traceback + +import cv2 +import numpy as np + +from core import mplib +from core.joblib import SubprocessGenerator, ThisThreadGenerator +from facelib import LandmarksProcessor +from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, + SampleType) + + +''' +arg +output_sample_types = [ + [SampleProcessor.TypeFlags, size, (optional) {} opts ] , + ... + ] +''' +class SampleGeneratorFaceDebug(SampleGeneratorBase): + def __init__ (self, samples_path, debug=False, batch_size=1, + random_ct_samples_path=None, + sample_process_options=SampleProcessor.Options(), + output_sample_types=[], + add_sample_idx=False, + generators_count=4, + rnd_seed=None, + **kwargs): + + super().__init__(debug, batch_size) + self.sample_process_options = sample_process_options + self.output_sample_types = output_sample_types + self.add_sample_idx = add_sample_idx + + if rnd_seed is None: + rnd_seed = np.random.randint(0x80000000) + + if self.debug: + self.generators_count = 1 + else: + self.generators_count = max(1, generators_count) + + samples = SampleLoader.load (SampleType.FACE, samples_path) + self.samples_len = len(samples) + + if self.samples_len == 0: + raise ValueError('No training data provided.') + + if random_ct_samples_path is not None: + ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) + else: + ct_samples = None + + pickled_samples = pickle.dumps(samples, 4) + ct_pickled_samples = pickle.dumps(ct_samples, 4) if ct_samples is not None else None + + if self.debug: + self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, ct_pickled_samples, rnd_seed) )] + else: + self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, ct_pickled_samples, rnd_seed+i), start_now=False ) \ + for i in range(self.generators_count) ] + + SubprocessGenerator.start_in_parallel( self.generators ) + + self.generator_counter = -1 + + def __iter__(self): + return self + + def __next__(self): + self.generator_counter += 1 + generator = self.generators[self.generator_counter % len(self.generators) ] + return next(generator) + + def batch_func(self, param ): + pickled_samples, ct_pickled_samples, rnd_seed = param + + rnd_state = np.random.RandomState(rnd_seed) + + samples = pickle.loads(pickled_samples) + idxs = [*range(len(samples))] + shuffle_idxs = [] + + if ct_pickled_samples is not None: + ct_samples = pickle.loads(ct_pickled_samples) + ct_idxs = [*range(len(ct_samples))] + ct_shuffle_idxs = [] + else: + ct_samples = None + + + bs = self.batch_size + while True: + batches = None + + for n_batch in range(bs): + + if len(shuffle_idxs) == 0: + shuffle_idxs = idxs.copy() + rnd_state.shuffle(shuffle_idxs) + + sample_idx = shuffle_idxs.pop() + sample = samples[sample_idx] + + ct_sample = None + if ct_samples is not None: + if len(ct_shuffle_idxs) == 0: + ct_shuffle_idxs = ct_idxs.copy() + rnd_state.shuffle(ct_shuffle_idxs) + ct_sample_idx = ct_shuffle_idxs.pop() + ct_sample = ct_samples[ct_sample_idx] + + try: + x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, rnd_state=rnd_state) + except: + raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) + + if batches is None: + batches = [ [] for _ in range(len(x)) ] + if self.add_sample_idx: + batches += [ [] ] + i_sample_idx = len(batches)-1 + + for i in range(len(x)): + batches[i].append ( x[i] ) + + if self.add_sample_idx: + batches[i_sample_idx].append (sample_idx) + + yield [ np.array(batch) for batch in batches] diff --git a/samplelib/SampleGeneratorFacePerson.py b/samplelib/SampleGeneratorFacePerson.py new file mode 100644 index 0000000000000000000000000000000000000000..0fbd2c33e8229af5ce7866754597a6ce0f2c2259 --- /dev/null +++ b/samplelib/SampleGeneratorFacePerson.py @@ -0,0 +1,365 @@ +import copy +import multiprocessing +import traceback + +import cv2 +import numpy as np + +from core import mplib +from core.joblib import SubprocessGenerator, ThisThreadGenerator +from facelib import LandmarksProcessor +from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, + SampleType) + + + +class Index2DHost(): + """ + Provides random shuffled 2D indexes for multiprocesses + """ + def __init__(self, indexes2D): + self.sq = multiprocessing.Queue() + self.cqs = [] + self.clis = [] + self.thread = threading.Thread(target=self.host_thread, args=(indexes2D,) ) + self.thread.daemon = True + self.thread.start() + + def host_thread(self, indexes2D): + indexes_counts_len = len(indexes2D) + + idxs = [*range(indexes_counts_len)] + idxs_2D = [None]*indexes_counts_len + shuffle_idxs = [] + shuffle_idxs_2D = [None]*indexes_counts_len + for i in range(indexes_counts_len): + idxs_2D[i] = indexes2D[i] + shuffle_idxs_2D[i] = [] + + sq = self.sq + + while True: + while not sq.empty(): + obj = sq.get() + cq_id, cmd = obj[0], obj[1] + + if cmd == 0: #get_1D + count = obj[2] + + result = [] + for i in range(count): + if len(shuffle_idxs) == 0: + shuffle_idxs = idxs.copy() + np.random.shuffle(shuffle_idxs) + result.append(shuffle_idxs.pop()) + self.cqs[cq_id].put (result) + elif cmd == 1: #get_2D + targ_idxs,count = obj[2], obj[3] + result = [] + + for targ_idx in targ_idxs: + sub_idxs = [] + for i in range(count): + ar = shuffle_idxs_2D[targ_idx] + if len(ar) == 0: + ar = shuffle_idxs_2D[targ_idx] = idxs_2D[targ_idx].copy() + np.random.shuffle(ar) + sub_idxs.append(ar.pop()) + result.append (sub_idxs) + self.cqs[cq_id].put (result) + + time.sleep(0.001) + + def create_cli(self): + cq = multiprocessing.Queue() + self.cqs.append ( cq ) + cq_id = len(self.cqs)-1 + return Index2DHost.Cli(self.sq, cq, cq_id) + + # disable pickling + def __getstate__(self): + return dict() + def __setstate__(self, d): + self.__dict__.update(d) + + class Cli(): + def __init__(self, sq, cq, cq_id): + self.sq = sq + self.cq = cq + self.cq_id = cq_id + + def get_1D(self, count): + self.sq.put ( (self.cq_id,0, count) ) + + while True: + if not self.cq.empty(): + return self.cq.get() + time.sleep(0.001) + + def get_2D(self, idxs, count): + self.sq.put ( (self.cq_id,1,idxs,count) ) + + while True: + if not self.cq.empty(): + return self.cq.get() + time.sleep(0.001) + +''' +arg +output_sample_types = [ + [SampleProcessor.TypeFlags, size, (optional) {} opts ] , + ... + ] +''' +class SampleGeneratorFacePerson(SampleGeneratorBase): + def __init__ (self, samples_path, debug=False, batch_size=1, + sample_process_options=SampleProcessor.Options(), + output_sample_types=[], + person_id_mode=1, + **kwargs): + + super().__init__(debug, batch_size) + self.sample_process_options = sample_process_options + self.output_sample_types = output_sample_types + self.person_id_mode = person_id_mode + + raise NotImplementedError("Currently SampleGeneratorFacePerson is not implemented.") + + samples_host = SampleLoader.mp_host (SampleType.FACE, samples_path) + samples = samples_host.get_list() + self.samples_len = len(samples) + + if self.samples_len == 0: + raise ValueError('No training data provided.') + + unique_person_names = { sample.person_name for sample in samples } + persons_name_idxs = { person_name : [] for person_name in unique_person_names } + for i,sample in enumerate(samples): + persons_name_idxs[sample.person_name].append (i) + indexes2D = [ persons_name_idxs[person_name] for person_name in unique_person_names ] + index2d_host = Index2DHost(indexes2D) + + if self.debug: + self.generators_count = 1 + self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) )] + else: + self.generators_count = np.clip(multiprocessing.cpu_count(), 2, 4) + self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (samples_host.create_cli(), index2d_host.create_cli(),) ) for i in range(self.generators_count) ] + + self.generator_counter = -1 + + def __iter__(self): + return self + + def __next__(self): + self.generator_counter += 1 + generator = self.generators[self.generator_counter % len(self.generators) ] + return next(generator) + + def batch_func(self, param ): + samples, index2d_host, = param + bs = self.batch_size + + while True: + person_idxs = index2d_host.get_1D(bs) + samples_idxs = index2d_host.get_2D(person_idxs, 1) + + batches = None + for n_batch in range(bs): + person_id = person_idxs[n_batch] + sample_idx = samples_idxs[n_batch][0] + + sample = samples[ sample_idx ] + try: + x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug) + except: + raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) + + if batches is None: + batches = [ [] for _ in range(len(x)) ] + + batches += [ [] ] + i_person_id = len(batches)-1 + + for i in range(len(x)): + batches[i].append ( x[i] ) + + batches[i_person_id].append ( np.array([person_id]) ) + + yield [ np.array(batch) for batch in batches] + + @staticmethod + def get_person_id_max_count(samples_path): + return SampleLoader.get_person_id_max_count(samples_path) + +""" +if self.person_id_mode==1: + samples_len = len(samples) + samples_idxs = [*range(samples_len)] + shuffle_idxs = [] + elif self.person_id_mode==2: + persons_count = len(samples) + + person_idxs = [] + for j in range(persons_count): + for i in range(j+1,persons_count): + person_idxs += [ [i,j] ] + + shuffle_person_idxs = [] + + samples_idxs = [None]*persons_count + shuffle_idxs = [None]*persons_count + + for i in range(persons_count): + samples_idxs[i] = [*range(len(samples[i]))] + shuffle_idxs[i] = [] + elif self.person_id_mode==3: + persons_count = len(samples) + + person_idxs = [ *range(persons_count) ] + shuffle_person_idxs = [] + + samples_idxs = [None]*persons_count + shuffle_idxs = [None]*persons_count + + for i in range(persons_count): + samples_idxs[i] = [*range(len(samples[i]))] + shuffle_idxs[i] = [] + +if self.person_id_mode==2: + if len(shuffle_person_idxs) == 0: + shuffle_person_idxs = person_idxs.copy() + np.random.shuffle(shuffle_person_idxs) + person_ids = shuffle_person_idxs.pop() + + + batches = None + for n_batch in range(self.batch_size): + + if self.person_id_mode==1: + if len(shuffle_idxs) == 0: + shuffle_idxs = samples_idxs.copy() + np.random.shuffle(shuffle_idxs) ### + + idx = shuffle_idxs.pop() + sample = samples[ idx ] + + try: + x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug) + except: + raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) + + if type(x) != tuple and type(x) != list: + raise Exception('SampleProcessor.process returns NOT tuple/list') + + if batches is None: + batches = [ [] for _ in range(len(x)) ] + + batches += [ [] ] + i_person_id = len(batches)-1 + + for i in range(len(x)): + batches[i].append ( x[i] ) + + batches[i_person_id].append ( np.array([sample.person_id]) ) + + + elif self.person_id_mode==2: + person_id1, person_id2 = person_ids + + if len(shuffle_idxs[person_id1]) == 0: + shuffle_idxs[person_id1] = samples_idxs[person_id1].copy() + np.random.shuffle(shuffle_idxs[person_id1]) + + idx = shuffle_idxs[person_id1].pop() + sample1 = samples[person_id1][idx] + + if len(shuffle_idxs[person_id2]) == 0: + shuffle_idxs[person_id2] = samples_idxs[person_id2].copy() + np.random.shuffle(shuffle_idxs[person_id2]) + + idx = shuffle_idxs[person_id2].pop() + sample2 = samples[person_id2][idx] + + if sample1 is not None and sample2 is not None: + try: + x1, = SampleProcessor.process ([sample1], self.sample_process_options, self.output_sample_types, self.debug) + except: + raise Exception ("Exception occured in sample %s. Error: %s" % (sample1.filename, traceback.format_exc() ) ) + + try: + x2, = SampleProcessor.process ([sample2], self.sample_process_options, self.output_sample_types, self.debug) + except: + raise Exception ("Exception occured in sample %s. Error: %s" % (sample2.filename, traceback.format_exc() ) ) + + x1_len = len(x1) + if batches is None: + batches = [ [] for _ in range(x1_len) ] + batches += [ [] ] + i_person_id1 = len(batches)-1 + + batches += [ [] for _ in range(len(x2)) ] + batches += [ [] ] + i_person_id2 = len(batches)-1 + + for i in range(x1_len): + batches[i].append ( x1[i] ) + + for i in range(len(x2)): + batches[x1_len+1+i].append ( x2[i] ) + + batches[i_person_id1].append ( np.array([sample1.person_id]) ) + + batches[i_person_id2].append ( np.array([sample2.person_id]) ) + + elif self.person_id_mode==3: + if len(shuffle_person_idxs) == 0: + shuffle_person_idxs = person_idxs.copy() + np.random.shuffle(shuffle_person_idxs) + person_id = shuffle_person_idxs.pop() + + if len(shuffle_idxs[person_id]) == 0: + shuffle_idxs[person_id] = samples_idxs[person_id].copy() + np.random.shuffle(shuffle_idxs[person_id]) + + idx = shuffle_idxs[person_id].pop() + sample1 = samples[person_id][idx] + + if len(shuffle_idxs[person_id]) == 0: + shuffle_idxs[person_id] = samples_idxs[person_id].copy() + np.random.shuffle(shuffle_idxs[person_id]) + + idx = shuffle_idxs[person_id].pop() + sample2 = samples[person_id][idx] + + if sample1 is not None and sample2 is not None: + try: + x1, = SampleProcessor.process ([sample1], self.sample_process_options, self.output_sample_types, self.debug) + except: + raise Exception ("Exception occured in sample %s. Error: %s" % (sample1.filename, traceback.format_exc() ) ) + + try: + x2, = SampleProcessor.process ([sample2], self.sample_process_options, self.output_sample_types, self.debug) + except: + raise Exception ("Exception occured in sample %s. Error: %s" % (sample2.filename, traceback.format_exc() ) ) + + x1_len = len(x1) + if batches is None: + batches = [ [] for _ in range(x1_len) ] + batches += [ [] ] + i_person_id1 = len(batches)-1 + + batches += [ [] for _ in range(len(x2)) ] + batches += [ [] ] + i_person_id2 = len(batches)-1 + + for i in range(x1_len): + batches[i].append ( x1[i] ) + + for i in range(len(x2)): + batches[x1_len+1+i].append ( x2[i] ) + + batches[i_person_id1].append ( np.array([sample1.person_id]) ) + + batches[i_person_id2].append ( np.array([sample2.person_id]) ) +""" diff --git a/samplelib/SampleGeneratorFaceSkinSegDataset - Copy.py b/samplelib/SampleGeneratorFaceSkinSegDataset - Copy.py new file mode 100644 index 0000000000000000000000000000000000000000..0ef530a5b07af92c1d2fa773f7c0249287ec7740 --- /dev/null +++ b/samplelib/SampleGeneratorFaceSkinSegDataset - Copy.py @@ -0,0 +1,260 @@ +import multiprocessing +import pickle +import time +import traceback +from enum import IntEnum + +import cv2 +import numpy as np + +from core import imagelib, mplib, pathex +from core.imagelib import sd +from core.cv2ex import * +from core.interact import interact as io +from core.joblib import SubprocessGenerator, ThisThreadGenerator +from facelib import LandmarksProcessor +from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType) + +class MaskType(IntEnum): + none = 0, + cloth = 1, + ear_r = 2, + eye_g = 3, + hair = 4, + hat = 5, + l_brow = 6, + l_ear = 7, + l_eye = 8, + l_lip = 9, + mouth = 10, + neck = 11, + neck_l = 12, + nose = 13, + r_brow = 14, + r_ear = 15, + r_eye = 16, + skin = 17, + u_lip = 18 + + + +MaskType_to_name = { + int(MaskType.none ) : 'none', + int(MaskType.cloth ) : 'cloth', + int(MaskType.ear_r ) : 'ear_r', + int(MaskType.eye_g ) : 'eye_g', + int(MaskType.hair ) : 'hair', + int(MaskType.hat ) : 'hat', + int(MaskType.l_brow) : 'l_brow', + int(MaskType.l_ear ) : 'l_ear', + int(MaskType.l_eye ) : 'l_eye', + int(MaskType.l_lip ) : 'l_lip', + int(MaskType.mouth ) : 'mouth', + int(MaskType.neck ) : 'neck', + int(MaskType.neck_l) : 'neck_l', + int(MaskType.nose ) : 'nose', + int(MaskType.r_brow) : 'r_brow', + int(MaskType.r_ear ) : 'r_ear', + int(MaskType.r_eye ) : 'r_eye', + int(MaskType.skin ) : 'skin', + int(MaskType.u_lip ) : 'u_lip', +} + +MaskType_from_name = { MaskType_to_name[k] : k for k in MaskType_to_name.keys() } + +class SampleGeneratorFaceSkinSegDataset(SampleGeneratorBase): + def __init__ (self, root_path, debug=False, batch_size=1, resolution=256, face_type=None, + generators_count=4, data_format="NHWC", + **kwargs): + + super().__init__(debug, batch_size) + self.initialized = False + + + aligned_path = root_path /'aligned' + if not aligned_path.exists(): + raise ValueError(f'Unable to find {aligned_path}') + + obstructions_path = root_path / 'obstructions' + + obstructions_images_paths = pathex.get_image_paths(obstructions_path, image_extensions=['.png'], subdirs=True) + + samples = SampleLoader.load (SampleType.FACE, aligned_path, subdirs=True) + self.samples_len = len(samples) + + pickled_samples = pickle.dumps(samples, 4) + + if self.debug: + self.generators_count = 1 + else: + self.generators_count = max(1, generators_count) + + if self.debug: + self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, obstructions_images_paths, resolution, face_type, data_format) )] + else: + self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, obstructions_images_paths, resolution, face_type, data_format), start_now=False ) \ + for i in range(self.generators_count) ] + + SubprocessGenerator.start_in_parallel( self.generators ) + + self.generator_counter = -1 + + self.initialized = True + + #overridable + def is_initialized(self): + return self.initialized + + def __iter__(self): + return self + + def __next__(self): + self.generator_counter += 1 + generator = self.generators[self.generator_counter % len(self.generators) ] + return next(generator) + + def batch_func(self, param ): + pickled_samples, obstructions_images_paths, resolution, face_type, data_format = param + + samples = pickle.loads(pickled_samples) + + obstructions_images_paths_len = len(obstructions_images_paths) + shuffle_o_idxs = [] + o_idxs = [*range(obstructions_images_paths_len)] + + shuffle_idxs = [] + idxs = [*range(len(samples))] + + random_flip = True + rotation_range=[-10,10] + scale_range=[-0.05, 0.05] + tx_range=[-0.05, 0.05] + ty_range=[-0.05, 0.05] + + o_random_flip = True + o_rotation_range=[-180,180] + o_scale_range=[-0.5, 0.05] + o_tx_range=[-0.5, 0.5] + o_ty_range=[-0.5, 0.5] + + random_bilinear_resize_chance, random_bilinear_resize_max_size_per = 25,75 + motion_blur_chance, motion_blur_mb_max_size = 25, 5 + gaussian_blur_chance, gaussian_blur_kernel_max_size = 25, 5 + + bs = self.batch_size + while True: + batches = [ [], [] ] + + n_batch = 0 + while n_batch < bs: + try: + if len(shuffle_idxs) == 0: + shuffle_idxs = idxs.copy() + np.random.shuffle(shuffle_idxs) + + idx = shuffle_idxs.pop() + + sample = samples[idx] + + img = sample.load_bgr() + h,w,c = img.shape + + mask = np.zeros ((h,w,1), dtype=np.float32) + sample.ie_polys.overlay_mask(mask) + + warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range ) + + if face_type == sample.face_type: + if w != resolution: + img = cv2.resize( img, (resolution, resolution), cv2.INTER_LANCZOS4 ) + mask = cv2.resize( mask, (resolution, resolution), cv2.INTER_LANCZOS4 ) + else: + mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, face_type) + img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 ) + mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 ) + + if len(mask.shape) == 2: + mask = mask[...,None] + + if obstructions_images_paths_len != 0: + # apply obstruction + if len(shuffle_o_idxs) == 0: + shuffle_o_idxs = o_idxs.copy() + np.random.shuffle(shuffle_o_idxs) + o_idx = shuffle_o_idxs.pop() + o_img = cv2_imread (obstructions_images_paths[o_idx]).astype(np.float32) / 255.0 + oh,ow,oc = o_img.shape + if oc == 4: + ohw = max(oh,ow) + scale = resolution / ohw + + #o_img = cv2.resize (o_img, ( int(ow*rate), int(oh*rate), ), cv2.INTER_CUBIC) + + + + + + mat = cv2.getRotationMatrix2D( (ow/2,oh/2), + np.random.uniform( o_rotation_range[0], o_rotation_range[1] ), + 1.0 ) + + mat += np.float32( [[0,0, -ow/2 ], + [0,0, -oh/2 ]]) + mat *= scale * np.random.uniform(1 +o_scale_range[0], 1 +o_scale_range[1]) + mat += np.float32( [[0, 0, resolution/2 + resolution*np.random.uniform( o_tx_range[0], o_tx_range[1] ) ], + [0, 0, resolution/2 + resolution*np.random.uniform( o_ty_range[0], o_ty_range[1] ) ] ]) + + + o_img = cv2.warpAffine( o_img, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 ) + + if o_random_flip and np.random.randint(10) < 4: + o_img = o_img[:,::-1,...] + + o_mask = o_img[...,3:4] + o_mask[o_mask>0] = 1.0 + + + o_mask = cv2.erode (o_mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5)), iterations = 1 ) + o_mask = cv2.GaussianBlur(o_mask, (5, 5) , 0)[...,None] + + img = img*(1-o_mask) + o_img[...,0:3]*o_mask + + o_mask[o_mask<0.5] = 0.0 + + + #import code + #code.interact(local=dict(globals(), **locals())) + mask *= (1-o_mask) + + + #cv2.imshow ("", np.clip(o_img*255, 0,255).astype(np.uint8) ) + #cv2.waitKey(0) + + + img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=False) + mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False) + + + img = np.clip(img.astype(np.float32), 0, 1) + mask[mask < 0.5] = 0.0 + mask[mask >= 0.5] = 1.0 + mask = np.clip(mask, 0, 1) + + + img = imagelib.apply_random_hsv_shift(img, mask=sd.random_circle_faded ([resolution,resolution])) + img = imagelib.apply_random_motion_blur( img, motion_blur_chance, motion_blur_mb_max_size, mask=sd.random_circle_faded ([resolution,resolution])) + img = imagelib.apply_random_gaussian_blur( img, gaussian_blur_chance, gaussian_blur_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution])) + img = imagelib.apply_random_bilinear_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution])) + + if data_format == "NCHW": + img = np.transpose(img, (2,0,1) ) + mask = np.transpose(mask, (2,0,1) ) + + batches[0].append ( img ) + batches[1].append ( mask ) + + n_batch += 1 + except: + io.log_err ( traceback.format_exc() ) + + yield [ np.array(batch) for batch in batches] diff --git a/samplelib/SampleGeneratorFaceTemporal.py b/samplelib/SampleGeneratorFaceTemporal.py new file mode 100644 index 0000000000000000000000000000000000000000..213747667df3798cef7c22244e76ac2f9133a026 --- /dev/null +++ b/samplelib/SampleGeneratorFaceTemporal.py @@ -0,0 +1,88 @@ +import multiprocessing +import pickle +import time +import traceback + +import cv2 +import numpy as np + +from core import mplib +from core.joblib import SubprocessGenerator, ThisThreadGenerator +from facelib import LandmarksProcessor +from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, + SampleType) + + +class SampleGeneratorFaceTemporal(SampleGeneratorBase): + def __init__ (self, samples_path, debug, batch_size, + temporal_image_count=3, + sample_process_options=SampleProcessor.Options(), + output_sample_types=[], + generators_count=2, + **kwargs): + super().__init__(debug, batch_size) + + self.temporal_image_count = temporal_image_count + self.sample_process_options = sample_process_options + self.output_sample_types = output_sample_types + + if self.debug: + self.generators_count = 1 + else: + self.generators_count = generators_count + + samples = SampleLoader.load (SampleType.FACE_TEMPORAL_SORTED, samples_path) + samples_len = len(samples) + if samples_len == 0: + raise ValueError('No training data provided.') + + mult_max = 1 + l = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) ) + index_host = mplib.IndexHost(l+1) + + pickled_samples = pickle.dumps(samples, 4) + if self.debug: + self.generators = [ThisThreadGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),) )] + else: + self.generators = [SubprocessGenerator ( self.batch_func, (pickled_samples, index_host.create_cli(),) ) for i in range(self.generators_count) ] + + self.generator_counter = -1 + + def __iter__(self): + return self + + def __next__(self): + self.generator_counter += 1 + generator = self.generators[self.generator_counter % len(self.generators) ] + return next(generator) + + def batch_func(self, param): + mult_max = 1 + bs = self.batch_size + pickled_samples, index_host = param + samples = pickle.loads(pickled_samples) + + while True: + batches = None + + indexes = index_host.multi_get(bs) + + for n_batch in range(self.batch_size): + idx = indexes[n_batch] + + temporal_samples = [] + mult = np.random.randint(mult_max)+1 + for i in range( self.temporal_image_count ): + sample = samples[ idx+i*mult ] + try: + temporal_samples += SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)[0] + except: + raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) + + if batches is None: + batches = [ [] for _ in range(len(temporal_samples)) ] + + for i in range(len(temporal_samples)): + batches[i].append ( temporal_samples[i] ) + + yield [ np.array(batch) for batch in batches] diff --git a/samplelib/SampleGeneratorFaceTest.py b/samplelib/SampleGeneratorFaceTest.py new file mode 100644 index 0000000000000000000000000000000000000000..21f9f69ee8fe32dcbde0ce257d898bc756e4e069 --- /dev/null +++ b/samplelib/SampleGeneratorFaceTest.py @@ -0,0 +1,179 @@ +import multiprocessing +import time +import traceback + +import cv2 +import numpy as np + +from core import mplib +from core.interact import interact as io +from core.joblib import SubprocessGenerator, ThisThreadGenerator +from facelib import LandmarksProcessor +from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, + SampleType) + + +''' +arg +output_sample_types = [ + [SampleProcessor.TypeFlags, size, (optional) {} opts ] , + ... + ] +''' +class SampleGeneratorFaceTest(SampleGeneratorBase): + def __init__ (self, samples_path, debug=False, batch_size=1, + random_ct_samples_path=None, + sample_process_options=SampleProcessor.Options(), + output_sample_types=[], + uniform_yaw_distribution=False, + generators_count=4, + raise_on_no_data=True, + **kwargs): + + super().__init__(debug, batch_size) + self.initialized = False + self.sample_process_options = sample_process_options + self.output_sample_types = output_sample_types + + if self.debug: + self.generators_count = 1 + else: + self.generators_count = max(1, generators_count) + + samples = SampleLoader.load (SampleType.FACE, samples_path) + self.samples_len = len(samples) + + if self.samples_len == 0: + if raise_on_no_data: + raise ValueError('No training data provided.') + else: + return + + if uniform_yaw_distribution: + samples_pyr = [ ( idx, sample.get_pitch_yaw_roll() ) for idx, sample in enumerate(samples) ] + + grads = 128 + #instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2 + grads_space = np.linspace (-1.2, 1.2,grads) + + yaws_sample_list = [None]*grads + for g in io.progress_bar_generator ( range(grads), "Sort by yaw"): + yaw = grads_space[g] + next_yaw = grads_space[g+1] if g < grads-1 else yaw + + yaw_samples = [] + for idx, pyr in samples_pyr: + s_yaw = -pyr[1] + if (g == 0 and s_yaw < next_yaw) or \ + (g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \ + (g == grads-1 and s_yaw >= yaw): + yaw_samples += [ idx ] + if len(yaw_samples) > 0: + yaws_sample_list[g] = yaw_samples + + yaws_sample_list = [ y for y in yaws_sample_list if y is not None ] + + index_host = mplib.Index2DHost( yaws_sample_list ) + else: + index_host = mplib.IndexHost(self.samples_len) + + if random_ct_samples_path is not None: + ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) + ct_index_host = mplib.IndexHost( len(ct_samples) ) + else: + ct_samples = None + ct_index_host = None + + if self.debug: + self.generators = [ThisThreadGenerator ( self.batch_func, (samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None) )] + else: + self.generators = [] + self.comm_qs = [] + for i in range(self.generators_count): + comm_q = multiprocessing.Queue() + + gen = SubprocessGenerator ( self.batch_func, (comm_q, samples, index_host.create_cli(), ct_samples, ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) + + self.comm_qs.append(comm_q) + self.generators.append(gen) + + SubprocessGenerator.start_in_parallel( self.generators ) + + self.generator_counter = -1 + + self.initialized = True + + #overridable + def is_initialized(self): + return self.initialized + + def send_start(self): + for comm_q in self.comm_qs: + comm_q.put( ('start', 0) ) + + + def set_face_scale(self, scale): + + for comm_q in self.comm_qs: + comm_q.put( ('face_scale', scale) ) + + + def __iter__(self): + return self + + def __next__(self): + if not self.initialized: + return [] + + self.generator_counter += 1 + generator = self.generators[self.generator_counter % len(self.generators) ] + return next(generator) + + def batch_func(self, param ): + comm_q, samples, index_host, ct_samples, ct_index_host = param + + bs = self.batch_size + face_scale = 1.0 + + while True: + if not comm_q.empty(): + cmd, param = comm_q.get() + if cmd == 'face_scale': + face_scale = param + + if cmd == 'start': + break + + while True: + + while not comm_q.empty(): + cmd, param = comm_q.get() + if cmd == 'face_scale': + face_scale = param + + batches = None + + indexes = index_host.multi_get(bs) + ct_indexes = ct_index_host.multi_get(bs) if ct_samples is not None else None + + t = time.time() + for n_batch in range(bs): + sample_idx = indexes[n_batch] + sample = samples[sample_idx] + + ct_sample = None + if ct_samples is not None: + ct_sample = ct_samples[ct_indexes[n_batch]] + + try: + x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample, face_scale=face_scale) + except: + raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) + + if batches is None: + batches = [ [] for _ in range(len(x)) ] + + for i in range(len(x)): + batches[i].append ( x[i] ) + + yield [ np.array(batch) for batch in batches] diff --git a/samplelib/SampleGeneratorFaceXSeg.py b/samplelib/SampleGeneratorFaceXSeg.py new file mode 100644 index 0000000000000000000000000000000000000000..7e38e64717b69b79f677baa7b72feb19b1048f31 --- /dev/null +++ b/samplelib/SampleGeneratorFaceXSeg.py @@ -0,0 +1,297 @@ +import multiprocessing +import pickle +import time +import traceback +from enum import IntEnum + +import cv2 +import numpy as np +from pathlib import Path +from core import imagelib, mplib, pathex +from core.imagelib import sd +from core.cv2ex import * +from core.interact import interact as io +from core.joblib import Subprocessor, SubprocessGenerator, ThisThreadGenerator +from facelib import LandmarksProcessor +from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleType) + +class SampleGeneratorFaceXSeg(SampleGeneratorBase): + def __init__ (self, paths, debug=False, batch_size=1, resolution=256, face_type=None, + generators_count=4, data_format="NHWC", + **kwargs): + + super().__init__(debug, batch_size) + self.initialized = False + + samples = sum([ SampleLoader.load (SampleType.FACE, path) for path in paths ] ) + seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples).run() + + if len(seg_sample_idxs) == 0: + seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples, count_xseg_mask=True).run() + if len(seg_sample_idxs) == 0: + raise Exception(f"No segmented faces found.") + else: + io.log_info(f"Using {len(seg_sample_idxs)} xseg labeled samples.") + else: + io.log_info(f"Using {len(seg_sample_idxs)} segmented samples.") + + if self.debug: + self.generators_count = 1 + else: + self.generators_count = max(1, generators_count) + + args = (samples, seg_sample_idxs, resolution, face_type, data_format) + if self.debug: + self.generators = [ThisThreadGenerator ( self.batch_func, args )] + else: + self.generators = [SubprocessGenerator ( self.batch_func, args, start_now=False ) for i in range(self.generators_count) ] + + SubprocessGenerator.start_in_parallel( self.generators ) + + self.generator_counter = -1 + + self.initialized = True + + #overridable + def is_initialized(self): + return self.initialized + + def __iter__(self): + return self + + def __next__(self): + self.generator_counter += 1 + generator = self.generators[self.generator_counter % len(self.generators) ] + return next(generator) + + def batch_func(self, param ): + samples, seg_sample_idxs, resolution, face_type, data_format = param + + shuffle_idxs = [] + bg_shuffle_idxs = [] + + random_flip = True + rotation_range=[-10,10] + scale_range=[-0.05, 0.05] + tx_range=[-0.05, 0.05] + ty_range=[-0.05, 0.05] + + random_bilinear_resize_chance, random_bilinear_resize_max_size_per = 25,75 + sharpen_chance, sharpen_kernel_max_size = 25, 5 + motion_blur_chance, motion_blur_mb_max_size = 25, 5 + gaussian_blur_chance, gaussian_blur_kernel_max_size = 25, 5 + random_jpeg_compress_chance = 25 + + def gen_img_mask(sample): + img = sample.load_bgr() + h,w,c = img.shape + + if sample.seg_ie_polys.has_polys(): + mask = np.zeros ((h,w,1), dtype=np.float32) + sample.seg_ie_polys.overlay_mask(mask) + elif sample.has_xseg_mask(): + mask = sample.get_xseg_mask() + mask[mask < 0.5] = 0.0 + mask[mask >= 0.5] = 1.0 + else: + raise Exception(f'no mask in sample {sample.filename}') + + if face_type == sample.face_type: + if w != resolution: + img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4 ) + mask = cv2.resize( mask, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4 ) + else: + mat = LandmarksProcessor.get_transform_mat (sample.landmarks, resolution, face_type) + img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 ) + mask = cv2.warpAffine( mask, mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LANCZOS4 ) + + if len(mask.shape) == 2: + mask = mask[...,None] + return img, mask + + bs = self.batch_size + while True: + batches = [ [], [] ] + + n_batch = 0 + while n_batch < bs: + try: + if len(shuffle_idxs) == 0: + shuffle_idxs = seg_sample_idxs.copy() + np.random.shuffle(shuffle_idxs) + sample = samples[shuffle_idxs.pop()] + img, mask = gen_img_mask(sample) + + if np.random.randint(2) == 0: + if len(bg_shuffle_idxs) == 0: + bg_shuffle_idxs = seg_sample_idxs.copy() + np.random.shuffle(bg_shuffle_idxs) + bg_sample = samples[bg_shuffle_idxs.pop()] + + bg_img, bg_mask = gen_img_mask(bg_sample) + + bg_wp = imagelib.gen_warp_params(resolution, True, rotation_range=[-180,180], scale_range=[-0.10, 0.10], tx_range=[-0.10, 0.10], ty_range=[-0.10, 0.10] ) + bg_img = imagelib.warp_by_params (bg_wp, bg_img, can_warp=False, can_transform=True, can_flip=True, border_replicate=True) + bg_mask = imagelib.warp_by_params (bg_wp, bg_mask, can_warp=False, can_transform=True, can_flip=True, border_replicate=False) + bg_img = bg_img*(1-bg_mask) + if np.random.randint(2) == 0: + bg_img = imagelib.apply_random_hsv_shift(bg_img) + else: + bg_img = imagelib.apply_random_rgb_levels(bg_img) + + c_mask = 1.0 - (1-bg_mask) * (1-mask) + rnd = 0.15 + np.random.uniform()*0.85 + img = img*(c_mask) + img*(1-c_mask)*rnd + bg_img*(1-c_mask)*(1-rnd) + + warp_params = imagelib.gen_warp_params(resolution, random_flip, rotation_range=rotation_range, scale_range=scale_range, tx_range=tx_range, ty_range=ty_range ) + img = imagelib.warp_by_params (warp_params, img, can_warp=True, can_transform=True, can_flip=True, border_replicate=True) + mask = imagelib.warp_by_params (warp_params, mask, can_warp=True, can_transform=True, can_flip=True, border_replicate=False) + + img = np.clip(img.astype(np.float32), 0, 1) + mask[mask < 0.5] = 0.0 + mask[mask >= 0.5] = 1.0 + mask = np.clip(mask, 0, 1) + + if np.random.randint(2) == 0: + # random face flare + krn = np.random.randint( resolution//4, resolution ) + krn = krn - krn % 2 + 1 + img = img + cv2.GaussianBlur(img*mask, (krn,krn), 0) + + if np.random.randint(2) == 0: + # random bg flare + krn = np.random.randint( resolution//4, resolution ) + krn = krn - krn % 2 + 1 + img = img + cv2.GaussianBlur(img*(1-mask), (krn,krn), 0) + + if np.random.randint(2) == 0: + img = imagelib.apply_random_hsv_shift(img, mask=sd.random_circle_faded ([resolution,resolution])) + else: + img = imagelib.apply_random_rgb_levels(img, mask=sd.random_circle_faded ([resolution,resolution])) + + if np.random.randint(2) == 0: + img = imagelib.apply_random_sharpen( img, sharpen_chance, sharpen_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution])) + else: + img = imagelib.apply_random_motion_blur( img, motion_blur_chance, motion_blur_mb_max_size, mask=sd.random_circle_faded ([resolution,resolution])) + img = imagelib.apply_random_gaussian_blur( img, gaussian_blur_chance, gaussian_blur_kernel_max_size, mask=sd.random_circle_faded ([resolution,resolution])) + + if np.random.randint(2) == 0: + img = imagelib.apply_random_nearest_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution])) + else: + img = imagelib.apply_random_bilinear_resize( img, random_bilinear_resize_chance, random_bilinear_resize_max_size_per, mask=sd.random_circle_faded ([resolution,resolution])) + img = np.clip(img, 0, 1) + + img = imagelib.apply_random_jpeg_compress( img, random_jpeg_compress_chance, mask=sd.random_circle_faded ([resolution,resolution])) + + if data_format == "NCHW": + img = np.transpose(img, (2,0,1) ) + mask = np.transpose(mask, (2,0,1) ) + + batches[0].append ( img ) + batches[1].append ( mask ) + + n_batch += 1 + except: + io.log_err ( traceback.format_exc() ) + + yield [ np.array(batch) for batch in batches] + +class SegmentedSampleFilterSubprocessor(Subprocessor): + #override + def __init__(self, samples, count_xseg_mask=False ): + self.samples = samples + self.samples_len = len(self.samples) + self.count_xseg_mask = count_xseg_mask + + self.idxs = [*range(self.samples_len)] + self.result = [] + super().__init__('SegmentedSampleFilterSubprocessor', SegmentedSampleFilterSubprocessor.Cli, 60) + + #override + def process_info_generator(self): + for i in range(multiprocessing.cpu_count()): + yield 'CPU%d' % (i), {}, {'samples':self.samples, 'count_xseg_mask':self.count_xseg_mask} + + #override + def on_clients_initialized(self): + io.progress_bar ("Filtering", self.samples_len) + + #override + def on_clients_finalized(self): + io.progress_bar_close() + + #override + def get_data(self, host_dict): + if len (self.idxs) > 0: + return self.idxs.pop(0) + + return None + + #override + def on_data_return (self, host_dict, data): + self.idxs.insert(0, data) + + #override + def on_result (self, host_dict, data, result): + idx, is_ok = result + if is_ok: + self.result.append(idx) + io.progress_bar_inc(1) + def get_result(self): + return self.result + + class Cli(Subprocessor.Cli): + #overridable optional + def on_initialize(self, client_dict): + self.samples = client_dict['samples'] + self.count_xseg_mask = client_dict['count_xseg_mask'] + + def process_data(self, idx): + if self.count_xseg_mask: + return idx, self.samples[idx].has_xseg_mask() + else: + return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0 + +""" + bg_path = None + for path in paths: + bg_path = Path(path) / 'backgrounds' + if bg_path.exists(): + + break + if bg_path is None: + io.log_info(f'Random backgrounds will not be used. Place no face jpg images to aligned\backgrounds folder. ') + bg_pathes = None + else: + bg_pathes = pathex.get_image_paths(bg_path, image_extensions=['.jpg'], return_Path_class=True) + io.log_info(f'Using {len(bg_pathes)} random backgrounds from {bg_path}') + +if bg_pathes is not None: + bg_path = bg_pathes[ np.random.randint(len(bg_pathes)) ] + + bg_img = cv2_imread(bg_path) + if bg_img is not None: + bg_img = bg_img.astype(np.float32) / 255.0 + bg_img = imagelib.normalize_channels(bg_img, 3) + + bg_img = imagelib.random_crop(bg_img, resolution, resolution) + bg_img = cv2.resize(bg_img, (resolution, resolution), interpolation=cv2.INTER_LINEAR) + + if np.random.randint(2) == 0: + bg_img = imagelib.apply_random_hsv_shift(bg_img) + else: + bg_img = imagelib.apply_random_rgb_levels(bg_img) + + bg_wp = imagelib.gen_warp_params(resolution, True, rotation_range=[-180,180], scale_range=[0,0], tx_range=[0,0], ty_range=[0,0]) + bg_img = imagelib.warp_by_params (bg_wp, bg_img, can_warp=False, can_transform=True, can_flip=True, border_replicate=True) + + bg = img*(1-mask) + fg = img*mask + + c_mask = sd.random_circle_faded ([resolution,resolution]) + bg = ( bg_img*c_mask + bg*(1-c_mask) )*(1-mask) + + img = fg+bg + + else: +""" \ No newline at end of file diff --git a/samplelib/SampleGeneratorImage.py b/samplelib/SampleGeneratorImage.py new file mode 100644 index 0000000000000000000000000000000000000000..6e4df392f80a933347ac7612f5359a3ada12380e --- /dev/null +++ b/samplelib/SampleGeneratorImage.py @@ -0,0 +1,66 @@ +import traceback + +import cv2 +import numpy as np + +from core.joblib import SubprocessGenerator, ThisThreadGenerator +from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, + SampleType) + + +class SampleGeneratorImage(SampleGeneratorBase): + def __init__ (self, samples_path, debug, batch_size, sample_process_options=SampleProcessor.Options(), output_sample_types=[], raise_on_no_data=True, **kwargs): + super().__init__(debug, batch_size) + self.initialized = False + self.sample_process_options = sample_process_options + self.output_sample_types = output_sample_types + + samples = SampleLoader.load (SampleType.IMAGE, samples_path) + + if len(samples) == 0: + if raise_on_no_data: + raise ValueError('No training data provided.') + return + + self.generators = [ThisThreadGenerator ( self.batch_func, samples )] if self.debug else \ + [SubprocessGenerator ( self.batch_func, samples )] + + self.generator_counter = -1 + self.initialized = True + + def __iter__(self): + return self + + def __next__(self): + self.generator_counter += 1 + generator = self.generators[self.generator_counter % len(self.generators) ] + return next(generator) + + def batch_func(self, samples): + samples_len = len(samples) + + + idxs = [ *range(samples_len) ] + shuffle_idxs = [] + + while True: + + batches = None + for n_batch in range(self.batch_size): + + if len(shuffle_idxs) == 0: + shuffle_idxs = idxs.copy() + np.random.shuffle (shuffle_idxs) + + idx = shuffle_idxs.pop() + sample = samples[idx] + + x, = SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug) + + if batches is None: + batches = [ [] for _ in range(len(x)) ] + + for i in range(len(x)): + batches[i].append ( x[i] ) + + yield [ np.array(batch) for batch in batches] diff --git a/samplelib/SampleGeneratorImageTemporal.py b/samplelib/SampleGeneratorImageTemporal.py new file mode 100644 index 0000000000000000000000000000000000000000..ed3ab4ad1aa85c8d6b76bc37ee3392f021eadfa5 --- /dev/null +++ b/samplelib/SampleGeneratorImageTemporal.py @@ -0,0 +1,81 @@ +import traceback + +import cv2 +import numpy as np + +from core.joblib import SubprocessGenerator, ThisThreadGenerator +from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, + SampleType) + + +''' +output_sample_types = [ + [SampleProcessor.TypeFlags, size, (optional)random_sub_size] , + ... + ] +''' +class SampleGeneratorImageTemporal(SampleGeneratorBase): + def __init__ (self, samples_path, debug, batch_size, temporal_image_count, sample_process_options=SampleProcessor.Options(), output_sample_types=[], **kwargs): + super().__init__(debug, batch_size) + + self.temporal_image_count = temporal_image_count + self.sample_process_options = sample_process_options + self.output_sample_types = output_sample_types + + self.samples = SampleLoader.load (SampleType.IMAGE, samples_path) + + self.generator_samples = [ self.samples ] + self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )] if self.debug else \ + [iter_utils.SubprocessGenerator ( self.batch_func, 0 )] + + self.generator_counter = -1 + + def __iter__(self): + return self + + def __next__(self): + self.generator_counter += 1 + generator = self.generators[self.generator_counter % len(self.generators) ] + return next(generator) + + def batch_func(self, generator_id): + samples = self.generator_samples[generator_id] + samples_len = len(samples) + if samples_len == 0: + raise ValueError('No training data provided.') + + mult_max = 4 + samples_sub_len = samples_len - ( (self.temporal_image_count)*mult_max - (mult_max-1) ) + + if samples_sub_len <= 0: + raise ValueError('Not enough samples to fit temporal line.') + + shuffle_idxs = [] + + while True: + + batches = None + for n_batch in range(self.batch_size): + + if len(shuffle_idxs) == 0: + shuffle_idxs = [ *range(samples_sub_len) ] + np.random.shuffle (shuffle_idxs) + + idx = shuffle_idxs.pop() + + temporal_samples = [] + mult = np.random.randint(mult_max)+1 + for i in range( self.temporal_image_count ): + sample = samples[ idx+i*mult ] + try: + temporal_samples += SampleProcessor.process ([sample], self.sample_process_options, self.output_sample_types, self.debug)[0] + except: + raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) ) + + if batches is None: + batches = [ [] for _ in range(len(temporal_samples)) ] + + for i in range(len(temporal_samples)): + batches[i].append ( temporal_samples[i] ) + + yield [ np.array(batch) for batch in batches] diff --git a/samplelib/SampleGeneratorSAE.py b/samplelib/SampleGeneratorSAE.py new file mode 100644 index 0000000000000000000000000000000000000000..867ff812860ff813ce3220b95867a10f44e6b208 --- /dev/null +++ b/samplelib/SampleGeneratorSAE.py @@ -0,0 +1,297 @@ +import multiprocessing +import time +import traceback + +import cv2 +import numpy as np +import numpy.linalg as npla + +from core import mplib +from core import imagelib +from core.interact import interact as io +from core.joblib import SubprocessGenerator, ThisThreadGenerator +from core import mathlib +from facelib import LandmarksProcessor, FaceType +from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, + SampleType) + +class SampleGeneratorSAE(SampleGeneratorBase): + def __init__ (self, src_samples_path, dst_samples_path, + resolution, + face_type, + random_src_flip=False, + random_dst_flip=False, + ct_mode=None, + uniform_yaw_distribution=False, + data_format='NHWC', + debug=False, batch_size=1, + raise_on_no_data=True, + **kwargs): + + super().__init__(debug, batch_size) + self.initialized = False + self.resolution = resolution + self.face_type = face_type + self.random_src_flip = random_src_flip + self.random_dst_flip = random_dst_flip + self.ct_mode = ct_mode + self.data_format = data_format + + if self.debug: + self.generators_count = 1 + else: + self.generators_count = 8 + + src_samples = SampleLoader.load (SampleType.FACE, src_samples_path) + src_samples_len = len(src_samples) + + if src_samples_len == 0: + raise ValueError(f'No samples in {src_samples_path}') + + dst_samples = SampleLoader.load (SampleType.FACE, dst_samples_path) + dst_samples_len = len(dst_samples) + + if dst_samples_len == 0: + raise ValueError(f'No samples in {dst_samples_path}') + + if uniform_yaw_distribution: + src_index_host = self._filter_uniform_yaw(src_samples) + dst_index_host = self._filter_uniform_yaw(dst_samples) + else: + src_index_host = mplib.IndexHost(src_samples_len) + dst_index_host = mplib.IndexHost(dst_samples_len) + + ct_index_host = mplib.IndexHost(dst_samples_len) if ct_mode is not None else None + + self.comm_qs = [ multiprocessing.Queue() for i in range(self.generators_count) ] + + if self.debug: + self.generators = [ThisThreadGenerator ( self.batch_func, (self.comm_qs[0], src_samples, dst_samples, src_index_host.create_cli(), dst_index_host.create_cli(), ct_index_host.create_cli() if ct_index_host is not None else None) )] + else: + self.generators = [SubprocessGenerator ( self.batch_func, (self.comm_qs[i], src_samples, dst_samples, src_index_host.create_cli(), dst_index_host.create_cli(), ct_index_host.create_cli() if ct_index_host is not None else None), start_now=False ) \ + for i in range(self.generators_count) ] + + self.generator_counter = -1 + + self.initialized = True + + def start(self): + if not self.debug: + SubprocessGenerator.start_in_parallel( self.generators ) + + def _filter_uniform_yaw(self, samples): + samples_pyr = [ ( idx, sample.get_pitch_yaw_roll() ) for idx, sample in enumerate(samples) ] + + grads = 128 + #instead of math.pi / 2, using -1.2,+1.2 because actually maximum yaw for 2DFAN landmarks are -1.2+1.2 + grads_space = np.linspace (-1.2, 1.2,grads) + + yaws_sample_list = [None]*grads + for g in io.progress_bar_generator ( range(grads), "Sort by yaw"): + yaw = grads_space[g] + next_yaw = grads_space[g+1] if g < grads-1 else yaw + + yaw_samples = [] + for idx, pyr in samples_pyr: + s_yaw = -pyr[1] + if (g == 0 and s_yaw < next_yaw) or \ + (g < grads-1 and s_yaw >= yaw and s_yaw < next_yaw) or \ + (g == grads-1 and s_yaw >= yaw): + yaw_samples += [ idx ] + if len(yaw_samples) > 0: + yaws_sample_list[g] = yaw_samples + + yaws_sample_list = [ y for y in yaws_sample_list if y is not None ] + + return mplib.Index2DHost( yaws_sample_list ) + + def set_face_scale(self, scale): + for comm_q in self.comm_qs: + comm_q.put( ('face_scale', scale) ) + + + #overridable + def is_initialized(self): + return self.initialized + + def __iter__(self): + return self + + def __next__(self): + if not self.initialized: + return [] + + self.generator_counter += 1 + generator = self.generators[self.generator_counter % len(self.generators) ] + return next(generator) + + def batch_func(self, param ): + comm_q, src_samples, dst_samples, src_index_host, dst_index_host, ct_index_host = param + + batch_size = self.batch_size + resolution = self.resolution + face_type = self.face_type + data_format = self.data_format + random_src_flip = self.random_src_flip + random_dst_flip = self.random_dst_flip + ct_mode = self.ct_mode + + rotation_range=[-10,10] + scale_range=[-0.05, 0.05] + tx_range=[-0.05, 0.05] + ty_range=[-0.05, 0.05] + rnd_state = np.random + + face_scale = 1.0 + + hi_res = 1024 + + def gen_sample(sample, target_face_type, resolution, allow_flip=False, scale=1.0, ct_mode=None, ct_sample=None):#:, tx, ty, rotation, scale): + tx = rnd_state.uniform( tx_range[0], tx_range[1] ) + ty = rnd_state.uniform( ty_range[0], ty_range[1] ) + rotation = rnd_state.uniform( rotation_range[0], rotation_range[1] ) + scale = rnd_state.uniform(scale +scale_range[0], scale +scale_range[1]) + + flip = allow_flip and rnd_state.randint(10) < 4 + + face_type = sample.face_type + face_lmrks = sample.landmarks + face = sample.load_bgr() + h,w,c = face.shape + + if face_type == FaceType.HEAD: + hi_mat = LandmarksProcessor.get_transform_mat (face_lmrks, hi_res, FaceType.HEAD) + else: + hi_mat = LandmarksProcessor.get_transform_mat (face_lmrks, hi_res, FaceType.HEAD_FACE) + + hi_lmrks = LandmarksProcessor.transform_points(face_lmrks, hi_mat) + hi_warp_params = imagelib.gen_warp_params(hi_res) + face_warp_params = imagelib.gen_warp_params(resolution) + + hi_to_target_mat = LandmarksProcessor.get_transform_mat (hi_lmrks, resolution, target_face_type) + hi_to_target_mat = mathlib.transform_mat(hi_to_target_mat, resolution, tx, ty, rotation, scale) + + face_to_target_mat = LandmarksProcessor.get_transform_mat (face_lmrks, resolution, target_face_type) + face_to_target_mat = mathlib.transform_mat(face_to_target_mat, resolution, tx, ty, rotation, scale) + + warped_face = face + if ct_mode is not None: + ct_bgr = ct_sample.load_bgr() + ct_bgr = cv2.resize(ct_bgr, (w,h), interpolation=cv2.INTER_LINEAR ) + warped_face = imagelib.color_transfer (ct_mode, warped_face, ct_bgr) + + warped_face = cv2.warpAffine(warped_face, hi_mat, (hi_res,hi_res), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) + warped_face = np.clip( imagelib.warp_by_params (hi_warp_params, warped_face, can_warp=True, can_transform=False, can_flip=False, border_replicate=cv2.BORDER_REPLICATE), 0, 1) + warped_face = cv2.warpAffine(warped_face, hi_to_target_mat, (resolution,resolution), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) + + """ + if face_type != target_face_type: + ... + else: + if w != resolution: + face = cv2.resize(face, (resolution, resolution), interpolation=cv2.INTER_CUBIC ) + """ + + # warped_face = cv2.warpAffine(warped_face, face_to_target_mat, (resolution,resolution), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) + # warped_face = np.clip( imagelib.warp_by_params (face_warp_params, warped_face, can_warp=True, can_transform=False, can_flip=False, border_replicate=cv2.BORDER_REPLICATE), 0, 1) + + target_face = face + if ct_mode is not None: + target_face = imagelib.color_transfer (ct_mode, target_face, ct_bgr) + + target_face = cv2.warpAffine(target_face, face_to_target_mat, (resolution,resolution), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC ) + + + face_mask = sample.get_xseg_mask() + if face_mask is not None: + if face_mask.shape[0] != h or face_mask.shape[1] != w: + face_mask = cv2.resize(face_mask, (w,h), interpolation=cv2.INTER_CUBIC) + face_mask = imagelib.normalize_channels(face_mask, 1) + else: + face_mask = LandmarksProcessor.get_image_hull_mask (face.shape, face_lmrks, eyebrows_expand_mod=sample.eyebrows_expand_mod ) + face_mask = np.clip(face_mask, 0, 1) + + target_face_mask = cv2.warpAffine(face_mask, face_to_target_mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LINEAR ) + target_face_mask = imagelib.normalize_channels(target_face_mask, 1) + target_face_mask = np.clip(target_face_mask, 0, 1) + + em_mask = np.clip(LandmarksProcessor.get_image_eye_mask (face.shape, face_lmrks) + \ + LandmarksProcessor.get_image_mouth_mask (face.shape, face_lmrks), 0, 1) + + target_face_em = cv2.warpAffine(em_mask, face_to_target_mat, (resolution,resolution), borderMode=cv2.BORDER_CONSTANT, flags=cv2.INTER_LINEAR ) + target_face_em = imagelib.normalize_channels(target_face_em, 1) + + div = target_face_em.max() + if div != 0.0: + target_face_em = target_face_em / div + + target_face_em = target_face_em * target_face_mask + + # while True: + # cv2.imshow('', warped_face) + # cv2.waitKey(0) + + # cv2.imshow('', target_face) + # cv2.waitKey(0) + + # cv2.imshow('', target_face_mask) + # cv2.waitKey(0) + + # cv2.imshow('', target_face_em) + # cv2.waitKey(0) + # import code + # code.interact(local=dict(globals(), **locals())) + + if flip: + warped_face = warped_face[:,::-1,...] + target_face = target_face[:,::-1,...] + target_face_mask = target_face_mask[:,::-1,...] + target_face_em = target_face_em[:,::-1,...] + + return warped_face, target_face, target_face_mask, target_face_em + + + while True: + while not comm_q.empty(): + cmd, param = comm_q.get() + if cmd == 'face_scale': + face_scale = param + + batches = [ [], [], [], [], [], [] ,[] ,[] ] # + + src_indexes = src_index_host.multi_get(batch_size) + dst_indexes = dst_index_host.multi_get(batch_size) + + for n_batch in range(batch_size): + src_sample = src_samples[src_indexes[n_batch]] + dst_sample = dst_samples[dst_indexes[n_batch]] + + src_warped_face, src_target_face, src_target_face_mask, src_target_face_em = \ + gen_sample(src_sample, face_type, resolution, allow_flip=random_src_flip, scale=face_scale, ct_mode=ct_mode, ct_sample=dst_sample) + + dst_warped_face, dst_target_face, dst_target_face_mask, dst_target_face_em = \ + gen_sample(dst_sample, face_type, resolution, allow_flip=random_dst_flip, scale=face_scale) + + + + if data_format == "NCHW": + src_warped_face = np.transpose(src_warped_face, (2,0,1) ) + src_target_face = np.transpose(src_target_face, (2,0,1) ) + src_target_face_mask = np.transpose(src_target_face_mask, (2,0,1) ) + src_target_face_em = np.transpose(src_target_face_em, (2,0,1) ) + dst_warped_face = np.transpose(dst_warped_face, (2,0,1) ) + dst_target_face = np.transpose(dst_target_face, (2,0,1) ) + dst_target_face_mask = np.transpose(dst_target_face_mask, (2,0,1) ) + dst_target_face_em = np.transpose(dst_target_face_em, (2,0,1) ) + + batches[0].append(src_warped_face) + batches[1].append(src_target_face) + batches[2].append(src_target_face_mask) + batches[3].append(src_target_face_em) + batches[4].append(dst_warped_face) + batches[5].append(dst_target_face) + batches[6].append(dst_target_face_mask) + batches[7].append(dst_target_face_em) + + + yield [ np.array(batch) for batch in batches] diff --git a/samplelib/SampleLoader.py b/samplelib/SampleLoader.py new file mode 100644 index 0000000000000000000000000000000000000000..298935453f4037843b8790536abfac078c3a212f --- /dev/null +++ b/samplelib/SampleLoader.py @@ -0,0 +1,175 @@ +import multiprocessing +import operator +import pickle +import traceback +from pathlib import Path + +import samplelib.PackedFaceset +from core import pathex +from core.mplib import MPSharedList +from core.interact import interact as io +from core.joblib import Subprocessor +from DFLIMG import * +from facelib import FaceType, LandmarksProcessor + +from .Sample import Sample, SampleType + + +class SampleLoader: + samples_cache = dict() + @staticmethod + def get_person_id_max_count(samples_path): + samples = None + try: + samples = samplelib.PackedFaceset.load(samples_path) + except: + io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_path)}, {traceback.format_exc()}") + + if samples is None: + raise ValueError("packed faceset not found.") + persons_name_idxs = {} + for sample in samples: + persons_name_idxs[sample.person_name] = 0 + return len(list(persons_name_idxs.keys())) + + @staticmethod + def load(sample_type, samples_path, subdirs=False): + """ + Return MPSharedList of samples + """ + samples_cache = SampleLoader.samples_cache + + if str(samples_path) not in samples_cache.keys(): + samples_cache[str(samples_path)] = [None]*SampleType.QTY + + samples = samples_cache[str(samples_path)] + + if sample_type == SampleType.IMAGE: + if samples[sample_type] is None: + samples[sample_type] = [ Sample(filename=filename) for filename in io.progress_bar_generator( pathex.get_image_paths(samples_path, subdirs=subdirs), "Loading") ] + + elif sample_type == SampleType.FACE: + if samples[sample_type] is None: + try: + result = samplelib.PackedFaceset.load(samples_path) + except: + io.log_err(f"Error occured while loading samplelib.PackedFaceset.load {str(samples_dat_path)}, {traceback.format_exc()}") + + if result is not None: + io.log_info (f"Loaded {len(result)} packed faces from {samples_path}") + + if result is None: + result = SampleLoader.load_face_samples( pathex.get_image_paths(samples_path, subdirs=subdirs) ) + + samples[sample_type] = MPSharedList(result) + elif sample_type == SampleType.FACE_TEMPORAL_SORTED: + result = SampleLoader.load (SampleType.FACE, samples_path) + result = SampleLoader.upgradeToFaceTemporalSortedSamples(result) + samples[sample_type] = MPSharedList(result) + + return samples[sample_type] + + @staticmethod + def load_face_samples ( image_paths): + result = FaceSamplesLoaderSubprocessor(image_paths).run() + sample_list = [] + + for filename, data in result: + if data is None: + continue + ( face_type, + shape, + landmarks, + seg_ie_polys, + xseg_mask_compressed, + eyebrows_expand_mod, + source_filename ) = data + + sample_list.append( Sample(filename=filename, + sample_type=SampleType.FACE, + face_type=FaceType.fromString (face_type), + shape=shape, + landmarks=landmarks, + seg_ie_polys=seg_ie_polys, + xseg_mask_compressed=xseg_mask_compressed, + eyebrows_expand_mod=eyebrows_expand_mod, + source_filename=source_filename, + )) + return sample_list + + @staticmethod + def upgradeToFaceTemporalSortedSamples( samples ): + new_s = [ (s, s.source_filename) for s in samples] + new_s = sorted(new_s, key=operator.itemgetter(1)) + + return [ s[0] for s in new_s] + + +class FaceSamplesLoaderSubprocessor(Subprocessor): + #override + def __init__(self, image_paths ): + self.image_paths = image_paths + self.image_paths_len = len(image_paths) + self.idxs = [*range(self.image_paths_len)] + self.result = [None]*self.image_paths_len + super().__init__('FaceSamplesLoader', FaceSamplesLoaderSubprocessor.Cli, 60) + + #override + def on_clients_initialized(self): + io.progress_bar ("Loading samples", len (self.image_paths)) + + #override + def on_clients_finalized(self): + io.progress_bar_close() + + #override + def process_info_generator(self): + for i in range(min(multiprocessing.cpu_count(), 8) ): + yield 'CPU%d' % (i), {}, {} + + #override + def get_data(self, host_dict): + if len (self.idxs) > 0: + idx = self.idxs.pop(0) + return idx, self.image_paths[idx] + + return None + + #override + def on_data_return (self, host_dict, data): + self.idxs.insert(0, data[0]) + + #override + def on_result (self, host_dict, data, result): + idx, dflimg = result + self.result[idx] = (self.image_paths[idx], dflimg) + io.progress_bar_inc(1) + + #override + def get_result(self): + return self.result + + class Cli(Subprocessor.Cli): + #override + def process_data(self, data): + idx, filename = data + dflimg = DFLIMG.load (Path(filename)) + + if dflimg is None or not dflimg.has_data(): + self.log_err (f"FaceSamplesLoader: {filename} is not a dfl image file.") + data = None + else: + data = (dflimg.get_face_type(), + dflimg.get_shape(), + dflimg.get_landmarks(), + dflimg.get_seg_ie_polys(), + dflimg.get_xseg_mask_compressed(), + dflimg.get_eyebrows_expand_mod(), + dflimg.get_source_filename() ) + + return idx, data + + #override + def get_data_name (self, data): + #return string identificator of your data + return data[1] diff --git a/samplelib/SampleProcessor.py b/samplelib/SampleProcessor.py new file mode 100644 index 0000000000000000000000000000000000000000..7432e75d547de3ef37967e3249d94e3949cba88b --- /dev/null +++ b/samplelib/SampleProcessor.py @@ -0,0 +1,258 @@ +import collections +import math +from enum import IntEnum + +import cv2 +import numpy as np + +from core import imagelib +from core.cv2ex import * +from core.imagelib import sd +from facelib import FaceType, LandmarksProcessor + + +class SampleProcessor(object): + class SampleType(IntEnum): + NONE = 0 + IMAGE = 1 + FACE_IMAGE = 2 + FACE_MASK = 3 + LANDMARKS_ARRAY = 4 + PITCH_YAW_ROLL = 5 + PITCH_YAW_ROLL_SIGMOID = 6 + + class ChannelType(IntEnum): + NONE = 0 + BGR = 1 #BGR + G = 2 #Grayscale + GGG = 3 #3xGrayscale + + class FaceMaskType(IntEnum): + NONE = 0 + FULL_FACE = 1 # mask all hull as grayscale + EYES = 2 # mask eyes hull as grayscale + EYES_MOUTH = 3 # eyes and mouse + + class Options(object): + def __init__(self, random_flip = True, rotation_range=[-10,10], scale_range=[-0.05, 0.05], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05] ): + self.random_flip = random_flip + self.rotation_range = rotation_range + self.scale_range = scale_range + self.tx_range = tx_range + self.ty_range = ty_range + + @staticmethod + def process (samples, sample_process_options, output_sample_types, debug, ct_sample=None): + SPST = SampleProcessor.SampleType + SPCT = SampleProcessor.ChannelType + SPFMT = SampleProcessor.FaceMaskType + + + outputs = [] + for sample in samples: + sample_rnd_seed = np.random.randint(0x80000000) + + sample_face_type = sample.face_type + sample_bgr = sample.load_bgr() + sample_landmarks = sample.landmarks + ct_sample_bgr = None + h,w,c = sample_bgr.shape + + def get_full_face_mask(): + xseg_mask = sample.get_xseg_mask() + if xseg_mask is not None: + if xseg_mask.shape[0] != h or xseg_mask.shape[1] != w: + xseg_mask = cv2.resize(xseg_mask, (w,h), interpolation=cv2.INTER_CUBIC) + xseg_mask = imagelib.normalize_channels(xseg_mask, 1) + return np.clip(xseg_mask, 0, 1) + else: + full_face_mask = LandmarksProcessor.get_image_hull_mask (sample_bgr.shape, sample_landmarks, eyebrows_expand_mod=sample.eyebrows_expand_mod ) + return np.clip(full_face_mask, 0, 1) + + def get_eyes_mask(): + eyes_mask = LandmarksProcessor.get_image_eye_mask (sample_bgr.shape, sample_landmarks) + return np.clip(eyes_mask, 0, 1) + + def get_eyes_mouth_mask(): + eyes_mask = LandmarksProcessor.get_image_eye_mask (sample_bgr.shape, sample_landmarks) + mouth_mask = LandmarksProcessor.get_image_mouth_mask (sample_bgr.shape, sample_landmarks) + mask = eyes_mask + mouth_mask + return np.clip(mask, 0, 1) + + is_face_sample = sample_landmarks is not None + + if debug and is_face_sample: + LandmarksProcessor.draw_landmarks (sample_bgr, sample_landmarks, (0, 1, 0)) + + outputs_sample = [] + for opts in output_sample_types: + resolution = opts.get('resolution', 0) + sample_type = opts.get('sample_type', SPST.NONE) + channel_type = opts.get('channel_type', SPCT.NONE) + nearest_resize_to = opts.get('nearest_resize_to', None) + warp = opts.get('warp', False) + transform = opts.get('transform', False) + random_hsv_shift_amount = opts.get('random_hsv_shift_amount', 0) + normalize_tanh = opts.get('normalize_tanh', False) + ct_mode = opts.get('ct_mode', None) + data_format = opts.get('data_format', 'NHWC') + + rnd_seed_shift = opts.get('rnd_seed_shift', 0) + warp_rnd_seed_shift = opts.get('warp_rnd_seed_shift', rnd_seed_shift) + + rnd_state = np.random.RandomState (sample_rnd_seed+rnd_seed_shift) + warp_rnd_state = np.random.RandomState (sample_rnd_seed+warp_rnd_seed_shift) + + warp_params = imagelib.gen_warp_params(resolution, + sample_process_options.random_flip, + rotation_range=sample_process_options.rotation_range, + scale_range=sample_process_options.scale_range, + tx_range=sample_process_options.tx_range, + ty_range=sample_process_options.ty_range, + rnd_state=rnd_state, + warp_rnd_state=warp_rnd_state, + ) + + if sample_type == SPST.FACE_MASK or sample_type == SPST.IMAGE: + border_replicate = False + elif sample_type == SPST.FACE_IMAGE: + border_replicate = True + + + border_replicate = opts.get('border_replicate', border_replicate) + borderMode = cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT + + + if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK: + if not is_face_sample: + raise ValueError("face_samples should be provided for sample_type FACE_*") + + if sample_type == SPST.FACE_IMAGE or sample_type == SPST.FACE_MASK: + face_type = opts.get('face_type', None) + face_mask_type = opts.get('face_mask_type', SPFMT.NONE) + + if face_type is None: + raise ValueError("face_type must be defined for face samples") + + if sample_type == SPST.FACE_MASK: + if face_mask_type == SPFMT.FULL_FACE: + img = get_full_face_mask() + elif face_mask_type == SPFMT.EYES: + img = get_eyes_mask() + elif face_mask_type == SPFMT.EYES_MOUTH: + mask = get_full_face_mask().copy() + mask[mask != 0.0] = 1.0 + img = get_eyes_mouth_mask()*mask + else: + img = np.zeros ( sample_bgr.shape[0:2]+(1,), dtype=np.float32) + + if sample_face_type == FaceType.MARK_ONLY: + raise NotImplementedError() + mat = LandmarksProcessor.get_transform_mat (sample_landmarks, warp_resolution, face_type) + img = cv2.warpAffine( img, mat, (warp_resolution, warp_resolution), flags=cv2.INTER_LINEAR ) + + img = imagelib.warp_by_params (warp_params, img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR) + img = cv2.resize( img, (resolution,resolution), interpolation=cv2.INTER_LINEAR ) + else: + if face_type != sample_face_type: + mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type) + img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_LINEAR ) + else: + if w != resolution: + img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_LINEAR ) + + img = imagelib.warp_by_params (warp_params, img, warp, transform, can_flip=True, border_replicate=border_replicate, cv2_inter=cv2.INTER_LINEAR) + + if face_mask_type == SPFMT.EYES_MOUTH: + div = img.max() + if div != 0.0: + img = img / div # normalize to 1.0 after warp + + if len(img.shape) == 2: + img = img[...,None] + + if channel_type == SPCT.G: + out_sample = img.astype(np.float32) + else: + raise ValueError("only channel_type.G supported for the mask") + + elif sample_type == SPST.FACE_IMAGE: + img = sample_bgr + + if face_type != sample_face_type: + mat = LandmarksProcessor.get_transform_mat (sample_landmarks, resolution, face_type) + img = cv2.warpAffine( img, mat, (resolution,resolution), borderMode=borderMode, flags=cv2.INTER_CUBIC ) + else: + if w != resolution: + img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_CUBIC ) + + # Apply random color transfer + if ct_mode is not None and ct_sample is not None: + if ct_sample_bgr is None: + ct_sample_bgr = ct_sample.load_bgr() + img = imagelib.color_transfer (ct_mode, img, cv2.resize( ct_sample_bgr, (resolution,resolution), interpolation=cv2.INTER_LINEAR ) ) + + if random_hsv_shift_amount != 0: + a = random_hsv_shift_amount + h_amount = max(1, int(360*a*0.5)) + img_h, img_s, img_v = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) + img_h = (img_h + rnd_state.randint(-h_amount, h_amount+1) ) % 360 + img_s = np.clip (img_s + (rnd_state.random()-0.5)*a, 0, 1 ) + img_v = np.clip (img_v + (rnd_state.random()-0.5)*a, 0, 1 ) + img = np.clip( cv2.cvtColor(cv2.merge([img_h, img_s, img_v]), cv2.COLOR_HSV2BGR) , 0, 1 ) + + img = imagelib.warp_by_params (warp_params, img, warp, transform, can_flip=True, border_replicate=border_replicate) + + img = np.clip(img.astype(np.float32), 0, 1) + + # Transform from BGR to desired channel_type + if channel_type == SPCT.BGR: + out_sample = img + elif channel_type == SPCT.G: + out_sample = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[...,None] + elif channel_type == SPCT.GGG: + out_sample = np.repeat ( np.expand_dims(cv2.cvtColor(img, cv2.COLOR_BGR2GRAY),-1), (3,), -1) + + # Final transformations + if nearest_resize_to is not None: + out_sample = cv2_resize(out_sample, (nearest_resize_to,nearest_resize_to), interpolation=cv2.INTER_NEAREST) + + if not debug: + if normalize_tanh: + out_sample = np.clip (out_sample * 2.0 - 1.0, -1.0, 1.0) + if data_format == "NCHW": + out_sample = np.transpose(out_sample, (2,0,1) ) + elif sample_type == SPST.IMAGE: + img = sample_bgr + img = imagelib.warp_by_params (warp_params, img, warp, transform, can_flip=True, border_replicate=True) + img = cv2.resize( img, (resolution, resolution), interpolation=cv2.INTER_CUBIC ) + out_sample = img + + if data_format == "NCHW": + out_sample = np.transpose(out_sample, (2,0,1) ) + + + elif sample_type == SPST.LANDMARKS_ARRAY: + l = sample_landmarks + l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 ) + l = np.clip(l, 0.0, 1.0) + out_sample = l + elif sample_type == SPST.PITCH_YAW_ROLL or sample_type == SPST.PITCH_YAW_ROLL_SIGMOID: + pitch,yaw,roll = sample.get_pitch_yaw_roll() + if warp_params['flip']: + yaw = -yaw + + if sample_type == SPST.PITCH_YAW_ROLL_SIGMOID: + pitch = np.clip( (pitch / math.pi) / 2.0 + 0.5, 0, 1) + yaw = np.clip( (yaw / math.pi) / 2.0 + 0.5, 0, 1) + roll = np.clip( (roll / math.pi) / 2.0 + 0.5, 0, 1) + + out_sample = (pitch, yaw) + else: + raise ValueError ('expected sample_type') + + outputs_sample.append ( out_sample ) + outputs += [outputs_sample] + + return outputs + diff --git a/samplelib/__init__.py b/samplelib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c140ffc29d6df371cf8b703eda30136c6cc88d3 --- /dev/null +++ b/samplelib/__init__.py @@ -0,0 +1,13 @@ +from .Sample import Sample +from .Sample import SampleType +from .SampleLoader import SampleLoader +from .SampleProcessor import SampleProcessor +from .SampleGeneratorBase import SampleGeneratorBase +from .SampleGeneratorFace import SampleGeneratorFace +from .SampleGeneratorFacePerson import SampleGeneratorFacePerson +from .SampleGeneratorFaceTemporal import SampleGeneratorFaceTemporal +from .SampleGeneratorImage import SampleGeneratorImage +from .SampleGeneratorImageTemporal import SampleGeneratorImageTemporal +from .SampleGeneratorFaceCelebAMaskHQ import SampleGeneratorFaceCelebAMaskHQ +from .SampleGeneratorFaceXSeg import SampleGeneratorFaceXSeg +from .PackedFaceset import PackedFaceset \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..168feab1975426144c42461ce9f427a36744131a --- /dev/null +++ b/test.py @@ -0,0 +1,13066 @@ +""" +import os +os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" +os.environ["PLAIDML_DEVICE_IDS"] = "opencl_nvidia_geforce_gtx_1060_6gb.0" +import keras +KL = keras.layers + +x = KL.Input ( (128,128,64) ) +label = KL.Input( (1,), dtype="int32") +y = x[:,:,:, label[0,0] ] + +import code +code.interact(local=dict(globals(), **locals())) +""" + +# import os +# os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" +# os.environ["PLAIDML_DEVICE_IDS"] = "opencl_nvidia_geforce_gtx_1060_6gb.0" +# import keras +# K = keras.backend +# import numpy as np + +# shape = (64, 64, 3) +# def encflow(x): +# x = keras.layers.Conv2D(128, 5, strides=2, padding="same")(x) +# x = keras.layers.Conv2D(256, 5, strides=2, padding="same")(x) +# x = keras.layers.Dense(3)(keras.layers.Flatten()(x)) +# return x + +# def modelify(model_functor): +# def func(tensor): +# return keras.models.Model (tensor, model_functor(tensor)) +# return func + +# encoder = modelify (encflow)( keras.Input(shape) ) + +# inp = x = keras.Input(shape) +# code_t = encoder(x) +# loss = K.mean(code_t) + +# train_func = K.function ([inp],[loss], keras.optimizers.Adam().get_updates(loss, encoder.trainable_weights) ) +# train_func ([ np.zeros ( (1, 64, 64, 3) ) ]) + +# import code +# code.interact(local=dict(globals(), **locals())) + +########################## +""" +import os +os.environ['TF_CUDNN_WORKSPACE_LIMIT_IN_MB'] = '1024' +#os.environ['TF_CUDNN_USE_AUTOTUNE'] = '0' +import numpy as np +import tensorflow as tf +keras = tf.keras +KL = keras.layers +K = keras.backend + +bgr_shape = (128, 128, 3) +batch_size = 80#132 #max -tf.1.11.0-cuda 9 +#batch_size = 86 #max -tf.1.13.1-cuda 10 + +class PixelShuffler(keras.layers.Layer): + def __init__(self, size=(2, 2), data_format=None, **kwargs): + super(PixelShuffler, self).__init__(**kwargs) + self.size = size + + def call(self, inputs): + + input_shape = K.int_shape(inputs) + if len(input_shape) != 4: + raise ValueError('Inputs should have rank ' + + str(4) + + '; Received input shape:', str(input_shape)) + + + batch_size, h, w, c = input_shape + if batch_size is None: + batch_size = -1 + rh, rw = self.size + oh, ow = h * rh, w * rw + oc = c // (rh * rw) + + out = K.reshape(inputs, (batch_size, h, w, rh, rw, oc)) + out = K.permute_dimensions(out, (0, 1, 3, 2, 4, 5)) + out = K.reshape(out, (batch_size, oh, ow, oc)) + return out + + def compute_output_shape(self, input_shape): + + if len(input_shape) != 4: + raise ValueError('Inputs should have rank ' + + str(4) + + '; Received input shape:', str(input_shape)) + + + height = input_shape[1] * self.size[0] if input_shape[1] is not None else None + width = input_shape[2] * self.size[1] if input_shape[2] is not None else None + channels = input_shape[3] // self.size[0] // self.size[1] + + if channels * self.size[0] * self.size[1] != input_shape[3]: + raise ValueError('channels of input and size are incompatible') + + return (input_shape[0], + height, + width, + channels) + + def get_config(self): + config = {'size': self.size} + base_config = super(PixelShuffler, self).get_config() + + return dict(list(base_config.items()) + list(config.items())) + +def upscale (dim): + def func(x): + return PixelShuffler()((KL.Conv2D(dim * 4, kernel_size=3, strides=1, padding='same')(x))) + return func + +inp = KL.Input(bgr_shape) +x = inp +x = KL.Conv2D(128, 5, strides=2, padding='same')(x) +x = KL.Conv2D(256, 5, strides=2, padding='same')(x) +x = KL.Conv2D(512, 5, strides=2, padding='same')(x) +x = KL.Conv2D(1024, 5, strides=2, padding='same')(x) +x = KL.Dense(1024)(KL.Flatten()(x)) +x = KL.Dense(8 * 8 * 1024)(x) +x = KL.Reshape((8, 8, 1024))(x) +x = upscale(512)(x) +x = upscale(256)(x) +x = upscale(128)(x) +x = upscale(64)(x) +x = KL.Conv2D(3, 5, strides=1, padding='same')(x) + +model = keras.models.Model ([inp], [x]) +model.compile(optimizer=keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss='mae') + +training_data = np.zeros ( (batch_size,128,128,3) ) +loss = model.train_on_batch( [training_data], [training_data] ) +print ("FINE") + +import sys +sys.exit() +""" + + +import struct +import os +#os.environ["DFL_PLAIDML_BUILD"] = "1" +import pickle +import math +import sys +import argparse +from core import pathex +from core import osex +from facelib import LandmarksProcessor +from facelib import FaceType +from pathlib import Path +import numpy as np +from numpy import linalg as npla +import cv2 +import time +import multiprocessing +import threading +import traceback +from tqdm import tqdm +from DFLIMG import * +from core.cv2ex import * +import shutil +from core import imagelib +from core.interact import interact as io + +# Add path to use current litenn local repo +sys.path.insert(0, r'D:\DevelopPython\Projects\litenn\git_litenn') +import litenn as lnn +import litenn.core as lnc +from litenn.core import CLKernelHelper as lph + +def umeyama(src, dst, estimate_scale): + """Estimate N-D similarity transformation with or without scaling. + Parameters + ---------- + src : (M, N) array + Source coordinates. + dst : (M, N) array + Destination coordinates. + estimate_scale : bool + Whether to estimate scaling factor. + Returns + ------- + T : (N + 1, N + 1) + The homogeneous similarity transformation matrix. The matrix contains + NaN values only if the problem is not well-conditioned. + References + ---------- + .. [1] "Least-squares estimation of transformation parameters between two + point patterns", Shinji Umeyama, PAMI 1991, DOI: 10.1109/34.88573 + """ + + num = src.shape[0] + dim = src.shape[1] + + # Compute mean of src and dst. + src_mean = src.mean(axis=0) + dst_mean = dst.mean(axis=0) + + # Subtract mean from src and dst. + src_demean = src - src_mean + dst_demean = dst - dst_mean + + # Eq. (38). + A = np.dot(dst_demean.T, src_demean) / num + + # Eq. (39). + d = np.ones((dim,), dtype=np.double) + if np.linalg.det(A) < 0: + d[dim - 1] = -1 + + T = np.eye(dim + 1, dtype=np.double) + + U, S, V = np.linalg.svd(A) + + # Eq. (40) and (43). + rank = np.linalg.matrix_rank(A) + if rank == 0: + return np.nan * T + elif rank == dim - 1: + if np.linalg.det(U) * np.linalg.det(V) > 0: + T[:dim, :dim] = np.dot(U, V) + else: + s = d[dim - 1] + d[dim - 1] = -1 + T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V)) + d[dim - 1] = s + else: + T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V.T)) + + if estimate_scale: + # Eq. (41) and (42). + scale = 1.0 / src_demean.var(axis=0).sum() * np.dot(S, d) + else: + scale = 1.0 + + T[:dim, dim] = dst_mean - scale * np.dot(T[:dim, :dim], src_mean.T) + T[:dim, :dim] *= scale + + return T + +def random_transform(image, rotation_range=10, zoom_range=0.5, shift_range=0.05, random_flip=0): + h, w = image.shape[0:2] + rotation = np.random.uniform(-rotation_range, rotation_range) + scale = np.random.uniform(1 - zoom_range, 1 + zoom_range) + tx = np.random.uniform(-shift_range, shift_range) * w + ty = np.random.uniform(-shift_range, shift_range) * h + mat = cv2.getRotationMatrix2D((w // 2, h // 2), rotation, scale) + mat[:, 2] += (tx, ty) + result = cv2.warpAffine( + image, mat, (w, h), borderMode=cv2.BORDER_REPLICATE) + if np.random.random() < random_flip: + result = result[:, ::-1] + return result + +# get pair of random warped images from aligned face image +def random_warp(image, coverage=160, scale = 5, zoom = 1): + assert image.shape == (256, 256, 3) + range_ = np.linspace(128 - coverage//2, 128 + coverage//2, 5) + mapx = np.broadcast_to(range_, (5, 5)) + mapy = mapx.T + + mapx = mapx + np.random.normal(size=(5,5), scale=scale) + mapy = mapy + np.random.normal(size=(5,5), scale=scale) + + interp_mapx = cv2.resize(mapx, (80*zoom,80*zoom))[8*zoom:72*zoom,8*zoom:72*zoom].astype('float32') + interp_mapy = cv2.resize(mapy, (80*zoom,80*zoom))[8*zoom:72*zoom,8*zoom:72*zoom].astype('float32') + + warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR) + + src_points = np.stack([mapx.ravel(), mapy.ravel() ], axis=-1) + dst_points = np.mgrid[0:65*zoom:16*zoom,0:65*zoom:16*zoom].T.reshape(-1,2) + mat = umeyama(src_points, dst_points, True)[0:2] + + target_image = cv2.warpAffine(image, mat, (64*zoom,64*zoom)) + + return warped_image, target_image + +def input_process(stdin_fd, sq, str): + sys.stdin = os.fdopen(stdin_fd) + try: + inp = input (str) + sq.put (True) + except: + sq.put (False) + +def input_in_time (str, max_time_sec): + sq = multiprocessing.Queue() + p = multiprocessing.Process(target=input_process, args=( sys.stdin.fileno(), sq, str)) + p.start() + t = time.time() + inp = False + while True: + if not sq.empty(): + inp = sq.get() + break + if time.time() - t > max_time_sec: + break + p.terminate() + sys.stdin = os.fdopen( sys.stdin.fileno() ) + return inp + + + +def subprocess(sq,cq): + prefetch = 2 + while True: + while prefetch > -1: + cq.put ( np.array([1]) ) #memory leak numpy==1.16.0 , but all fine in 1.15.4 + #cq.put ( [1] ) #no memory leak + prefetch -= 1 + + sq.get() #waiting msg from serv to continue posting + prefetch += 1 + + + +def get_image_hull_mask (image_shape, image_landmarks): + if len(image_landmarks) != 68: + raise Exception('get_image_hull_mask works only with 68 landmarks') + + hull_mask = np.zeros(image_shape[0:2]+(1,),dtype=np.float32) + + cv2.fillConvexPoly( hull_mask, cv2.convexHull( np.concatenate ( (image_landmarks[0:17], image_landmarks[48:], [image_landmarks[0]], [image_landmarks[8]], [image_landmarks[16]])) ), (1,) ) + cv2.fillConvexPoly( hull_mask, cv2.convexHull( np.concatenate ( (image_landmarks[27:31], [image_landmarks[33]]) ) ), (1,) ) + cv2.fillConvexPoly( hull_mask, cv2.convexHull( np.concatenate ( (image_landmarks[17:27], [image_landmarks[0]], [image_landmarks[27]], [image_landmarks[16]], [image_landmarks[33]])) ), (1,) ) + + return hull_mask + + +def umeyama(src, dst, estimate_scale): + """Estimate N-D similarity transformation with or without scaling. + Parameters + ---------- + src : (M, N) array + Source coordinates. + dst : (M, N) array + Destination coordinates. + estimate_scale : bool + Whether to estimate scaling factor. + Returns + ------- + T : (N + 1, N + 1) + The homogeneous similarity transformation matrix. The matrix contains + NaN values only if the problem is not well-conditioned. + References + ---------- + .. [1] "Least-squares estimation of transformation parameters between two + point patterns", Shinji Umeyama, PAMI 1991, DOI: 10.1109/34.88573 + """ + + num = src.shape[0] + dim = src.shape[1] + + # Compute mean of src and dst. + src_mean = src.mean(axis=0) + dst_mean = dst.mean(axis=0) + + # Subtract mean from src and dst. + src_demean = src - src_mean + dst_demean = dst - dst_mean + + # Eq. (38). + A = np.dot(dst_demean.T, src_demean) / num + + # Eq. (39). + d = np.ones((dim,), dtype=np.double) + if np.linalg.det(A) < 0: + d[dim - 1] = -1 + + T = np.eye(dim + 1, dtype=np.double) + + U, S, V = np.linalg.svd(A) + + # Eq. (40) and (43). + rank = np.linalg.matrix_rank(A) + if rank == 0: + return np.nan * T + elif rank == dim - 1: + if np.linalg.det(U) * np.linalg.det(V) > 0: + T[:dim, :dim] = np.dot(U, V) + else: + s = d[dim - 1] + d[dim - 1] = -1 + T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V)) + d[dim - 1] = s + else: + T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V.T)) + + if estimate_scale: + # Eq. (41) and (42). + scale = 1.0 / src_demean.var(axis=0).sum() * np.dot(S, d) + else: + scale = 1.0 + + T[:dim, dim] = dst_mean - scale * np.dot(T[:dim, :dim], src_mean.T) + T[:dim, :dim] *= scale + + return T + +#mean_face_x = np.array([ +#0.000213256, 0.0752622, 0.18113, 0.29077, 0.393397, 0.586856, 0.689483, 0.799124, +#0.904991, 0.98004, 0.490127, 0.490127, 0.490127, 0.490127, 0.36688, 0.426036, +#0.490127, 0.554217, 0.613373, 0.121737, 0.187122, 0.265825, 0.334606, 0.260918, +#0.182743, 0.645647, 0.714428, 0.793132, 0.858516, 0.79751, 0.719335, 0.254149, +#0.340985, 0.428858, 0.490127, 0.551395, 0.639268, 0.726104, 0.642159, 0.556721, +#0.490127, 0.423532, 0.338094, 0.290379, 0.428096, 0.490127, 0.552157, 0.689874, +#0.553364, 0.490127, 0.42689 ]) +# +#mean_face_y = np.array([ +#0.106454, 0.038915, 0.0187482, 0.0344891, 0.0773906, 0.0773906, 0.0344891, +#0.0187482, 0.038915, 0.106454, 0.203352, 0.307009, 0.409805, 0.515625, 0.587326, +#0.609345, 0.628106, 0.609345, 0.587326, 0.216423, 0.178758, 0.179852, 0.231733, +#0.245099, 0.244077, 0.231733, 0.179852, 0.178758, 0.216423, 0.244077, 0.245099, +#0.780233, 0.745405, 0.727388, 0.742578, 0.727388, 0.745405, 0.780233, 0.864805, +#0.902192, 0.909281, 0.902192, 0.864805, 0.784792, 0.778746, 0.785343, 0.778746, +#0.784792, 0.824182, 0.831803, 0.824182 ]) +# +#landmarks_2D = np.stack( [ mean_face_x, mean_face_y ], axis=1 ) + + +#alignments = [] +# +#aligned_path_image_paths = pathex.get_image_paths("D:\\DeepFaceLab\\workspace issue\\data_dst\\aligned") +#for filepath in tqdm(aligned_path_image_paths, desc="Collecting alignments", ascii=True ): +# filepath = Path(filepath) +# +# if filepath.suffix == '.png': +# dflimg = DFLPNG.load( str(filepath), print_on_no_embedded_data=True ) +# elif filepath.suffix == '.jpg': +# dflimg = DFLJPG.load ( str(filepath), print_on_no_embedded_data=True ) +# else: +# print ("%s is not a dfl image file" % (filepath.name) ) +# +# #source_filename_stem = Path( dflimg.get_source_filename() ).stem +# #if source_filename_stem not in alignments.keys(): +# # alignments[ source_filename_stem ] = [] +# +# #alignments[ source_filename_stem ].append (dflimg.get_source_landmarks()) +# alignments.append (dflimg.get_source_landmarks()) +import string + +def tdict2kw_conv2d ( w, b=None ): + if b is not None: + return [ np.transpose(w.numpy(), [2,3,1,0]), b.numpy() ] + else: + return [ np.transpose(w.numpy(), [2,3,1,0])] + +def tdict2kw_depconv2d ( w, b=None ): + if b is not None: + return [ np.transpose(w.numpy(), [2,3,0,1]), b.numpy() ] + else: + return [ np.transpose(w.numpy(), [2,3,0,1]) ] + +def tdict2kw_bn2d( d, name_prefix ): + return [ d[name_prefix+'.weight'].numpy(), + d[name_prefix+'.bias'].numpy(), + d[name_prefix+'.running_mean'].numpy(), + d[name_prefix+'.running_var'].numpy() ] + + +def t2kw_conv2d (src): + if src.bias is not None: + return [ np.transpose(src.weight.data.cpu().numpy(), [2,3,1,0]), src.bias.data.cpu().numpy() ] + else: + return [ np.transpose(src.weight.data.cpu().numpy(), [2,3,1,0])] + + +def t2kw_bn2d(src): + return [ src.weight.data.cpu().numpy(), src.bias.data.cpu().numpy(), src.running_mean.cpu().numpy(), src.running_var.cpu().numpy() ] + + +import scipy as sp + +def color_transfer_mkl(x0, x1): + eps = np.finfo(float).eps + + h,w,c = x0.shape + h1,w1,c1 = x1.shape + + x0 = x0.reshape ( (h*w,c) ) + x1 = x1.reshape ( (h1*w1,c1) ) + + a = np.cov(x0.T) + b = np.cov(x1.T) + + Da2, Ua = np.linalg.eig(a) + Da = np.diag(np.sqrt(Da2.clip(eps, None))) + + C = np.dot(np.dot(np.dot(np.dot(Da, Ua.T), b), Ua), Da) + + Dc2, Uc = np.linalg.eig(C) + Dc = np.diag(np.sqrt(Dc2.clip(eps, None))) + + Da_inv = np.diag(1./(np.diag(Da))) + + t = np.dot(np.dot(np.dot(np.dot(np.dot(np.dot(Ua, Da_inv), Uc), Dc), Uc.T), Da_inv), Ua.T) + + mx0 = np.mean(x0, axis=0) + mx1 = np.mean(x1, axis=0) + + result = np.dot(x0-mx0, t) + mx1 + return np.clip ( result.reshape ( (h,w,c) ), 0, 1) + +def color_transfer_idt(i0, i1, bins=256, n_rot=20): + relaxation = 1 / n_rot + h,w,c = i0.shape + h1,w1,c1 = i1.shape + + i0 = i0.reshape ( (h*w,c) ) + i1 = i1.reshape ( (h1*w1,c1) ) + + n_dims = c + + d0 = i0.T + d1 = i1.T + + for i in range(n_rot): + + r = sp.stats.special_ortho_group.rvs(n_dims).astype(np.float32) + + d0r = np.dot(r, d0) + d1r = np.dot(r, d1) + d_r = np.empty_like(d0) + + for j in range(n_dims): + + lo = min(d0r[j].min(), d1r[j].min()) + hi = max(d0r[j].max(), d1r[j].max()) + + p0r, edges = np.histogram(d0r[j], bins=bins, range=[lo, hi]) + p1r, _ = np.histogram(d1r[j], bins=bins, range=[lo, hi]) + + cp0r = p0r.cumsum().astype(np.float32) + cp0r /= cp0r[-1] + + cp1r = p1r.cumsum().astype(np.float32) + cp1r /= cp1r[-1] + + f = np.interp(cp0r, cp1r, edges[1:]) + + d_r[j] = np.interp(d0r[j], edges[1:], f, left=0, right=bins) + + d0 = relaxation * np.linalg.solve(r, (d_r - d0r)) + d0 + + return np.clip ( d0.T.reshape ( (h,w,c) ), 0, 1) + +from core import imagelib + + +def color_transfer_mix(img_src,img_trg): + img_src = (img_src*255.0).astype(np.uint8) + img_trg = (img_trg*255.0).astype(np.uint8) + + img_src_lab = cv2.cvtColor(img_src, cv2.COLOR_BGR2LAB) + img_trg_lab = cv2.cvtColor(img_trg, cv2.COLOR_BGR2LAB) + + rct_light = np.clip ( imagelib.linear_color_transfer(img_src_lab[...,0:1].astype(np.float32)/255.0, + img_trg_lab[...,0:1].astype(np.float32)/255.0 )[...,0]*255.0, + 0, 255).astype(np.uint8) + + img_src_lab[...,0] = (np.ones_like (rct_light)*100).astype(np.uint8) + img_src_lab = cv2.cvtColor(img_src_lab, cv2.COLOR_LAB2BGR) + + img_trg_lab[...,0] = (np.ones_like (rct_light)*100).astype(np.uint8) + img_trg_lab = cv2.cvtColor(img_trg_lab, cv2.COLOR_LAB2BGR) + + img_rct = imagelib.color_transfer_sot( img_src_lab.astype(np.float32), img_trg_lab.astype(np.float32) ) + img_rct = np.clip(img_rct, 0, 255).astype(np.uint8) + + img_rct = cv2.cvtColor(img_rct, cv2.COLOR_BGR2LAB) + img_rct[...,0] = rct_light + img_rct = cv2.cvtColor(img_rct, cv2.COLOR_LAB2BGR) + + + return (img_rct / 255.0).astype(np.float32) + + + +def color_transfer_mix2(img_src,img_trg): + img_src = (img_src*255.0).astype(np.uint8) + img_trg = (img_trg*255.0).astype(np.uint8) + + img_src_lab = cv2.cvtColor(img_src, cv2.COLOR_BGR2YUV) + img_trg_lab = cv2.cvtColor(img_trg, cv2.COLOR_BGR2YUV) + + rct_light = np.clip ( imagelib.linear_color_transfer(img_src_lab[...,0:1].astype(np.float32)/255.0, + img_trg_lab[...,0:1].astype(np.float32)/255.0 )[...,0]*255.0, + 0, 255).astype(np.uint8) + + img_src_lab[...,0] = (np.ones_like (rct_light)*100).astype(np.uint8) + img_src_lab = cv2.cvtColor(img_src_lab, cv2.COLOR_YUV2BGR) + + img_trg_lab[...,0] = (np.ones_like (rct_light)*100).astype(np.uint8) + img_trg_lab = cv2.cvtColor(img_trg_lab, cv2.COLOR_YUV2BGR) + + img_rct = imagelib.color_transfer_sot( img_src_lab.astype(np.float32), img_trg_lab.astype(np.float32) ) + img_rct = np.clip(img_rct, 0, 255).astype(np.uint8) + + img_rct = cv2.cvtColor(img_rct, cv2.COLOR_BGR2YUV) + img_rct[...,0] = rct_light + img_rct = cv2.cvtColor(img_rct, cv2.COLOR_YUV2BGR) + + + return (img_rct / 255.0).astype(np.float32) + + +def nd_cor(pts1, pts2): + dtype = pts1.dtype + + for iter in range(10): + + dir = np.random.normal(size=2).astype(dtype) + dir /= npla.norm(dir) + + proj_pts1 = pts1*dir + proj_pts2 = pts2*dir + id_pts1 = np.argsort (proj_pts1) + id_pts2 = np.argsort (proj_pts2) + + + + +def fist(pts_src, pts_dst): + + rot = np.eye (3,3,dtype=np.float32) + trans = np.zeros (2, dtype=np.float32) + scaling = 1 + + for iter in range(10): + + center1 = np.zeros (2, dtype=np.float32) + center2 = np.zeros (2, dtype=np.float32) + + +landmarks_2D = np.array([ +[ 0.000213256, 0.106454 ], #17 +[ 0.0752622, 0.038915 ], #18 +[ 0.18113, 0.0187482 ], #19 +[ 0.29077, 0.0344891 ], #20 +[ 0.393397, 0.0773906 ], #21 +[ 0.586856, 0.0773906 ], #22 +[ 0.689483, 0.0344891 ], #23 +[ 0.799124, 0.0187482 ], #24 +[ 0.904991, 0.038915 ], #25 +[ 0.98004, 0.106454 ], #26 +[ 0.490127, 0.203352 ], #27 +[ 0.490127, 0.307009 ], #28 +[ 0.490127, 0.409805 ], #29 +[ 0.490127, 0.515625 ], #30 +[ 0.36688, 0.587326 ], #31 +[ 0.426036, 0.609345 ], #32 +[ 0.490127, 0.628106 ], #33 +[ 0.554217, 0.609345 ], #34 +[ 0.613373, 0.587326 ], #35 +[ 0.121737, 0.216423 ], #36 +[ 0.187122, 0.178758 ], #37 +[ 0.265825, 0.179852 ], #38 +[ 0.334606, 0.231733 ], #39 +[ 0.260918, 0.245099 ], #40 +[ 0.182743, 0.244077 ], #41 +[ 0.645647, 0.231733 ], #42 +[ 0.714428, 0.179852 ], #43 +[ 0.793132, 0.178758 ], #44 +[ 0.858516, 0.216423 ], #45 +[ 0.79751, 0.244077 ], #46 +[ 0.719335, 0.245099 ], #47 +[ 0.254149, 0.780233 ], #48 +[ 0.340985, 0.745405 ], #49 +[ 0.428858, 0.727388 ], #50 +[ 0.490127, 0.742578 ], #51 +[ 0.551395, 0.727388 ], #52 +[ 0.639268, 0.745405 ], #53 +[ 0.726104, 0.780233 ], #54 +[ 0.642159, 0.864805 ], #55 +[ 0.556721, 0.902192 ], #56 +[ 0.490127, 0.909281 ], #57 +[ 0.423532, 0.902192 ], #58 +[ 0.338094, 0.864805 ], #59 +[ 0.290379, 0.784792 ], #60 +[ 0.428096, 0.778746 ], #61 +[ 0.490127, 0.785343 ], #62 +[ 0.552157, 0.778746 ], #63 +[ 0.689874, 0.784792 ], #64 +[ 0.553364, 0.824182 ], #65 +[ 0.490127, 0.831803 ], #66 +[ 0.42689 , 0.824182 ] #67 +], dtype=np.float32) + + +""" + +( .Config() + .sample_host('src_samples', path) + + .index_generator('i1', 'src_samples' ) + + .batch(16) + .warp_params('w1', ...) + .branch( (.Branch() + .load_sample('src_samples', 'i1') + + ) + ) +) + + +""" + + +def _compute_fans(shape, data_format='channels_last'): + """Computes the number of input and output units for a weight shape. + # Arguments + shape: Integer shape tuple. + data_format: Image data format to use for convolution kernels. + Note that all kernels in Keras are standardized on the + `channels_last` ordering (even when inputs are set + to `channels_first`). + # Returns + A tuple of scalars, `(fan_in, fan_out)`. + # Raises + ValueError: in case of invalid `data_format` argument. + """ + if len(shape) == 2: + fan_in = shape[0] + fan_out = shape[1] + elif len(shape) in {3, 4, 5}: + # Assuming convolution kernels (1D, 2D or 3D). + # TH kernel shape: (depth, input_depth, ...) + # TF kernel shape: (..., input_depth, depth) + if data_format == 'channels_first': + receptive_field_size = np.prod(shape[2:]) + fan_in = shape[1] * receptive_field_size + fan_out = shape[0] * receptive_field_size + elif data_format == 'channels_last': + receptive_field_size = np.prod(shape[:-2]) + fan_in = shape[-2] * receptive_field_size + fan_out = shape[-1] * receptive_field_size + else: + raise ValueError('Invalid data_format: ' + data_format) + else: + # No specific assumptions. + fan_in = np.sqrt(np.prod(shape)) + fan_out = np.sqrt(np.prod(shape)) + return fan_in, fan_out + +def _create_basis(filters, size, floatx, eps_std): + if size == 1: + return np.random.normal(0.0, eps_std, (filters, size)) + + nbb = filters // size + 1 + li = [] + for i in range(nbb): + a = np.random.normal(0.0, 1.0, (size, size)) + a = _symmetrize(a) + u, _, v = np.linalg.svd(a) + li.extend(u.T.tolist()) + p = np.array(li[:filters], dtype=floatx) + return p + +def _symmetrize(a): + return a + a.T - np.diag(a.diagonal()) + +def _scale_filters(filters, variance): + c_var = np.var(filters) + p = np.sqrt(variance / c_var) + return filters * p + +def CAGenerateWeights ( shape, floatx, data_format, eps_std=0.05, seed=None ): + if seed is not None: + np.random.seed(seed) + + fan_in, fan_out = _compute_fans(shape, data_format) + variance = 2 / fan_in + + rank = len(shape) + if rank == 3: + row, stack_size, filters_size = shape + + transpose_dimensions = (2, 1, 0) + kernel_shape = (row,) + correct_ifft = lambda shape, s=[None]: np.fft.irfft(shape, s[0]) + correct_fft = np.fft.rfft + + elif rank == 4: + row, column, stack_size, filters_size = shape + + transpose_dimensions = (2, 3, 1, 0) + kernel_shape = (row, column) + correct_ifft = np.fft.irfft2 + correct_fft = np.fft.rfft2 + + elif rank == 5: + x, y, z, stack_size, filters_size = shape + + transpose_dimensions = (3, 4, 0, 1, 2) + kernel_shape = (x, y, z) + correct_fft = np.fft.rfftn + correct_ifft = np.fft.irfftn + else: + raise ValueError('rank unsupported') + + kernel_fourier_shape = correct_fft(np.zeros(kernel_shape)).shape + + init = [] + for i in range(filters_size): + basis = _create_basis(stack_size, np.prod(kernel_fourier_shape), floatx, eps_std) + basis = basis.reshape((stack_size,) + kernel_fourier_shape) + + filters = [correct_ifft(x, kernel_shape) + + np.random.normal(0, eps_std, kernel_shape) + for x in basis] + + init.append(filters) + + # Format of array is now: filters, stack, row, column + init = np.array(init) + init = _scale_filters(init, variance) + return init.transpose(transpose_dimensions) + +from datetime import datetime +class timeit: + + def __enter__(self): + self.t = datetime.now().timestamp() + def __exit__(self, a,b,c): + print(f'timeit!: {datetime.now().timestamp()-self.t}') + +import scipy +from ctypes import * + +from core.joblib import Subprocessor + +class CTComputerSubprocessor(Subprocessor): + class Cli(Subprocessor.Cli): + def process_data(self, data): + idx, src_path, dst_path = data + src_path = Path(src_path) + dst_path = Path(dst_path) + + src_uint8 = cv2_imread(src_path) + dst_uint8 = cv2_imread(dst_path) + + src_dflimg = DFLIMG.load(src_path) + dst_dflimg = DFLIMG.load(dst_path) + if src_dflimg is None or dst_dflimg is None: + return idx, [0,0,0,0,0,0] + + src_uint8 = src_uint8*LandmarksProcessor.get_image_hull_mask( src_uint8.shape, src_dflimg.get_landmarks() ) + dst_uint8 = dst_uint8*LandmarksProcessor.get_image_hull_mask( dst_uint8.shape, dst_dflimg.get_landmarks() ) + + src = src_uint8.astype(np.float32) / 255.0 + dst = dst_uint8.astype(np.float32) / 255.0 + + src_rct = imagelib.reinhard_color_transfer(src_uint8, dst_uint8).astype(np.float32) / 255.0 + src_lct = np.clip( imagelib.linear_color_transfer (src, dst), 0.0, 1.0 ) + src_mkl = imagelib.color_transfer_mkl (src, dst) + src_idt = imagelib.color_transfer_idt (src, dst) + src_sot = imagelib.color_transfer_sot (src, dst) + + dst_mean = np.mean(dst, axis=(0,1) ) + src_mean = np.mean(src, axis=(0,1) ) + src_rct_mean = np.mean(src_rct, axis=(0,1) ) + src_lct_mean = np.mean(src_lct, axis=(0,1) ) + src_mkl_mean = np.mean(src_mkl, axis=(0,1) ) + src_idt_mean = np.mean(src_idt, axis=(0,1) ) + src_sot_mean = np.mean(src_sot, axis=(0,1) ) + + dst_std = np.sqrt ( np.var(dst, axis=(0,1) ) + 1e-5 ) + src_std = np.sqrt ( np.var(src, axis=(0,1) ) + 1e-5 ) + src_rct_std = np.sqrt ( np.var(src_rct, axis=(0,1) ) + 1e-5 ) + src_lct_std = np.sqrt ( np.var(src_lct, axis=(0,1) ) + 1e-5 ) + src_mkl_std = np.sqrt ( np.var(src_mkl, axis=(0,1) ) + 1e-5 ) + src_idt_std = np.sqrt ( np.var(src_idt, axis=(0,1) ) + 1e-5 ) + src_sot_std = np.sqrt ( np.var(src_sot, axis=(0,1) ) + 1e-5 ) + + def_mean_sum = np.sum( np.square(src_mean-dst_mean) ) + rct_mean_sum = np.sum( np.square(src_rct_mean-dst_mean) ) + lct_mean_sum = np.sum( np.square(src_lct_mean-dst_mean) ) + mkl_mean_sum = np.sum( np.square(src_mkl_mean-dst_mean) ) + idt_mean_sum = np.sum( np.square(src_idt_mean-dst_mean) ) + sot_mean_sum = np.sum( np.square(src_sot_mean-dst_mean) ) + + def_std_sum = np.sum( np.square(src_std-dst_std) ) + rct_std_sum = np.sum( np.square(src_rct_std-dst_std) ) + lct_std_sum = np.sum( np.square(src_lct_std-dst_std) ) + mkl_std_sum = np.sum( np.square(src_mkl_std-dst_std) ) + idt_std_sum = np.sum( np.square(src_idt_std-dst_std) ) + sot_std_sum = np.sum( np.square(src_sot_std-dst_std) ) + + return idx, [def_mean_sum+def_std_sum, + rct_mean_sum+rct_std_sum, + lct_mean_sum+lct_std_sum, + mkl_mean_sum+mkl_std_sum, + idt_mean_sum+idt_std_sum, + sot_mean_sum+sot_std_sum + ] + + def __init__(self, src_paths, dst_paths ): + self.src_paths = src_paths + self.src_paths_idxs = [*range(len(self.src_paths))] + self.dst_paths = dst_paths + self.result = [None]*len(self.src_paths) + super().__init__('CTComputerSubprocessor', CTComputerSubprocessor.Cli, 60) + + def process_info_generator(self): + + for i in range(multiprocessing.cpu_count()): + yield 'CPU%d' % (i), {}, {} + + def on_clients_initialized(self): + io.progress_bar ("Computing", len (self.src_paths_idxs)) + + def on_clients_finalized(self): + io.progress_bar_close() + + def get_data(self, host_dict): + if len (self.src_paths_idxs) > 0: + idx = self.src_paths_idxs.pop(0) + src_path = self.src_paths [idx] + dst_path = self.dst_paths [np.random.randint(len(self.dst_paths))] + return idx, src_path, dst_path + return None + + #override + def on_data_return (self, host_dict, data): + self.src_paths_idxs.insert(0, data[0]) + + #override + def on_result (self, host_dict, data, result): + idx, data = result + self.result[idx] = data + io.progress_bar_inc(1) + + #override + def get_result(self): + return {0:'none', + 1:'rct', + 2:'lct', + 3:'mkl', + 4:'idt', + 5:'sot' + }[np.argmin(np.mean(np.array(self.result), 0))] + +from samplelib import * +from skimage.transform import rescale + + + +#np.seterr(divide='ignore', invalid='ignore') + +def mls_affine_deformation_1pt(p, q, v, alpha=1): + ''' Calculate the affine deformation of one point. + This function is used to test the algorithm. + ''' + ctrls = p.shape[0] + np.seterr(divide='ignore') + w = 1.0 / np.sum((p - v) ** 2, axis=1) ** alpha + w[w == np.inf] = 2**31-1 + pstar = np.sum(p.T * w, axis=1) / np.sum(w) + qstar = np.sum(q.T * w, axis=1) / np.sum(w) + phat = p - pstar + qhat = q - qstar + reshaped_phat1 = phat.reshape(ctrls, 2, 1) + reshaped_phat2 = phat.reshape(ctrls, 1, 2) + reshaped_w = w.reshape(ctrls, 1, 1) + pTwp = np.sum(reshaped_phat1 * reshaped_w * reshaped_phat2, axis=0) + try: + inv_pTwp = np.linalg.inv(pTwp) + except np.linalg.linalg.LinAlgError: + if np.linalg.det(pTwp) < 1e-8: + new_v = v + qstar - pstar + return new_v + else: + raise + mul_left = v - pstar + mul_right = np.sum(reshaped_phat1 * reshaped_w * qhat[:, np.newaxis, :], axis=0) + new_v = np.dot(np.dot(mul_left, inv_pTwp), mul_right) + qstar + return new_v + +def mls_affine_deformation(image, p, q, alpha=1.0, density=1.0): + ''' Affine deformation + ### Params: + * image - ndarray: original image + * p - ndarray: an array with size [n, 2], original control points + * q - ndarray: an array with size [n, 2], final control points + * alpha - float: parameter used by weights + * density - float: density of the grids + ### Return: + A deformed image. + ''' + height = image.shape[0] + width = image.shape[1] + # Change (x, y) to (row, col) + q = q[:, [1, 0]] + p = p[:, [1, 0]] + + # Make grids on the original image + gridX = np.linspace(0, width, num=int(width*density), endpoint=False) + gridY = np.linspace(0, height, num=int(height*density), endpoint=False) + vy, vx = np.meshgrid(gridX, gridY) + grow = vx.shape[0] # grid rows + gcol = vx.shape[1] # grid cols + ctrls = p.shape[0] # control points + + # Precompute + reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] + reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] + + w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] + w[w == np.inf] = 2**31 - 1 + pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] + reshaped_phat1 = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol] + reshaped_phat2 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] + pTwp = np.sum(reshaped_phat1 * reshaped_w * reshaped_phat2, axis=0) # [2, 2, grow, gcol] + try: + inv_pTwp = np.linalg.inv(pTwp.transpose(2, 3, 0, 1)) # [grow, gcol, 2, 2] + flag = False + except np.linalg.linalg.LinAlgError: + flag = True + det = np.linalg.det(pTwp.transpose(2, 3, 0, 1)) # [grow, gcol] + det[det < 1e-8] = np.inf + reshaped_det = det.reshape(1, 1, grow, gcol) # [1, 1, grow, gcol] + adjoint = pTwp[[[1, 0], [1, 0]], [[1, 1], [0, 0]], :, :] # [2, 2, grow, gcol] + adjoint[[0, 1], [1, 0], :, :] = -adjoint[[0, 1], [1, 0], :, :] # [2, 2, grow, gcol] + inv_pTwp = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [grow, gcol, 2, 2] + mul_left = reshaped_v - pstar # [2, grow, gcol] + reshaped_mul_left = mul_left.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1) # [grow, gcol, 1, 2] + mul_right = reshaped_w * reshaped_phat1 # [ctrls, 2, 1, grow, gcol] + reshaped_mul_right =mul_right.transpose(0, 3, 4, 1, 2) # [ctrls, grow, gcol, 2, 1] + A = np.matmul(np.matmul(reshaped_mul_left, inv_pTwp), reshaped_mul_right) # [ctrls, grow, gcol, 1, 1] + reshaped_A = A.reshape(ctrls, 1, grow, gcol) # [ctrls, 1, grow, gcol] + + # Calculate q + reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] + qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] + + # Get final image transfomer -- 3-D array + transformers = np.sum(reshaped_A * qhat, axis=0) + qstar # [2, grow, gcol] + + # Correct the points where pTwp is singular + if flag: + blidx = det == np.inf # bool index + transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] + transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] + + # Removed the points outside the border + transformers[transformers < 0] = 0 + transformers[0][transformers[0] > height - 1] = 0 + transformers[1][transformers[1] > width - 1] = 0 + + # Mapping original image + transformed_image = np.ones_like(image) * 255 + new_gridY, new_gridX = np.meshgrid((np.arange(gcol) / density).astype(np.int16), + (np.arange(grow) / density).astype(np.int16)) + transformed_image[tuple(transformers.astype(np.int16))] = image[new_gridX, new_gridY] # [grow, gcol] + + return transformed_image + +def mls_affine_deformation_inv(image, p, q, alpha=1.0, density=1.0): + ''' Affine inverse deformation + ### Params: + * image - ndarray: original image + * p - ndarray: an array with size [n, 2], original control points + * q - ndarray: an array with size [n, 2], final control points + * alpha - float: parameter used by weights + * density - float: density of the grids + ### Return: + A deformed image. + ''' + height = image.shape[0] + width = image.shape[1] + # Change (x, y) to (row, col) + q = q[:, [1, 0]] + p = p[:, [1, 0]] + + # Make grids on the original image + gridX = np.linspace(0, width, num=int(width*density), endpoint=False) + gridY = np.linspace(0, height, num=int(height*density), endpoint=False) + vy, vx = np.meshgrid(gridX, gridY) + grow = vx.shape[0] # grid rows + gcol = vx.shape[1] # grid cols + ctrls = p.shape[0] # control points + + # Compute + reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] + reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] + reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] + + w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] + w[w == np.inf] = 2**31 - 1 + pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] + qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] + + reshaped_phat = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol] + reshaped_phat2 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 2, 1, grow, gcol] + reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] + pTwq = np.sum(reshaped_phat * reshaped_w * reshaped_qhat, axis=0) # [2, 2, grow, gcol] + try: + inv_pTwq = np.linalg.inv(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol, 2, 2] + flag = False + except np.linalg.linalg.LinAlgError: + flag = True + det = np.linalg.det(pTwq.transpose(2, 3, 0, 1)) # [grow, gcol] + det[det < 1e-8] = np.inf + reshaped_det = det.reshape(1, 1, grow, gcol) # [1, 1, grow, gcol] + adjoint = pTwq[[[1, 0], [1, 0]], [[1, 1], [0, 0]], :, :] # [2, 2, grow, gcol] + adjoint[[0, 1], [1, 0], :, :] = -adjoint[[0, 1], [1, 0], :, :] # [2, 2, grow, gcol] + inv_pTwq = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [grow, gcol, 2, 2] + mul_left = reshaped_v - qstar # [2, grow, gcol] + reshaped_mul_left = mul_left.reshape(1, 2, grow, gcol).transpose(2, 3, 0, 1) # [grow, gcol, 1, 2] + mul_right = np.sum(reshaped_phat * reshaped_w * reshaped_phat2, axis=0) # [2, 2, grow, gcol] + reshaped_mul_right =mul_right.transpose(2, 3, 0, 1) # [grow, gcol, 2, 2] + temp = np.matmul(np.matmul(reshaped_mul_left, inv_pTwq), reshaped_mul_right) # [grow, gcol, 1, 2] + reshaped_temp = temp.reshape(grow, gcol, 2).transpose(2, 0, 1) # [2, grow, gcol] + + # Get final image transfomer -- 3-D array + transformers = reshaped_temp + pstar # [2, grow, gcol] + + # Correct the points where pTwp is singular + if flag: + blidx = det == np.inf # bool index + transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] + transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] + + # Removed the points outside the border + transformers[transformers < 0] = 0 + transformers[0][transformers[0] > height - 1] = 0 + transformers[1][transformers[1] > width - 1] = 0 + + # Mapping original image + transformed_image = image[tuple(transformers.astype(np.int16))] # [grow, gcol] + + # Rescale image + transformed_image = rescale(transformed_image, scale=1.0 / density, mode='reflect') + + return transformed_image + + + + + + +def mls_similarity_deformation(image, p, q, alpha=1.0, density=1.0): + ''' Similarity deformation + ### Params: + * image - ndarray: original image + * p - ndarray: an array with size [n, 2], original control points + * q - ndarray: an array with size [n, 2], final control points + * alpha - float: parameter used by weights + * density - float: density of the grids + ### Return: + A deformed image. + ''' + height = image.shape[0] + width = image.shape[1] + # Change (x, y) to (row, col) + q = q[:, [1, 0]] + p = p[:, [1, 0]] + + # Make grids on the original image + gridX = np.linspace(0, width, num=int(width*density), endpoint=False) + gridY = np.linspace(0, height, num=int(height*density), endpoint=False) + vy, vx = np.meshgrid(gridX, gridY) + grow = vx.shape[0] # grid rows + gcol = vx.shape[1] # grid cols + ctrls = p.shape[0] # control points + + # Compute + reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] + reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] + + w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] + sum_w = np.sum(w, axis=0) # [grow, gcol] + pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / sum_w # [2, grow, gcol] + phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] + reshaped_phat1 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_phat2 = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol] + reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] + mu = np.sum(np.matmul(reshaped_w.transpose(0, 3, 4, 1, 2) * + reshaped_phat1.transpose(0, 3, 4, 1, 2), + reshaped_phat2.transpose(0, 3, 4, 1, 2)), axis=0) # [grow, gcol, 1, 1] + reshaped_mu = mu.reshape(1, grow, gcol) # [1, grow, gcol] + neg_phat_verti = phat[:, [1, 0],...] # [ctrls, 2, grow, gcol] + neg_phat_verti[:, 1,...] = -neg_phat_verti[:, 1,...] + reshaped_neg_phat_verti = neg_phat_verti.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + mul_left = np.concatenate((reshaped_phat1, reshaped_neg_phat_verti), axis=1) # [ctrls, 2, 2, grow, gcol] + vpstar = reshaped_v - pstar # [2, grow, gcol] + reshaped_vpstar = vpstar.reshape(2, 1, grow, gcol) # [2, 1, grow, gcol] + neg_vpstar_verti = vpstar[[1, 0],...] # [2, grow, gcol] + neg_vpstar_verti[1,...] = -neg_vpstar_verti[1,...] + reshaped_neg_vpstar_verti = neg_vpstar_verti.reshape(2, 1, grow, gcol) # [2, 1, grow, gcol] + mul_right = np.concatenate((reshaped_vpstar, reshaped_neg_vpstar_verti), axis=1) # [2, 2, grow, gcol] + reshaped_mul_right = mul_right.reshape(1, 2, 2, grow, gcol) # [1, 2, 2, grow, gcol] + A = np.matmul((reshaped_w * mul_left).transpose(0, 3, 4, 1, 2), + reshaped_mul_right.transpose(0, 3, 4, 1, 2)) # [ctrls, grow, gcol, 2, 2] + + # Calculate q + reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] + qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] + reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol).transpose(0, 3, 4, 1, 2) # [ctrls, grow, gcol, 1, 2] + + # Get final image transfomer -- 3-D array + temp = np.sum(np.matmul(reshaped_qhat, A), axis=0).transpose(2, 3, 0, 1) # [1, 2, grow, gcol] + reshaped_temp = temp.reshape(2, grow, gcol) # [2, grow, gcol] + transformers = reshaped_temp / reshaped_mu + qstar # [2, grow, gcol] + + # Removed the points outside the border + transformers[transformers < 0] = 0 + transformers[0][transformers[0] > height - 1] = 0 + transformers[1][transformers[1] > width - 1] = 0 + + # Mapping original image + transformed_image = np.ones_like(image) * 255 + new_gridY, new_gridX = np.meshgrid((np.arange(gcol) / density).astype(np.int16), + (np.arange(grow) / density).astype(np.int16)) + transformed_image[tuple(transformers.astype(np.int16))] = image[new_gridX, new_gridY] # [grow, gcol] + + return transformed_image + + +def mls_similarity_deformation_inv(image, p, q, alpha=1.0, density=1.0): + ''' Similarity inverse deformation + ### Params: + * image - ndarray: original image + * p - ndarray: an array with size [n, 2], original control points + * q - ndarray: an array with size [n, 2], final control points + * alpha - float: parameter used by weights + * density - float: density of the grids + ### Return: + A deformed image. + ''' + height = image.shape[0] + width = image.shape[1] + # Change (x, y) to (row, col) + q = q[:, [1, 0]] + p = p[:, [1, 0]] + + # Make grids on the original image + gridX = np.linspace(0, width, num=int(width*density), endpoint=False) + gridY = np.linspace(0, height, num=int(height*density), endpoint=False) + vy, vx = np.meshgrid(gridX, gridY) + grow = vx.shape[0] # grid rows + gcol = vx.shape[1] # grid cols + ctrls = p.shape[0] # control points + + # Compute + reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] + reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] + reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] + + w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] + w[w == np.inf] = 2**31 - 1 + pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] + qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] + reshaped_phat1 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_phat2 = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol] + reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] + + mu = np.sum(np.matmul(reshaped_w.transpose(0, 3, 4, 1, 2) * + reshaped_phat1.transpose(0, 3, 4, 1, 2), + reshaped_phat2.transpose(0, 3, 4, 1, 2)), axis=0) # [grow, gcol, 1, 1] + reshaped_mu = mu.reshape(1, grow, gcol) # [1, grow, gcol] + neg_phat_verti = phat[:, [1, 0],...] # [ctrls, 2, grow, gcol] + neg_phat_verti[:, 1,...] = -neg_phat_verti[:, 1,...] + reshaped_neg_phat_verti = neg_phat_verti.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + mul_right = np.concatenate((reshaped_phat1, reshaped_neg_phat_verti), axis=1) # [ctrls, 2, 2, grow, gcol] + mul_left = reshaped_qhat * reshaped_w # [ctrls, 1, 2, grow, gcol] + Delta = np.sum(np.matmul(mul_left.transpose(0, 3, 4, 1, 2), + mul_right.transpose(0, 3, 4, 1, 2)), + axis=0).transpose(0, 1, 3, 2) # [grow, gcol, 2, 1] + Delta_verti = Delta[...,[1, 0],:] # [grow, gcol, 2, 1] + Delta_verti[...,0,:] = -Delta_verti[...,0,:] + B = np.concatenate((Delta, Delta_verti), axis=3) # [grow, gcol, 2, 2] + try: + inv_B = np.linalg.inv(B) # [grow, gcol, 2, 2] + flag = False + except np.linalg.linalg.LinAlgError: + flag = True + det = np.linalg.det(B) # [grow, gcol] + det[det < 1e-8] = np.inf + reshaped_det = det.reshape(grow, gcol, 1, 1) # [grow, gcol, 1, 1] + adjoint = B[:,:,[[1, 0], [1, 0]], [[1, 1], [0, 0]]] # [grow, gcol, 2, 2] + adjoint[:,:,[0, 1], [1, 0]] = -adjoint[:,:,[0, 1], [1, 0]] # [grow, gcol, 2, 2] + inv_B = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [2, 2, grow, gcol] + + v_minus_qstar_mul_mu = (reshaped_v - qstar) * reshaped_mu # [2, grow, gcol] + + # Get final image transfomer -- 3-D array + reshaped_v_minus_qstar_mul_mu = v_minus_qstar_mul_mu.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol] + transformers = np.matmul(reshaped_v_minus_qstar_mul_mu.transpose(2, 3, 0, 1), + inv_B).reshape(grow, gcol, 2).transpose(2, 0, 1) + pstar # [2, grow, gcol] + + # Correct the points where pTwp is singular + if flag: + blidx = det == np.inf # bool index + transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] + transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] + + # Removed the points outside the border + transformers[transformers < 0] = 0 + transformers[0][transformers[0] > height - 1] = 0 + transformers[1][transformers[1] > width - 1] = 0 + + # Mapping original image + transformed_image = image[tuple(transformers.astype(np.int16))] # [grow, gcol] + + # Rescale image + transformed_image = rescale(transformed_image, scale=1.0 / density, mode='reflect') + + return transformed_image + + +def mls_rigid_deformation(image, p, q, alpha=1.0, density=1.0): + ''' Rigid deformation + ### Params: + * image - ndarray: original image + * p - ndarray: an array with size [n, 2], original control points + * q - ndarray: an array with size [n, 2], final control points + * alpha - float: parameter used by weights + * density - float: density of the grids + ### Return: + A deformed image. + ''' + height = image.shape[0] + width = image.shape[1] + # Change (x, y) to (row, col) + q = q[:, [1, 0]] + p = p[:, [1, 0]] + + # Make grids on the original image + gridX = np.linspace(0, width, num=int(width*density), endpoint=False) + gridY = np.linspace(0, height, num=int(height*density), endpoint=False) + vy, vx = np.meshgrid(gridX, gridY) + grow = vx.shape[0] # grid rows + gcol = vx.shape[1] # grid cols + ctrls = p.shape[0] # control points + + # Compute + reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] + reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] + + w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] + sum_w = np.sum(w, axis=0) # [grow, gcol] + pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / sum_w # [2, grow, gcol] + phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] + reshaped_phat = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] + neg_phat_verti = phat[:, [1, 0],...] # [ctrls, 2, grow, gcol] + neg_phat_verti[:, 1,...] = -neg_phat_verti[:, 1,...] + reshaped_neg_phat_verti = neg_phat_verti.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + mul_left = np.concatenate((reshaped_phat, reshaped_neg_phat_verti), axis=1) # [ctrls, 2, 2, grow, gcol] + vpstar = reshaped_v - pstar # [2, grow, gcol] + reshaped_vpstar = vpstar.reshape(2, 1, grow, gcol) # [2, 1, grow, gcol] + neg_vpstar_verti = vpstar[[1, 0],...] # [2, grow, gcol] + neg_vpstar_verti[1,...] = -neg_vpstar_verti[1,...] + reshaped_neg_vpstar_verti = neg_vpstar_verti.reshape(2, 1, grow, gcol) # [2, 1, grow, gcol] + mul_right = np.concatenate((reshaped_vpstar, reshaped_neg_vpstar_verti), axis=1) # [2, 2, grow, gcol] + reshaped_mul_right = mul_right.reshape(1, 2, 2, grow, gcol) # [1, 2, 2, grow, gcol] + A = np.matmul((reshaped_w * mul_left).transpose(0, 3, 4, 1, 2), + reshaped_mul_right.transpose(0, 3, 4, 1, 2)) # [ctrls, grow, gcol, 2, 2] + + # Calculate q + reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] + qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + qhat = reshaped_q - qstar # [2, grow, gcol] + reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol).transpose(0, 3, 4, 1, 2) # [ctrls, grow, gcol, 1, 2] + + # Get final image transfomer -- 3-D array + temp = np.sum(np.matmul(reshaped_qhat, A), axis=0).transpose(2, 3, 0, 1) # [1, 2, grow, gcol] + reshaped_temp = temp.reshape(2, grow, gcol) # [2, grow, gcol] + norm_reshaped_temp = np.linalg.norm(reshaped_temp, axis=0, keepdims=True) # [1, grow, gcol] + norm_vpstar = np.linalg.norm(vpstar, axis=0, keepdims=True) # [1, grow, gcol] + transformers = reshaped_temp / norm_reshaped_temp * norm_vpstar + qstar # [2, grow, gcol] + + # Removed the points outside the border + transformers[transformers < 0] = 0 + transformers[0][transformers[0] > height - 1] = 0 + transformers[1][transformers[1] > width - 1] = 0 + + # Mapping original image + transformed_image = np.ones_like(image) * 255 + new_gridY, new_gridX = np.meshgrid((np.arange(gcol) / density).astype(np.int16), + (np.arange(grow) / density).astype(np.int16)) + transformed_image[tuple(transformers.astype(np.int16))] = image[new_gridX, new_gridY] # [grow, gcol] + + return transformed_image + +def mls_rigid_deformation_inv(image, p, q, alpha=1.0, density=1.0): + ''' Rigid inverse deformation + ### Params: + * image - ndarray: original image + * p - ndarray: an array with size [n, 2], original control points + * q - ndarray: an array with size [n, 2], final control points + * alpha - float: parameter used by weights + * density - float: density of the grids + ### Return: + A deformed image. + ''' + height = image.shape[0] + width = image.shape[1] + # Change (x, y) to (row, col) + q = q[:, [1, 0]] + p = p[:, [1, 0]] + + # Make grids on the original image + gridX = np.linspace(0, width, num=int(width*density), endpoint=False) + gridY = np.linspace(0, height, num=int(height*density), endpoint=False) + vy, vx = np.meshgrid(gridX, gridY) + grow = vx.shape[0] # grid rows + gcol = vx.shape[1] # grid cols + ctrls = p.shape[0] # control points + + # Compute + reshaped_p = p.reshape(ctrls, 2, 1, 1) # [ctrls, 2, 1, 1] + reshaped_q = q.reshape((ctrls, 2, 1, 1)) # [ctrls, 2, 1, 1] + reshaped_v = np.vstack((vx.reshape(1, grow, gcol), vy.reshape(1, grow, gcol))) # [2, grow, gcol] + + w = 1.0 / np.sum((reshaped_p - reshaped_v) ** 2, axis=1)**alpha # [ctrls, grow, gcol] + w[w == np.inf] = 2**31 - 1 + pstar = np.sum(w * reshaped_p.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + phat = reshaped_p - pstar # [ctrls, 2, grow, gcol] + qstar = np.sum(w * reshaped_q.transpose(1, 0, 2, 3), axis=1) / np.sum(w, axis=0) # [2, grow, gcol] + qhat = reshaped_q - qstar # [ctrls, 2, grow, gcol] + reshaped_phat1 = phat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_phat2 = phat.reshape(ctrls, 2, 1, grow, gcol) # [ctrls, 2, 1, grow, gcol] + reshaped_qhat = qhat.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + reshaped_w = w.reshape(ctrls, 1, 1, grow, gcol) # [ctrls, 1, 1, grow, gcol] + + mu = np.sum(np.matmul(reshaped_w.transpose(0, 3, 4, 1, 2) * + reshaped_phat1.transpose(0, 3, 4, 1, 2), + reshaped_phat2.transpose(0, 3, 4, 1, 2)), axis=0) # [grow, gcol, 1, 1] + reshaped_mu = mu.reshape(1, grow, gcol) # [1, grow, gcol] + neg_phat_verti = phat[:, [1, 0],...] # [ctrls, 2, grow, gcol] + neg_phat_verti[:, 1,...] = -neg_phat_verti[:, 1,...] + reshaped_neg_phat_verti = neg_phat_verti.reshape(ctrls, 1, 2, grow, gcol) # [ctrls, 1, 2, grow, gcol] + mul_right = np.concatenate((reshaped_phat1, reshaped_neg_phat_verti), axis=1) # [ctrls, 2, 2, grow, gcol] + mul_left = reshaped_qhat * reshaped_w # [ctrls, 1, 2, grow, gcol] + Delta = np.sum(np.matmul(mul_left.transpose(0, 3, 4, 1, 2), + mul_right.transpose(0, 3, 4, 1, 2)), + axis=0).transpose(0, 1, 3, 2) # [grow, gcol, 2, 1] + Delta_verti = Delta[...,[1, 0],:] # [grow, gcol, 2, 1] + Delta_verti[...,0,:] = -Delta_verti[...,0,:] + B = np.concatenate((Delta, Delta_verti), axis=3) # [grow, gcol, 2, 2] + try: + inv_B = np.linalg.inv(B) # [grow, gcol, 2, 2] + flag = False + except np.linalg.linalg.LinAlgError: + flag = True + det = np.linalg.det(B) # [grow, gcol] + det[det < 1e-8] = np.inf + reshaped_det = det.reshape(grow, gcol, 1, 1) # [grow, gcol, 1, 1] + adjoint = B[:,:,[[1, 0], [1, 0]], [[1, 1], [0, 0]]] # [grow, gcol, 2, 2] + adjoint[:,:,[0, 1], [1, 0]] = -adjoint[:,:,[0, 1], [1, 0]] # [grow, gcol, 2, 2] + inv_B = (adjoint / reshaped_det).transpose(2, 3, 0, 1) # [2, 2, grow, gcol] + + vqstar = reshaped_v - qstar # [2, grow, gcol] + reshaped_vqstar = vqstar.reshape(1, 2, grow, gcol) # [1, 2, grow, gcol] + + # Get final image transfomer -- 3-D array + temp = np.matmul(reshaped_vqstar.transpose(2, 3, 0, 1), + inv_B).reshape(grow, gcol, 2).transpose(2, 0, 1) # [2, grow, gcol] + norm_temp = np.linalg.norm(temp, axis=0, keepdims=True) # [1, grow, gcol] + norm_vqstar = np.linalg.norm(vqstar, axis=0, keepdims=True) # [1, grow, gcol] + transformers = temp / norm_temp * norm_vqstar + pstar # [2, grow, gcol] + + # Correct the points where pTwp is singular + if flag: + blidx = det == np.inf # bool index + transformers[0][blidx] = vx[blidx] + qstar[0][blidx] - pstar[0][blidx] + transformers[1][blidx] = vy[blidx] + qstar[1][blidx] - pstar[1][blidx] + + # Removed the points outside the border + transformers[transformers < 0] = 0 + transformers[0][transformers[0] > height - 1] = 0 + transformers[1][transformers[1] > width - 1] = 0 + + # Mapping original image + transformed_image = image[tuple(transformers.astype(np.int16))] # [grow, gcol] + + # Rescale image + transformed_image = rescale(transformed_image, scale=1.0 / density, mode='reflect') + + return transformed_image + +def color_transfer_sot(src,trg, steps=10, batch_size=5, reg_sigmaXY=16.0, reg_sigmaV=5.0): + """ + Color Transform via Sliced Optimal Transfer + ported by @iperov from https://github.com/dcoeurjo/OTColorTransfer + + src - any float range any channel image + dst - any float range any channel image, same shape as src + steps - number of solver steps + batch_size - solver batch size + reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0 + reg_sigmaV - sigmaV of filter + + return value - clip it manually + """ + if not np.issubdtype(src.dtype, np.floating): + raise ValueError("src value must be float") + if not np.issubdtype(trg.dtype, np.floating): + raise ValueError("trg value must be float") + + if len(src.shape) != 3: + raise ValueError("src shape must have rank 3 (h,w,c)") + + if src.shape != trg.shape: + raise ValueError("src and trg shapes must be equal") + + h,w,c = src.shape + src_orig = src + src_dtype = src.dtype + + trg = trg.reshape( (1, h*w, c) ) + src = src.copy().reshape( (1, h*w, c) ) + + idx_offsets = np.tile( np.array( [[h*w]]), (batch_size, 1) )*np.arange(0,batch_size)[:,None] + + for step in range (steps): + dir = np.random.normal( size=(batch_size,c) ).astype(src_dtype) + dir /= npla.norm(dir, axis=1, keepdims=True) + + + projsource = np.sum( src * dir[:,None,:], axis=-1 ) + projtarget = np.sum( trg * dir[:,None,:], axis=-1 ) + + idSource = np.argsort (projsource) + idx_offsets + idTarget = np.argsort (projtarget) + idx_offsets + + x = projtarget.reshape ( (batch_size*h*w) )[idTarget] - \ + projsource.reshape ( (batch_size*h*w) )[idSource] + + q = x[:,:,None] * dir[:,None,:] + src += np.mean( x[:,:,None] * dir[:,None,:], axis=0, keepdims=True ) + + import code + code.interact(local=dict(globals(), **locals())) + + src = src.reshape( src_orig.shape ) + + if reg_sigmaXY != 0.0: + src_diff = src-src_orig + src_diff_filt = cv2.bilateralFilter (src_diff, 0, reg_sigmaV, reg_sigmaXY ) + if len(src_diff_filt.shape) == 2: + src_diff_filt = src_diff_filt[...,None] + src = src + src_diff_filt + return src + +import numpy as np +import time +from skimage.color import rgb2grey, rgb2lab +from skimage.filters import laplace +from scipy.ndimage.filters import convolve + +class Inpainter(): + def __init__(self, image, mask, patch_size=9, diff_algorithm='sq', plot_progress=False): + self.image = image.astype('uint8') + self.mask = mask.astype('uint8') + # 进行光滑处理消除噪声 + self.mask = cv2.GaussianBlur(self.mask, (3, 3), 1.5) + self.mask = (self.mask > 0).astype('uint8') + self.fill_image = np.copy(self.image) + self.fill_range = np.copy(self.mask) + self.patch_size = patch_size + # 信誉度 + self.confidence = (self.mask == 0).astype('float') + self.height = self.mask.shape[0] + self.width = self.mask.shape[1] + self.total_fill_pixel = self.fill_range.sum() + + self.diff_algorithm = diff_algorithm + self.plot_progress = plot_progress + # 初始化成员变量 + + # 边界矩阵 + self.front = None + self.D = None + # 优先级 + self.priority = None + # 边界等照度线 + self.isophote = None + # 目标点 + self.target_point = None + # 灰度图片 + self.gray_image = None + + def inpaint(self): + while self.fill_range.sum() != 0: + self._get_front() + self.gray_image = cv2.cvtColor( + self.fill_image, cv2.COLOR_RGB2GRAY).astype('float')/255 + self._log() + + if self.plot_progress: + self._plot_image() + + self._update_priority() + target_point = self._get_target_point() + self.target_point = target_point + best_patch_range = self._get_best_patch_range(target_point) + self._fill_image(target_point, best_patch_range) + + return self.fill_image + + # 打印日志 + + def _log(self): + progress_rate = 1-self.fill_range.sum()/self.total_fill_pixel + progress_rate *= 100 + print('填充进度为%.2f' % progress_rate, '%') + + # 动态显示图片更新情况 + def _plot_image(self): + fill_range = 1-self.fill_range + fill_range = fill_range[:, :, np.newaxis].repeat(3, axis=2) + + image = self.fill_image*fill_range + + # 空洞填充为白色 + white_reginon = (self.fill_range-self.front)*255 + white_reginon = white_reginon[:, :, np.newaxis].repeat(3, axis=2) + image += white_reginon + + plt.clf() + plt.imshow(image) + plt.draw() + plt.pause(0.001) + + # 填充图片 + def _fill_image(self, target_point, source_patch_range): + target_patch_range = self._get_patch_range(target_point) + # 获取待填充点的位置 + fill_point_positions = np.where(self._patch_data( + self.fill_range, target_patch_range) > 0) + + # 更新填充点的信誉度 + target_confidence = self._patch_data( + self.confidence, target_patch_range) + target_confidence[fill_point_positions[0], fill_point_positions[1]] =\ + self.confidence[target_point[0], target_point[1]] + + # 更新待填充点像素 + source_patch = self._patch_data(self.fill_image, source_patch_range) + target_patch = self._patch_data(self.fill_image, target_patch_range) + target_patch[fill_point_positions[0], fill_point_positions[1]] =\ + source_patch[fill_point_positions[0], fill_point_positions[1]] + + # 更新剩余填充点 + target_fill_range = self._patch_data( + self.fill_range, target_patch_range) + target_fill_range[:] = 0 + + # 获取最佳匹配图片块的范围 + def _get_best_patch_range(self, template_point): + diff_method_name = '_'+self.diff_algorithm+'_diff' + diff_method = getattr(self, diff_method_name) + + template_patch_range = self._get_patch_range(template_point) + patch_height = template_patch_range[0][1]-template_patch_range[0][0] + patch_width = template_patch_range[1][1]-template_patch_range[1][0] + + best_patch_range = None + best_diff = float('inf') + lab_image = cv2.cvtColor(self.fill_image, cv2.COLOR_RGB2Lab) + # lab_image=np.copy(self.fill_image) + + for x in range(self.height-patch_height+1): + for y in range(self.width-patch_width+1): + source_patch_range = [ + [x, x+patch_height], + [y, y+patch_width] + ] + if self._patch_data(self.fill_range, source_patch_range).sum() != 0: + continue + diff = diff_method( + lab_image, template_patch_range, source_patch_range) + + if diff < best_diff: + best_diff = diff + best_patch_range = source_patch_range + + return best_patch_range + + # 使用平方差比较算法计算两个区域的区别 + def _sq_diff(self, img, template_patch_range, source_patch_range): + mask = 1-self._patch_data(self.fill_range, template_patch_range) + mask = mask[:, :, np.newaxis].repeat(3, axis=2) + template_patch = self._patch_data(img, template_patch_range)*mask + source_patch = self._patch_data(img, source_patch_range)*mask + + return ((template_patch-source_patch)**2).sum() + + # 加入欧拉距离作为考量 + def _sq_with_eucldean_diff(self, img, template_patch_range, source_patch_range): + sq_diff = self._sq_diff(img, template_patch_range, source_patch_range) + eucldean_distance = np.sqrt((template_patch_range[0][0]-source_patch_range[0][0])**2 + + (template_patch_range[1][0]-source_patch_range[1][0])**2) + return sq_diff+eucldean_distance + + def _sq_with_gradient_diff(self, img, template_patch_range, source_patch_range): + sq_diff = self._sq_diff(img, template_patch_range, source_patch_range) + target_isophote = np.copy( + self.isophote[self.target_point[0], self.target_point[1]]) + target_isophote_val = np.sqrt( + target_isophote[0]**2+target_isophote[1]**2) + gray_source_patch = self._patch_data(self.gray_image, source_patch_range) + source_patch_gradient = np.nan_to_num(np.gradient(gray_source_patch)) + source_patch_val = np.sqrt( + source_patch_gradient[0]**2+source_patch_gradient[1]**2) + patch_max_pos = np.unravel_index( + source_patch_val.argmax(), + source_patch_val.shape + ) + source_isophote = np.array([-source_patch_gradient[1, patch_max_pos[0], patch_max_pos[1]], + source_patch_gradient[0, patch_max_pos[0], patch_max_pos[1]]]) + source_isophote_val = source_patch_val.max() + + # 计算两者之间的cos(theta) + dot_product = abs( + source_isophote[0]*target_isophote[0]+source_isophote[1] * target_isophote[1]) + norm = source_isophote_val*target_isophote_val + cos_theta = 0 + if norm != 0: + cos_theta = dot_product/norm + val_diff = abs(source_isophote_val-target_isophote_val) + return sq_diff-cos_theta+val_diff + + def _sq_with_gradient_eucldean_diff(self,img,template_patch_range,source_patch_range): + sq_with_gradient=self._sq_with_gradient_diff(img,template_patch_range,source_patch_range) + eucldean_distance = np.sqrt((template_patch_range[0][0]-source_patch_range[0][0])**2 + + (template_patch_range[1][0]-source_patch_range[1][0])**2) + return sq_with_gradient+eucldean_distance + + # 获取目标点的位置 + + def _get_target_point(self): + return np.unravel_index(self.priority.argmax(), self.priority.shape) + + # 使用Laplace算子求边界 + def _get_front(self): + self.front = (cv2.Laplacian(self.fill_range, -1) > 0).astype('uint8') + + def _update_priority(self): + self._update_front_confidence() + self._update_D() + self.priority = self.confidence*self.D*self.front + + # 更新D + def _update_D(self): + normal = self._get_normal() + isophote = self._get_isophote() + self.isophote = isophote + self.D = abs(normal[:, :, 0]*isophote[:, :, 0]**2 + + normal[:, :, 1]*isophote[:, :, 1]**2)+0.001 + # 更新边界点的信誉度 + + def _update_front_confidence(self): + new_confidence = np.copy(self.confidence) + front_positions = np.argwhere(self.front == 1) + for point in front_positions: + patch_range = self._get_patch_range(point) + sum_patch_confidence = self._patch_data( + self.confidence, patch_range).sum() + area = (patch_range[0][1]-patch_range[0][0]) * \ + (patch_range[1][1]-patch_range[1][0]) + new_confidence[point[0], point[1]] = sum_patch_confidence/area + + self.confidence = new_confidence + + # 获取边界上法线的单位向量 + def _get_normal(self): + x_normal = cv2.Scharr(self.fill_range, cv2.CV_64F, 1, 0) + y_normal = cv2.Scharr(self.fill_range, cv2.CV_64F, 0, 1) + normal = np.dstack([x_normal, y_normal]) + norm = np.sqrt(x_normal**2+y_normal**2).reshape(self.height, + self.width, 1).repeat(2, axis=2) + norm[norm == 0] = 1 + unit_normal = normal/norm + return unit_normal + + # 获取patch周围的等照度线 + def _get_isophote(self): + gray_image = np.copy(self.gray_image) + gray_image[self.fill_range == 1] = None + gradient = np.nan_to_num(np.array(np.gradient(gray_image))) + gradient_val = np.sqrt(gradient[0]**2 + gradient[1]**2) + max_gradient = np.zeros([self.height, self.width, 2]) + front_positions = np.argwhere(self.front == 1) + for point in front_positions: + patch = self._get_patch_range(point) + patch_y_gradient = self._patch_data(gradient[0], patch) + patch_x_gradient = self._patch_data(gradient[1], patch) + patch_gradient_val = self._patch_data(gradient_val, patch) + patch_max_pos = np.unravel_index( + patch_gradient_val.argmax(), + patch_gradient_val.shape + ) + # 旋转90度 + max_gradient[point[0], point[1], 0] = \ + -patch_y_gradient[patch_max_pos] + max_gradient[point[0], point[1], 1] = \ + patch_x_gradient[patch_max_pos] + + return max_gradient + + # 获取图片块的范围 + def _get_patch_range(self, point): + half_patch_size = (self.patch_size-1)//2 + patch_range = [ + [ + max(0, point[0]-half_patch_size), + min(point[0]+half_patch_size+1, self.height) + ], + [ + max(0, point[1]-half_patch_size), + min(point[1]+half_patch_size+1, self.width) + ] + ] + return patch_range + + # 获取patch中的数据 + @staticmethod + def _patch_data(img, patch_range): + return img[patch_range[0][0]:patch_range[0][1], patch_range[1][0]:patch_range[1][1]] + +def Patch(im, taillecadre, point): + """ + Permet de calculer les deux points extreme du patch + Voici le patch avec les 4 points + 1 _________ 2 + | | + | | + 3|________|4 + """ + px, py = point + xsize, ysize, c = im.shape + x3 = max(px - taillecadre, 0) + y3 = max(py - taillecadre, 0) + x2 = min(px + taillecadre, ysize - 1) + y2 = min(py + taillecadre, xsize - 1) + return((x3, y3),(x2, y2)) + +def patch_complet(x, y, xsize, ysize, original): + for i in range(xsize): + for j in range(ysize): + if original[x+i,y+j]==0: + return(False) + return(True) + +def crible(xsize,ysize,x1,y1,masque): + compteur=0 + cibles,ciblem=[],[] + for i in range(xsize): + for j in range(ysize): + if masque[y1+i, x1+j] == 0: + compteur += 1 + cibles+=[(i, j)] + else: + ciblem+=[(i, j)] + return (compteur,cibles,ciblem,xsize,ysize) + +def calculPatch(dOmega, cibleIndex, im, original, masque, taillecadre): + mini = minvar = sys.maxsize + sourcePatch,sourcePatche = [],[] + p = dOmega[cibleIndex] + patch = Patch(im, taillecadre, p) + x1, y1 = patch[0] + x2, y2 = patch[1] + Xsize, Ysize, c = im.shape + compteur,cibles,ciblem,xsize,ysize=crible(y2-y1+1,x2-x1+1,x1,y1,masque) + for x in range(Xsize - xsize): + for y in range(Ysize - ysize): + if patch_complet(x, y, xsize, ysize, original): + sourcePatch+=[(x, y)] + for (y, x) in sourcePatch: + R = V = B = ssd = 0 + for (i, j) in cibles: + ima = im[y+i,x+j] + omega = im[y1+i,x1+j] + for k in range(3): + difference = float(ima[k]) - float(omega[k]) + ssd += difference**2 + R += ima[0] + V += ima[1] + B += ima[2] + ssd /= compteur + if ssd < mini: + variation = 0 + for (i, j) in ciblem: + ima = im[y+i,x+j] + differenceR = ima[0] - R/compteur + differenceV = ima[1] - V/compteur + differenceB = ima[2] - B/compteur + variation += differenceR**2 + differenceV**2 + differenceB**2 + if ssd < mini or variation < minvar: + minvar = variation + mini = ssd + pointPatch = (x, y) + return(ciblem, pointPatch) + +Lap = np.array([[ 1., 1., 1.],[ 1., -8., 1.],[ 1., 1., 1.]]) +kerx = np.array([[ 0., 0., 0.], [-1., 0., 1.], [ 0., 0., 0.]]) +kery = np.array([[ 0., -1., 0.], [ 0., 0., 0.], [ 0., 1., 0.]]) + +def calculConfiance(confiance, im, taillecadre, masque, dOmega): + """Permet de calculer la confiance définie dans l'article""" + for k in range(len(dOmega)): + px, py = dOmega[k] + patch = Patch(im, taillecadre, dOmega[k]) + x3, y3 = patch[0] + x2, y2 = patch[1] + compteur = 0 + taille_psi_p = ((x2-x3+1) * (y2-y3+1)) + for x in range(x3, x2 + 1): + for y in range(y3, y2 + 1): + if masque[y, x] == 0: # intersection avec not Omega + compteur += confiance[y, x] + confiance[py, px] = compteur / taille_psi_p + return(confiance) + +def calculData(dOmega, normale, data, gradientX, gradientY, confiance): + """Permet de calculer data définie dans l'article""" + for k in range(len(dOmega)): + x, y = dOmega[k] + NX, NY = normale[k] + data[y, x] = (((gradientX[y, x] * NX)**2 + (gradientY[y, x] * NY)**2)**0.5) / 255. + return(data) + + +def calculPriority(im, taillecadre, masque, dOmega, normale, data, gradientX, gradientY, confiance): + """Permet de calculer la priorité du patch""" + C = calculConfiance(confiance, im, taillecadre, masque, dOmega) + D = calculData(dOmega, normale, data, gradientX, gradientY, confiance) + index = 0 + maxi = 0 + for i in range(len(dOmega)): + x, y = dOmega[i] + P = C[y,x]*D[y,x] + if P > maxi: + maxi = P + index = i + return(C, D, index) +def update(im, gradientX, gradientY, confiance, source, masque, dOmega, point, list, index, taillecadre): + p = dOmega[index] + px, py = p + patch = Patch(im, taillecadre, p) + x1, y1 = patch[0] + x2, y2 = patch[1] + px, py = point + for (i, j) in list: + im[y1+i, x1+j] = im[py+i, px+j] + confiance[y1+i, x1+j] = confiance[py, px] + source[y1+i, x1+j] = 1 + masque[y1+i, x1+j] = 0 + return(im, gradientX, gradientY, confiance, source, masque) + +Lap = np.array([[ 1., 1., 1.],[ 1., -8., 1.],[ 1., 1., 1.]]) +kerx = np.array([[ 0., 0., 0.], [-1., 0., 1.], [ 0., 0., 0.]]) +kery = np.array([[ 0., -1., 0.], [ 0., 0., 0.], [ 0., 1., 0.]]) + +def IdentifyTheFillFront(masque, source): + """ Identifie le front de remplissage """ + dOmega = [] + normale = [] + lap = cv2.filter2D(masque, cv2.CV_32F, Lap) + GradientX = cv2.filter2D(source, cv2.CV_32F, kerx) + GradientY = cv2.filter2D(source, cv2.CV_32F, kery) + xsize, ysize = lap.shape + for x in range(xsize): + for y in range(ysize): + if lap[x, y] > 0: + dOmega+=[(y, x)] + dx = GradientX[x, y] + dy = GradientY[x, y] + N = (dy**2 + dx**2)**0.5 + if N != 0: + normale+=[(dy/N, -dx/N)] + else: + normale+=[(dy, -dx)] + return(dOmega, normale) + +def inpaint(image, masque, taillecadre): + xsize, ysize, channels = image.shape # meme taille pour filtre et image + + #on verifie les tailles + + x, y = masque.shape + + if x != xsize or y != ysize: + print("La taille de l'image et du filtre doivent être les même") + exit() + + tau = 170 #valeur pour séparer les valeurs du masque + omega=[] + confiance = np.copy(masque) + masque = np.copy(masque) + for x in range(xsize): + for y in range(ysize): + v=masque[x,y] + if v s_rf[1] ): + s[rf] = (layers_count, sum_st, layers) + + if val == 0: + break + + x = sorted(list(s.keys())) + q=x[np.abs(np.array(x)-target_patch_size).argmin()] + return s[q][2] + + import code + code.interact(local=dict(globals(), **locals())) + +luma_quant = np.array([ + [16, 11, 10, 16, 24, 40, 51, 61], + [12, 12, 14, 19, 26, 58, 60, 55], + [14, 13, 16, 24, 40, 57, 69, 56], + [14, 17, 22, 29, 51, 87, 80, 62], + [18, 22, 37, 56, 68, 109, 103, 77], + [24, 35, 55, 64, 81, 104, 113, 92], + [49, 64, 78, 87, 103, 121, 120, 101], + [72, 92, 95, 98, 112, 100, 103, 99]], np.float32) +luma_quant = np.ones_like(luma_quant) + +chroma_quant = np.array([ + [17, 18, 24, 27, 99, 99, 99, 99], + [18, 21, 26, 66, 99, 99, 99, 99], + [24, 26, 56, 99, 99, 99, 99, 99], + [47, 66, 99, 99, 99, 99, 99, 99], + [99, 99, 99, 99, 99, 99, 99, 99], + [99, 99, 99, 99, 99, 99, 99, 99], + [99, 99, 99, 99, 99, 99, 99, 99], + [99, 99, 99, 99, 99, 99, 99, 99]], np.float32) +chroma_quant = np.ones_like(chroma_quant) +#chroma_quant = luma_quant + + +def dct_compress(img, quality=1): + + # if quality < 50: + # S = 5000/quality + # else: + # S = 200 - 2*quality + + # luma_quant = np.floor( (S*luma_quant_base.copy() + 50) / 100) + # luma_quant[luma_quant == 0] = 1 + # chroma_quant = np.floor((S*chroma_quant_base.copy() + 50) / 100) + # chroma_quant[chroma_quant == 0] = 1 + + img = img.copy() + h,w,c = img.shape + + out_img = np.empty( (h,w,c), dtype=np.float32) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) + #img[...,1] = cv2.resize(cv2.resize(img[...,1], (w//2, h//2)), (w,h)) + #img[...,2] = cv2.resize(cv2.resize(img[...,2], (w//2, h//2)), (w,h)) + + + for j in range(h // 8): + for i in range(w // 8): + for k in range(c): + tile = img[j*8:(j+1)*8,i*8:(i+1)*8,k].astype(np.float32) + #tile -= 128.0 + + tile = cv2.dct(tile) + + if k == 0: + tile = tile / luma_quant + else: + tile = tile / chroma_quant + + if i == 21: + import code + code.interact(local=dict(globals(), **locals())) + tile = tile.astype(np.int32).astype(np.float32) + #tile = np.clip(tile, 0, 255).astype(np.uint8) + + out_img[j*8:(j+1)*8,i*8:(i+1)*8,k] = tile + return out_img + +def dct_decompress(img): + img = img.copy() + + + h,w,c = img.shape + out_img = np.empty( (h,w,c), dtype=np.uint8) + #img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) + #img[...,1] = cv2.resize(cv2.resize(img[...,1], (w//2, h//2)), (w,h)) + #img[...,2] = cv2.resize(cv2.resize(img[...,2], (w//2, h//2)), (w,h)) + + + for j in range(h // 8): + for i in range(w // 8): + for k in range(c): + tile = img[j*8:(j+1)*8,i*8:(i+1)*8,k].astype(np.float32) + # if i == 21: + # import code + # code.interact(local=dict(globals(), **locals())) + if k == 0: + tile *= luma_quant + else: + tile *= chroma_quant + + tile = cv2.idct(tile) + + #tile = tile+128.0 + + tile = np.clip(tile, 0, 255).astype(np.uint8) + + out_img[j*8:(j+1)*8,i*8:(i+1)*8,k] = tile + + out_img = cv2.cvtColor(out_img, cv2.COLOR_YCrCb2BGR) + return out_img + + +def dct8x8(img): + h,w = img.shape + if h % 8 != 0 or w % 8 != 0: + raise ValueError('img size must be divisible by 8') + + out_h = h//8 + out_w = w//8 + + tiles = [] + for j in range(h // 8): + for i in range(w // 8): + tile = img[j*8:(j+1)*8,i*8:(i+1)*8].astype(np.float32) + tile -= 128.0 + tile = cv2.dct(tile) + tile = np.round(tile) + tile[tile == -0.0] = 0.0 + tiles.append( tile ) + + tiles = np.array(tiles).reshape( (out_h,out_w,64)) + + return tiles + +def idct8x8(tiles): + h,w,c = tiles.shape + out_h, out_w = h*8, w*8 + + out_img = np.empty( (out_h,out_w), dtype=np.uint8) + + for j in range(h): + for i in range(w): + tile = tiles[j,i].reshape( (8,8)) + tile = cv2.idct(tile) + tile = np.clip(tile+128.0, 0, 255).astype(np.uint8) + + out_img[j*8:(j+1)*8,i*8:(i+1)*8] = tile + return out_img + +def bgr2dct(img): + h,w,c = img.shape + + if c != 3: + raise ValueError('img must have 3 channels') + + out_img = np.empty( (h,w,c), dtype=np.float32) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) + Y, Cr, Cb = cv2.split(img) + + Cr = cv2.resize(Cr, (w//2, h//2)) + Cb = cv2.resize(Cb, (w//2, h//2)) + + return [ dct8x8(x) for x in [Y, Cr, Cb] ] + +def dct2bgr( Y, Cr, Cb ): + h,w,c = Y.shape + out_h = h * 8 + out_w = w * 8 + + Y, Cr, Cb = [ idct8x8(x) for x in [Y, Cr, Cb] ] + + Cr = cv2.resize(Cr, (out_w, out_h)) + Cb = cv2.resize(Cb, (out_w, out_h)) + + img = cv2.merge( [Y,Cr,Cb]) + return cv2.cvtColor(img, cv2.COLOR_YCrCb2BGR) +import numpy as np +import cv2 + +def dct8x8(img): + h,w = img.shape + if h % 8 != 0 or w % 8 != 0: + raise ValueError('img size must be divisible by 8') + + out_h = h//8 + out_w = w//8 + + tiles = [] + for j in range(h // 8): + for i in range(w // 8): + tile = img[j*8:(j+1)*8,i*8:(i+1)*8].astype(np.float32) + tile -= 128.0 + tile = cv2.dct(tile) + tile = np.round(tile) + tile[tile == -0.0] = 0.0 + tiles.append( tile ) + + tiles = np.array(tiles).reshape( (out_h,out_w,64)) + + return tiles + +def idct8x8(tiles): + h,w,c = tiles.shape + out_h, out_w = h*8, w*8 + + out_img = np.empty( (out_h,out_w), dtype=np.uint8) + + for j in range(h): + for i in range(w): + tile = tiles[j,i].reshape( (8,8)) + tile = cv2.idct(tile) + tile += 128.0 + tile = np.clip(tile, 0, 255).astype(np.uint8) + + out_img[j*8:(j+1)*8,i*8:(i+1)*8] = tile + return out_img + +def bgr2dct(img): + h,w,c = img.shape + + if c != 3: + raise ValueError('img must have 3 channels') + + out_img = np.empty( (h,w,c), dtype=np.float32) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb) + Y, Cr, Cb = cv2.split(img) + + Cr = cv2.resize(Cr, (w//2, h//2)) + Cb = cv2.resize(Cb, (w//2, h//2)) + Cr = cv2.resize(Cr, (w, h)) + Cb = cv2.resize(Cb, (w, h)) + + return np.concatenate([ dct8x8(x) for x in [Y, Cr, Cb] ], -1) + +def dct2bgr(img): + + h,w,c = img.shape + out_h = h * 8 + out_w = w * 8 + + Y = img[...,0:64] + Cr = img[...,64:128] + Cb = img[...,128:] + + Y, Cr, Cb = [ idct8x8(x) for x in [Y, Cr, Cb] ] + + #Cr = cv2.resize(Cr, (out_w, out_h)) + #Cb = cv2.resize(Cb, (out_w, out_h)) + + img = cv2.merge( [Y,Cr,Cb]) + return cv2.cvtColor(img, cv2.COLOR_YCrCb2BGR) + + +import cv2 +import scipy +#import trimesh +import numpy as np +from scipy.spatial import ConvexHull +#from cv2.ximgproc import createGuidedFilter + + + + + + + + +# Global position of light source. +gx = 0.0 +gy = 0.0 + +def run(image, mask, ambient_intensity, light_intensity, light_source_height, gamma_correction, stroke_density_clipping, light_color_red, light_color_green, light_color_blue, enabling_multiple_channel_effects): + + # Some pre-processing to resize images and remove input JPEG artifacts. + raw_image = min_resize(image, 512) + #raw_image = image + raw_image = raw_image.astype(np.float32) + unmasked_image = raw_image.copy() + + if mask is not None: + alpha = np.mean(d_resize(mask, raw_image.shape).astype(np.float32) / 255.0, axis=2, keepdims=True) + raw_image = unmasked_image * alpha + + # Compute the convex-hull-like palette. + h, w, c = raw_image.shape + flattened_raw_image = raw_image.reshape((h * w, c)) + raw_image_center = np.mean(flattened_raw_image, axis=0) + hull = ConvexHull(flattened_raw_image) + + + #import code + #code.interact(local=dict(globals(), **locals())) + # Estimate the stroke density map. + intersector = trimesh.Trimesh(faces=hull.simplices, vertices=hull.points).ray + start = np.tile(raw_image_center[None, :], [h * w, 1]) + direction = flattened_raw_image - start + print('Begin ray intersecting ...') + index_tri, index_ray, locations = intersector.intersects_id(start, direction, return_locations=True, multiple_hits=True) + + print('Intersecting finished.') + intersections = np.zeros(shape=(h * w, c), dtype=np.float32) + intersection_count = np.zeros(shape=(h * w, 1), dtype=np.float32) + CI = index_ray.shape[0] + for c in range(CI): + i = index_ray[c] + intersection_count[i] += 1 + intersections[i] += locations[c] + intersections = (intersections + 1e-10) / (intersection_count + 1e-10) + intersections = intersections.reshape((h, w, 3)) + intersection_count = intersection_count.reshape((h, w)) + intersections[intersection_count < 1] = raw_image[intersection_count < 1] + intersection_distance = np.sqrt(np.sum(np.square(intersections - raw_image_center[None, None, :]), axis=2, keepdims=True)) + pixel_distance = np.sqrt(np.sum(np.square(raw_image - raw_image_center[None, None, :]), axis=2, keepdims=True)) + stroke_density = ((1.0 - np.abs(1.0 - pixel_distance / intersection_distance)) * stroke_density_clipping).clip(0, 1) * 255 + + # A trick to improve the quality of the stroke density map. + # It uses guided filter to remove some possible artifacts. + # You can remove these codes if you like sharper effects. + guided_filter = createGuidedFilter(pixel_distance.clip(0, 255).astype(np.uint8), 1, 0.01) + for _ in range(4): + stroke_density = guided_filter.filter(stroke_density) + + # Visualize the estimated stroke density. + cv2.imwrite(r'D:\stroke_density.png', stroke_density.clip(0, 255).astype(np.uint8)) + + # Then generate the lighting effects + raw_image = unmasked_image.copy() + lighting_effect = np.stack([ + generate_lighting_effects(stroke_density, raw_image[:, :, 0]), + generate_lighting_effects(stroke_density, raw_image[:, :, 1]), + generate_lighting_effects(stroke_density, raw_image[:, :, 2]) + ], axis=2) + + + import code + code.interact(local=dict(globals(), **locals())) + + + # Using a simple user interface to display results. + + def update_mouse(event, x, y, flags, param): + global gx + global gy + gx = - float(x % w) / float(w) * 2.0 + 1.0 + gy = - float(y % h) / float(h) * 2.0 + 1.0 + return + + light_source_color = np.array([light_color_blue, light_color_green, light_color_red]) + + global gx + global gy + + while True: + light_source_location = np.array([[[light_source_height, gy, gx]]], dtype=np.float32) + light_source_direction = light_source_location / np.sqrt(np.sum(np.square(light_source_location))) + final_effect = np.sum(lighting_effect * light_source_direction, axis=3).clip(0, 1) + if not enabling_multiple_channel_effects: + final_effect = np.mean(final_effect, axis=2, keepdims=True) + rendered_image = (ambient_intensity + final_effect * light_intensity) * light_source_color * raw_image + rendered_image = ((rendered_image / 255.0) ** gamma_correction) * 255.0 + canvas = np.concatenate([raw_image, rendered_image], axis=1).clip(0, 255).astype(np.uint8) + + #import code + #code.interact(local=dict(globals(), **locals())) + + cv2.imshow('Move your mouse on the canvas to play!', canvas) + cv2.setMouseCallback('Move your mouse on the canvas to play!', update_mouse) + cv2.waitKey(10) + + + +class RGBRelighter: + """ + Generating Digital Painting Lighting Effects via RGB-space Geometry + from https://github.com/lllyasviel/PaintingLight + + + can raise error during construction + """ + + def __init__(self, img : np.ndarray): + stroke_density_clipping = 0.1 + + # Compute the convex-hull-like palette. + def_img = self._def_img = img + DH,DW,DC = def_img.shape + + + img = RGBRelighter._min_resize(img, 128).astype(np.float32) + h, w, c = img.shape + img_fl = img.reshape((h * w, c)) + img_fl_mean = np.mean(img_fl, axis=0) + + hull = ConvexHull(img_fl) + + # Estimate the stroke density map. + intersector = trimesh.Trimesh(faces=hull.simplices, vertices=hull.points).ray + start = np.tile(img_fl_mean[None, :], [h * w, 1]) + direction = img_fl - start + + index_tri, index_ray, locations = intersector.intersects_id(start, direction, return_locations=True, multiple_hits=True) + + intersections = np.zeros(shape=(h * w, c), dtype=np.float32) + intersection_count = np.zeros(shape=(h * w, 1), dtype=np.float32) + CI = index_ray.shape[0] + for c in range(CI): + i = index_ray[c] + intersection_count[i] += 1 + intersections[i] += locations[c] + + intersections = (intersections + 1e-10) / (intersection_count + 1e-10) + intersections = intersections.reshape((h, w, 3)) + intersection_count = intersection_count.reshape((h, w)) + intersections[intersection_count < 1] = img[intersection_count < 1] + intersection_distance = np.sqrt(np.sum(np.square(intersections - img_fl_mean[None, None, :]), axis=2, keepdims=True)) + pixel_distance = np.sqrt(np.sum(np.square(img - img_fl_mean[None, None, :]), axis=2, keepdims=True)) + stroke_density = ((1.0 - np.abs(1.0 - pixel_distance / intersection_distance)) * stroke_density_clipping).clip(0, 1) * 255 + + guided_filter = createGuidedFilter(pixel_distance.clip(0, 255).astype(np.uint8), 1, 0.01) + for _ in range(4): + stroke_density = guided_filter.filter(stroke_density) + cv2.imshow('stroke_density', stroke_density / stroke_density.max() ) + cv2.waitKey(0) + + stroke_density = cv2.resize(stroke_density, (DW, DH), interpolation=cv2.INTER_LANCZOS4) + + # Then generate the lighting effects + lighting_effect = np.stack([ + RGBRelighter._generate_lighting_effects(stroke_density, def_img[:, :, 0]), + RGBRelighter._generate_lighting_effects(stroke_density, def_img[:, :, 1]), + RGBRelighter._generate_lighting_effects(stroke_density, def_img[:, :, 2]) + ], axis=2) + + self._lighting_effect = lighting_effect + + + def relighted(self, light_pos, + light_intensity = 1.0, + light_color_bgr = (1.0,1.0,1.0), + light_source_height = 1.0, + ambient_intensity = 0.45, + gamma_correction = 1.0, + enabling_multiple_channel_effects = False): + """ + light_pos (X,Y) X,Y:[-1.0 ... +1.0] + """ + light_source_color = np.array(light_color_bgr, np.float32) + + light_source_location = np.array([[[light_source_height, light_pos[1], light_pos[0 ]]]], dtype=np.float32) + light_source_direction = light_source_location / np.sqrt(np.sum(np.square(light_source_location))) + final_effect = np.sum(self._lighting_effect * light_source_direction, axis=3).clip(0, 1) + + if not enabling_multiple_channel_effects: + final_effect = np.mean(final_effect, axis=2, keepdims=True) + + rendered_image = (ambient_intensity + final_effect * light_intensity) * light_source_color * self._def_img + rendered_image = ((rendered_image / 255.0) ** gamma_correction) * 255.0 + return rendered_image.clip(0, 255).astype(np.uint8) + + # Some image resizing tricks. + def _min_resize(x, m): + if x.shape[0] < x.shape[1]: + s0 = m + s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1])) + else: + s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0])) + s1 = m + new_max = min(s1, s0) + raw_max = min(x.shape[0], x.shape[1]) + return cv2.resize(x, (s1, s0), interpolation=cv2.INTER_LANCZOS4) + + # Some image resizing tricks. + def _d_resize(x, d, fac=1.0): + new_min = min(int(d[1] * fac), int(d[0] * fac)) + raw_min = min(x.shape[0], x.shape[1]) + if new_min < raw_min: + interpolation = cv2.INTER_AREA + else: + interpolation = cv2.INTER_LANCZOS4 + y = cv2.resize(x, (int(d[1] * fac), int(d[0] * fac)), interpolation=interpolation) + return y + + + def _get_image_gradient(dist): + cols = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, 0, +1], [-2, 0, +2], [-1, 0, +1]])) + rows = cv2.filter2D(dist, cv2.CV_32F, np.array([[-1, -2, -1], [0, 0, 0], [+1, +2, +1]])) + return cols, rows + + def _generate_lighting_effects(stroke_density, content): + h512 = content + h256 = cv2.pyrDown(h512) + h128 = cv2.pyrDown(h256) + h64 = cv2.pyrDown(h128) + h32 = cv2.pyrDown(h64) + h16 = cv2.pyrDown(h32) + c512, r512 = RGBRelighter._get_image_gradient(h512) + c256, r256 = RGBRelighter._get_image_gradient(h256) + c128, r128 = RGBRelighter._get_image_gradient(h128) + c64, r64 = RGBRelighter._get_image_gradient(h64) + c32, r32 = RGBRelighter._get_image_gradient(h32) + c16, r16 = RGBRelighter._get_image_gradient(h16) + c = c16 + c = RGBRelighter._d_resize(cv2.pyrUp(c), c32.shape) * 4.0 + c32 + c = RGBRelighter._d_resize(cv2.pyrUp(c), c64.shape) * 4.0 + c64 + c = RGBRelighter._d_resize(cv2.pyrUp(c), c128.shape) * 4.0 + c128 + c = RGBRelighter._d_resize(cv2.pyrUp(c), c256.shape) * 4.0 + c256 + c = RGBRelighter._d_resize(cv2.pyrUp(c), c512.shape) * 4.0 + c512 + r = r16 + r = RGBRelighter._d_resize(cv2.pyrUp(r), r32.shape) * 4.0 + r32 + r = RGBRelighter._d_resize(cv2.pyrUp(r), r64.shape) * 4.0 + r64 + r = RGBRelighter._d_resize(cv2.pyrUp(r), r128.shape) * 4.0 + r128 + r = RGBRelighter._d_resize(cv2.pyrUp(r), r256.shape) * 4.0 + r256 + r = RGBRelighter._d_resize(cv2.pyrUp(r), r512.shape) * 4.0 + r512 + coarse_effect_cols = c + coarse_effect_rows = r + EPS = 1e-10 + max_effect = np.max((coarse_effect_cols**2 + coarse_effect_rows**2)**0.5) + coarse_effect_cols = (coarse_effect_cols + EPS) / (max_effect + EPS) + coarse_effect_rows = (coarse_effect_rows + EPS) / (max_effect + EPS) + stroke_density_scaled = (stroke_density.astype(np.float32) / 255.0).clip(0, 1) + coarse_effect_cols *= (1.0 - stroke_density_scaled ** 2.0 + 1e-10) ** 0.5 + coarse_effect_rows *= (1.0 - stroke_density_scaled ** 2.0 + 1e-10) ** 0.5 + refined_result = np.stack([stroke_density_scaled, coarse_effect_rows, coarse_effect_cols], axis=2) + return refined_result + + +# from core import pathex +# from core.cv2ex import * +# from core.interact import interact as io +# from core.joblib import Subprocessor +# from DFLIMG import * +# from facelib import LandmarksProcessor, FaceType + + +# class FacesetRelighterSubprocessor(Subprocessor): + +# #override +# def __init__(self, image_paths): +# self.image_paths = image_paths + +# super().__init__('FacesetRelighter', FacesetRelighterSubprocessor.Cli, 600) + +# #override +# def on_clients_initialized(self): +# io.progress_bar (None, len (self.image_paths)) + +# #override +# def on_clients_finalized(self): +# io.progress_bar_close() + +# #override +# def process_info_generator(self): +# base_dict = {} + +# for device_idx in range( min(8, multiprocessing.cpu_count()) ): +# client_dict = base_dict.copy() +# device_name = f'CPU #{device_idx}' +# client_dict['device_name'] = device_name +# yield device_name, {}, client_dict + +# #override +# def get_data(self, host_dict): +# if len (self.image_paths) > 0: +# return self.image_paths.pop(0) + +# #override +# def on_data_return (self, host_dict, data): +# self.image_paths.insert(0, data) + +# #override +# def on_result (self, host_dict, data, result): +# io.progress_bar_inc(1) + +# #override +# def get_result(self): +# return None + +# class Cli(Subprocessor.Cli): + +# #override +# def on_initialize(self, client_dict): +# ... + +# #override +# def process_data(self, filepath): +# try: +# dflimg = DFLIMG.load (filepath) +# if dflimg is not None and dflimg.has_data(): + +# rd = {0: (-1.0, -1.0), +# 1: (0.0, -1.0), +# 2: (1.0, -1.0), +# 3: (-1.0, 0.0), +# 4: (1.0, 0.0), +# 5: (-1.0, 1.0), +# 6: (0.0, 1.0), +# 7: (1.0, 1.0), +# } + +# dfl_dict = dflimg.get_dict() + +# img = cv2_imread(filepath) + +# try: +# relighter = RGBRelighter(img) +# except: +# return + +# for i in rd: +# output_filepath = filepath.parent / (filepath.stem + f'_relighted{i}' + filepath.suffix) +# img_r = relighter.relighted(rd[i], light_intensity=1.5) +# cv2_imwrite ( str(output_filepath), img_r, [int(cv2.IMWRITE_JPEG_QUALITY), 100] ) + +# dflimg = DFLIMG.load (output_filepath) +# dflimg.set_dict(dfl_dict) +# dflimg.save() + +# except: +# self.log_err (f"Exception occured while processing file {filepath}. Error: {traceback.format_exc()}") + +# def relight_folder (dirpath): + + +# image_paths = [Path(x) for x in pathex.get_image_paths( dirpath ) ] +# result = FacesetRelighterSubprocessor (image_paths).run() + + +# import code +# code.interact(local=dict(globals(), **locals())) + +# io.log_info ( f"Processing to {output_dirpath_parts}") + +# output_images_paths = pathex.get_image_paths(output_dirpath) +# if len(output_images_paths) > 0: +# for filename in output_images_paths: +# Path(filename).unlink() + + +# result = FacesetRelighterSubprocessor ( image_paths, output_dirpath, image_size).run() + +# is_merge = io.input_bool (f"\r\nMerge {output_dirpath_parts} to {dirpath_parts} ?", True) +# if is_merge: +# io.log_info (f"Copying processed files to {dirpath_parts}") + +# for (filepath, output_filepath) in result: +# try: +# shutil.copy (output_filepath, filepath) +# except: +# pass + +# io.log_info (f"Removing {output_dirpath_parts}") +# shutil.rmtree(output_dirpath) + + +def main(): + # import tensorflow as tf + # interpreter = tf.lite.Interpreter(r'D:\DevelopPython\test\FaceMeshOrig.tflite') + # interpreter.allocate_tensors() + + # # Get input and output tensors. + # input_details = interpreter.get_input_details() + # output_details = interpreter.get_output_details() + + # # Test model on random input data. + # input_shape = input_details[0]['shape'] + # input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) + # interpreter.set_tensor(input_details[0]['index'], input_data) + + # interpreter.invoke() + # output_data = interpreter.get_tensor(output_details[0]['index']) + + # #input().fill(3.) + # #interpreter.allocate_tensors() + + # # while True: + # # with timeit(): + # # interpreter.invoke() + import tensorflow as tf + interpreter = tf.lite.Interpreter(r'D:\DevelopPython\test\FaceMeshOrig.tflite') + # interpreter.allocate_tensors() + + # # Get input and output tensors. + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + import tf2onnx + with tf.device("/CPU:0"): + model_proto, _ = tf2onnx.convert._convert_common( + None, + tflite_path=r'D:\DevelopPython\test\FaceMeshOrig.tflite', + name='SAEHD', + input_names=['input_1'], + output_names=['conv2d_21'], + opset=9, + output_path=r'D:\DevelopPython\test\FaceMeshOrig.onnx') + + import code + code.interact(local=dict(globals(), **locals())) + + + + imagepaths = pathex.get_image_paths(r'E:\FakeFaceVideoSources\Datasets\GenericXseg source\aligned ffhq') + imgs = [] + for filename in tqdm(imagepaths, ascii=True): + imgs.append( cv2_imread(filename) ) + + mean_img = np.mean(imgs, 0).astype(np.float32) / 255.0 + + cv2.imshow('', mean_img) + cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + + img = cv2_imread(imagepaths[0]).astype(np.float32) + + for y_offset in range(0,1): + for x_offset in range(-10,10): + aff = np.array([ [1,0,x_offset], + [0,1,y_offset]], np.float32) + x = cv2.warpAffine(img, aff, (1024,1024), flags=cv2.INTER_LANCZOS4) + + print(f'{x_offset},{y_offset} : {(((x-mean_img))**2).mean()}') + + import code + code.interact(local=dict(globals(), **locals())) + #=================================================== + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.BestGPU(), data_format="NCHW" ) + tf = nn.tf + tf_sess = nn.tf_sess + + t1 = tf.get_variable ("t1", (16384,16384), dtype=tf.float32) + t2 = tf.get_variable ("t2", (16384,16384), dtype=tf.float32) + loss = tf.matmul(t1,t2) + + nn.batch_set_value([( t1, np.random.randint( 2**8, size=(16384,16384) ).astype(np.float32))] ) + nn.batch_set_value([( t2, np.random.randint( 2**8, size=(16384,16384) ).astype(np.float32))] ) + + for i in range(100): + t = time.time() + q = nn.tf_sess.run ( [ loss ] ) + print(f'time {time.time()-t} ') + import code + code.interact(local=dict(globals(), **locals())) + #=================================================== + + + + from core.leras import nn + nn.initialize_main_env() + #nn.initialize( device_config=nn.DeviceConfig.GPUIndexes([0]), data_format="NCHW" ) + nn.initialize( device_config=nn.DeviceConfig.CPU(), data_format="NCHW" ) + tf = nn.tf + tf_sess = nn.tf_sess + + + # with tf.gfile.GFile(r'D:\DevelopPython\test\FaceMesh.pb', 'rb') as f: + # graph_def = tf.GraphDef() + # graph_def.ParseFromString(f.read()) + + # with tf.Graph().as_default() as graph: + # tf.import_graph_def(graph_def, name="") + + #input_tensor = graph.get_tensor_by_name(INPUT_TENSOR_NAME) + #output_tensor = graph.get_tensor_by_name(OUTPUT_TENSOR_NAME) + x = tf.lite.Interpreter(r'D:\DevelopPython\test\BlazeFace.tflite') + + #input = x.tensor(x.get_input_details()[0]["index"]) + #output = interpreter.tensor(interpreter.get_output_details()[0]["index"]) + + #input().fill(3.) + #interpreter.allocate_tensors() + + # while True: + # with timeit(): + # interpreter.invoke() + + import code + code.interact(local=dict(globals(), **locals())) + + #====================================== + + + import code + code.interact(local=dict(globals(), **locals())) + + #====================================== + + + + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.GPUIndexes([1]), data_format="NCHW" ) + tf = nn.tf + tf_sess = nn.tf_sess + + import onnx + from onnx_tf.backend import prepare + + onnx_model = onnx.load(r'D:\DevelopPPP\projects\DeepFaceLive\github_project\xlib\onnxruntime\FaceMesh\FaceMesh.onnx') + tf_rep = prepare(onnx_model) + tf_rep.export_graph(r'D:\1\FaceMesh.pb') + + import code + code.interact(local=dict(globals(), **locals())) + + #====================================== + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.GPUIndexes([1]), data_format="NCHW" ) + tf = nn.tf + tf_sess = nn.tf_sess + + + + + img = cv2.imread(r'D:\DevelopPython\test\00007.jpg') + img = img.astype(np.float32) / 255.0 + img = cv2.resize(img, (256,256)) + + img_np = img[None,...].transpose( (0,3,1,2) ) + + + class LaplacianPyramidConv(nn.LayerBase): + def __init__(self, ch=3, L=3): + super().__init__() + self.L = L + kernel = np.float32([ [1., 4., 6., 4., 1], + [4., 16., 24., 16., 4.], + [6., 24., 36., 24., 6.], + [4., 16., 24., 16., 4.], + [1., 4., 6., 4., 1.]])[...,None,None] / 256.0 + + kernel = np.tile(kernel, (1,1,ch,1) ) + self.kernel = tf.constant (kernel) + + def pyramid_decomp(self, inp_t): + current = inp_t + pyr = [] + for _ in range(self.L): + filtered = tf.nn.depthwise_conv2d(current, self.kernel, strides=[1,1,1,1], padding='SAME', data_format=nn.data_format) + down = filtered[...,::2,::2] + up = tf.nn.depthwise_conv2d( nn.upsample2d(down), self.kernel*4, strides=[1,1,1,1], padding='SAME', data_format=nn.data_format) + diff = current - up + pyr.insert(0, diff) + current = down + + pyr.insert(0, current) + return pyr + + def pyramid_comp(self, pyrs): + img_t = pyrs[0] + + for level_t in pyrs[1:]: + up_t = tf.nn.depthwise_conv2d( nn.upsample2d(img_t), self.kernel*4, strides=[1,1,1,1], padding='SAME', data_format=nn.data_format) + img_t = up_t + level_t + return img_t + + class TransHighFreq(nn.ModelBase): + def on_build(self, in_ch=3, ch=64): + self.conv1 = nn.Conv2D( in_ch*3, ch, kernel_size=3, strides=1, padding='SAME') + + self.conv2 = nn.Conv2D( ch, ch, kernel_size=1, strides=1, padding='SAME') + + def forward(self, inp): + x = inp + x = self.conv1(x) + x = tf.nn.leaky_relu(x, 0.1) + x = self.conv2(x) + return x + + class TransHighFreqBlock(nn.ModelBase): + def on_build(self, ch=3): + self.conv1 = nn.Conv2D( ch, 16, kernel_size=1, strides=1, padding='SAME') + self.conv2 = nn.Conv2D( 16, ch, kernel_size=1, strides=1, padding='SAME') + + def forward(self, inp): + x = inp + x = self.conv1(x) + x = tf.nn.leaky_relu(x, 0.1) + x = self.conv2(x) + return x + + lap_pyr_conv = LaplacianPyramidConv(L=2) + + res = 256 + bs = 1 + + inp_t = tf.placeholder(tf.float32, (bs,3,res,res) ) + + pyrs_t = lap_pyr_conv.pyramid_decomp(inp_t) + + out_t = lap_pyr_conv.pyramid_comp(pyrs_t) + + out_np = nn.tf_sess.run (out_t, feed_dict={inp_t:img_np }) + + #import code + #code.interact(local=dict(globals(), **locals())) + + + for x in [out_np]: + + x = x.transpose( (0,2,3,1))[0] + + #x -= x.min() + #x /= x.max() + + cv2.imshow('',x) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + #====================================== + + + + #import tflite + import tensorflow as tf + #model = tflite.Interpreter(model_path=r'F:\DeepFaceLabCUDA9.2SSE\_internal\model.tflite') + interpreter = tf.lite.Interpreter(model_path=r'F:\DeepFaceLabCUDA9.2SSE\_internal\model.tflite') + #interpreter.allocate_tensors() + + # Get input and output tensors. + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + import code + code.interact(local=dict(globals(), **locals())) + + + img = cv2.imread(r'D:\DeepFaceLabCUDA9.2SSE\workspace\data_src\aligned\00000.jpg') + img = img.astype(np.float32) / 255.0 + + + while True: + z = imagelib.apply_random_overlay_triangle(img, max_alpha=0.25) + + cv2.imshow('', z) + cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + + img = cv2.imread(r'D:\DeepFaceLabCUDA9.2SSE\workspace\data_src\aligned\00000.jpg') + img = img.astype(np.float32) / 255.0 + + while True: + with timeit(): + img_r = apply_random_relight(img) + cv2.imshow('', img_r) + cv2.waitKey(1) + import code + code.interact(local=dict(globals(), **locals())) + + + + img = cv2.imread(r'D:\DevelopPython\test\00000.png') + + relighter = RGBRelighter(img) + + img_r = relighter.relighted( (-1.0,-1.0) ) + + while True: + cv2.imshow('', img) + cv2.waitKey(0) + cv2.imshow('', img_r) + cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + + + relight_folder(r'D:\DeepFaceLabCUDA9.2SSE\workspace\data_src\aligned') + + import code + code.interact(local=dict(globals(), **locals())) + + + + + + + + #image = cv2.imread(r'D:\DeepFaceLabCUDA9.2SSE\workspace\data_src\aligned\00000.jpg') + #image = cv2.imread(r'D:\1.jpg') + image = cv2.imread(r'D:\DeepFaceLabCUDA9.2SSE\workspace\data_src\aligned1\00638.jpg') + mask = None + + ambient_intensity = 0.45 + light_intensity = 1.0 + light_source_height = 1.0 + gamma_correction = 1.0 + stroke_density_clipping = 0.1 + enabling_multiple_channel_effects = False + + light_color_red = 1.0 + light_color_green = 1.0 + light_color_blue = 1.0 + + run(image, mask, ambient_intensity, light_intensity, light_source_height, + gamma_correction, stroke_density_clipping, light_color_red, light_color_green, + light_color_blue, enabling_multiple_channel_effects) + + import code + code.interact(local=dict(globals(), **locals())) + #======================================= + + + + from numpy import linalg as npla + + def vector2_dot(a,b): + return a[...,0]*b[...,0]+a[...,1]*b[...,1] + + def vector2_dot2(a): + return a[...,0]*a[...,0]+a[...,1]*a[...,1] + + def vector2_cross(a,b): + return a[...,0]*b[...,1]-a[...,1]*b[...,0] + + def sd_bezier( wh, A, B, C ): + """ + returns drawn bezier in [h,w,1] output range float32, + every pixel contains signed distance to bezier line + + wh [w,h] resolution + A,B,C points [x,y] + """ + + width,height = wh + + A = np.float32(A) + B = np.float32(B) + C = np.float32(C) + + + pos = np.empty( (height,width,2), dtype=np.float32 ) + pos[...,0] = np.arange(width)[:,None] + pos[...,1] = np.arange(height)[None,:] + + + a = B-A + b = A - 2.0*B + C + c = a * 2.0 + d = A - pos + + b_dot = vector2_dot(b,b) + if b_dot == 0.0: + return np.zeros( (height,width), dtype=np.float32 ) + + kk = 1.0 / b_dot + + kx = kk * vector2_dot(a,b) + ky = kk * (2.0*vector2_dot(a,a)+vector2_dot(d,b))/3.0; + kz = kk * vector2_dot(d,a); + + res = 0.0; + sgn = 0.0; + + p = ky - kx*kx; + + p3 = p*p*p; + q = kx*(2.0*kx*kx - 3.0*ky) + kz; + h = q*q + 4.0*p3; + + hp_sel = h >= 0.0 + + hp_p = h[hp_sel] + hp_p = np.sqrt(hp_p) + + hp_x = ( np.stack( (hp_p,-hp_p), -1) -q[hp_sel,None] ) / 2.0 + hp_uv = np.sign(hp_x) * np.power( np.abs(hp_x), [1.0/3.0, 1.0/3.0] ) + hp_t = np.clip( hp_uv[...,0] + hp_uv[...,1] - kx, 0.0, 1.0 ) + + hp_t = hp_t[...,None] + hp_q = d[hp_sel]+(c+b*hp_t)*hp_t + hp_res = vector2_dot2(hp_q) + hp_sgn = vector2_cross(c+2.0*b*hp_t,hp_q) + + hl_sel = h < 0.0 + + hl_q = q[hl_sel] + hl_p = p[hl_sel] + hl_z = np.sqrt(-hl_p) + hl_v = np.arccos( hl_q / (hl_p*hl_z*2.0)) / 3.0 + + hl_m = np.cos(hl_v) + hl_n = np.sin(hl_v)*1.732050808; + + hl_t = np.clip( np.stack( (hl_m+hl_m,-hl_n-hl_m,hl_n-hl_m), -1)*hl_z[...,None]-kx, 0.0, 1.0 ); + + hl_d = d[hl_sel] + + hl_qx = hl_d+(c+b*hl_t[...,0:1])*hl_t[...,0:1] + + hl_dx = vector2_dot2(hl_qx) + hl_sx = vector2_cross(c+2.0*b*hl_t[...,0:1], hl_qx) + + hl_qy = hl_d+(c+b*hl_t[...,1:2])*hl_t[...,1:2] + hl_dy = vector2_dot2(hl_qy) + hl_sy = vector2_cross(c+2.0*b*hl_t[...,1:2],hl_qy); + + hl_dx_l_dy = hl_dx=hl_dy + + hl_res = np.empty_like(hl_dx) + hl_res[hl_dx_l_dy] = hl_dx[hl_dx_l_dy] + hl_res[hl_dx_ge_dy] = hl_dy[hl_dx_ge_dy] + + hl_sgn = np.empty_like(hl_sx) + hl_sgn[hl_dx_l_dy] = hl_sx[hl_dx_l_dy] + hl_sgn[hl_dx_ge_dy] = hl_sy[hl_dx_ge_dy] + + res = np.empty( (height, width), np.float32 ) + res[hp_sel] = hp_res + res[hl_sel] = hl_res + + sgn = np.empty( (height, width), np.float32 ) + sgn[hp_sel] = hp_sgn + sgn[hl_sel] = hl_sgn + + sgn = np.sign(sgn) + res = np.sqrt(res)*sgn + + return res[...,None] + + def random_bezier_split_faded( wh, ): + width, height = wh + + degA = np.random.randint(360) + degB = np.random.randint(360) + degC = np.random.randint(360) + + deg_2_rad = math.pi / 180.0 + + center = np.float32([width / 2.0, height / 2.0]) + + radius = max(width, height) + + A = center + radius*np.float32([ math.sin( degA * deg_2_rad), math.cos( degA * deg_2_rad) ] ) + B = center + np.random.randint(radius)*np.float32([ math.sin( degB * deg_2_rad), math.cos( degB * deg_2_rad) ] ) + C = center + radius*np.float32([ math.sin( degC * deg_2_rad), math.cos( degC * deg_2_rad) ] ) + + x = sd_bezier( (width,height), A, B, C ) + + x = x / (1+np.random.randint(radius)) + 0.5 + + x = np.clip(x, 0, 1) + return x + + while True: + + x = random_bezier_split_faded( [256,256]) + + cv2.imshow('', x) + cv2.waitKey(100) + + import code + code.interact(local=dict(globals(), **locals())) + + #====================================== + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.CPU() )# device_config=nn.DeviceConfig.GPUIndexes([1]) ) + tf = nn.tf + + def load_pb(path_to_pb): + with tf.gfile.GFile(path_to_pb, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name='') + return graph + + graph = load_pb (r"D:\DevelopPython\test\opencv_face_detector_uint8.pb") + + + # input0 = graph.get_tensor_by_name('VideoSR_Unet/inputFrame0:0') + # input1 = graph.get_tensor_by_name('VideoSR_Unet/inputFrame1:0') + # input2 = graph.get_tensor_by_name('VideoSR_Unet/inputFrame2:0') + # input3 = graph.get_tensor_by_name('VideoSR_Unet/inputFrame3:0') + # input4 = graph.get_tensor_by_name('VideoSR_Unet/inputFrame4:0') + # output = graph.get_tensor_by_name('VideoSR_Unet/Out4X/output/add:0') + #filepath = r'D:\DevelopPython\test\00000.jpg' + #img = cv2.imread(filepath).astype(np.float32) / 255.0 + #inp_img = img *2 - 1 + #inp_img = cv2.resize (inp_img, (192,192) ) + + sess = tf.Session(graph=graph, config=nn.tf_sess_config) + + import code + code.interact(local=dict(globals(), **locals())) + + #====================================== + + + class CenterFace(object): + def __init__(self): + self.net = cv2.dnn.readNetFromONNX(r"D:\DevelopPython\test\centerface.onnx") + self.img_h_new, self.img_w_new, self.scale_h, self.scale_w = 0, 0, 0, 0 + + def __call__(self, img, threshold=0.5): + h,w,c = img.shape + self.img_h_new, self.img_w_new, self.scale_h, self.scale_w = self.transform(h, w) + return self.inference_opencv(img, threshold) + + def inference_opencv(self, img, threshold): + blob = cv2.dnn.blobFromImage(img, scalefactor=1.0, size=(self.img_w_new, self.img_h_new), mean=(0, 0, 0), swapRB=True, crop=False) + self.net.setInput(blob) + heatmap, scale, offset, lms = self.net.forward(["537", "538", "539", '540']) + return self.postprocess(heatmap, lms, offset, scale, threshold) + + def transform(self, h, w): + img_h_new, img_w_new = int(np.ceil(h / 32) * 32), int(np.ceil(w / 32) * 32) + scale_h, scale_w = img_h_new / h, img_w_new / w + return img_h_new, img_w_new, scale_h, scale_w + + def postprocess(self, heatmap, lms, offset, scale, threshold): + dets, lms = self.decode(heatmap, scale, offset, lms, (self.img_h_new, self.img_w_new), threshold=threshold) + + if len(dets) > 0: + dets[:, 0:4:2], dets[:, 1:4:2] = dets[:, 0:4:2] / self.scale_w, dets[:, 1:4:2] / self.scale_h + lms[:, 0:10:2], lms[:, 1:10:2] = lms[:, 0:10:2] / self.scale_w, lms[:, 1:10:2] / self.scale_h + else: + dets = np.empty(shape=[0, 5], dtype=np.float32) + lms = np.empty(shape=[0, 10], dtype=np.float32) + return dets, lms + + def decode(self, heatmap, scale, offset, landmark, size, threshold=0.1): + heatmap = np.squeeze(heatmap) + scale0, scale1 = scale[0, 0, :, :], scale[0, 1, :, :] + offset0, offset1 = offset[0, 0, :, :], offset[0, 1, :, :] + c0, c1 = np.where(heatmap > threshold) + + boxes, lms = [], [] + if len(c0) > 0: + for i in range(len(c0)): + s0, s1 = np.exp(scale0[c0[i], c1[i]]) * 4, np.exp(scale1[c0[i], c1[i]]) * 4 + o0, o1 = offset0[c0[i], c1[i]], offset1[c0[i], c1[i]] + s = heatmap[c0[i], c1[i]] + x1, y1 = max(0, (c1[i] + o1 + 0.5) * 4 - s1 / 2), max(0, (c0[i] + o0 + 0.5) * 4 - s0 / 2) + x1, y1 = min(x1, size[1]), min(y1, size[0]) + boxes.append([x1, y1, min(x1 + s1, size[1]), min(y1 + s0, size[0]), s]) + + lm = [] + for j in range(5): + lm.append(landmark[0, j * 2 + 1, c0[i], c1[i]] * s1 + x1) + lm.append(landmark[0, j * 2, c0[i], c1[i]] * s0 + y1) + lms.append(lm) + boxes = np.asarray(boxes, dtype=np.float32) + keep = self.nms(boxes[:, :4], boxes[:, 4], 0.3) + boxes = boxes[keep, :] + + lms = np.asarray(lms, dtype=np.float32) + lms = lms[keep, :] + return boxes, lms + + def nms(self, boxes, scores, nms_thresh): + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = np.argsort(scores)[::-1] + num_detections = boxes.shape[0] + suppressed = np.zeros((num_detections,), dtype=np.bool) + + keep = [] + for _i in range(num_detections): + i = order[_i] + if suppressed[i]: + continue + keep.append(i) + + ix1 = x1[i] + iy1 = y1[i] + ix2 = x2[i] + iy2 = y2[i] + iarea = areas[i] + + for _j in range(_i + 1, num_detections): + j = order[_j] + if suppressed[j]: + continue + + xx1 = max(ix1, x1[j]) + yy1 = max(iy1, y1[j]) + xx2 = min(ix2, x2[j]) + yy2 = min(iy2, y2[j]) + w = max(0, xx2 - xx1 + 1) + h = max(0, yy2 - yy1 + 1) + + inter = w * h + ovr = inter / (iarea + areas[j] - inter) + if ovr >= nms_thresh: + suppressed[j] = True + + return keep + + + + import torch + import torch.nn as tnn + import torch.nn.functional as F + import torchvision as tv + + import onnx + img = cv2_imread(r"D:\DevelopPython\test\linus0.jpg") + + g = onnx.load(r"D:\DevelopPython\test\centerface.onnx").graph + weights = { i.name : i for i in g.initializer } + + #net = cv2.dnn.readNetFromONNX(r"D:\DevelopPython\test\centerface.onnx") + #net.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV) + #net.setPreferableTarget(cv2.dnn.DNN_TARGET_) + #blob = cv2.dnn.blobFromImage(img, scalefactor=1.0, size=(512, 512), mean=(0, 0, 0), swapRB=True, crop=False) + #net.setInput(blob) + + + + + def g_get_weight(s): + w = weights[s] + raw_data = w.raw_data + n_floats = len(raw_data) //4 + + f = struct.unpack('f'*n_floats, raw_data) + + n = np.array(f, dtype=np.float32).reshape(w.dims) + return n + import code + code.interact(local=dict(globals(), **locals())) + + s = "" + for node in g.node: + s += f'{node.op_type} ' + if node.op_type == 'Conv' or node.op_type == 'ConvTranspose': + + w = weights[node.input[1]] + out_ch, in_ch, _,_ = w.dims + + s += f'{in_ch}->{out_ch} ' + + + for attr in node.attribute: + if attr.name == 'kernel_shape': + s += f'k={attr.ints} ' + elif attr.name == 'pads': + s += f'pads={attr.ints} ' + elif attr.name == 'strides': + s += f'strides={attr.ints} ' + + + s += f': {node.input} {node.output} \n' + + #Path(r'D:\graph.txt').write_text(s) + + + + # + class CenterFaceTNN(tnn.Module): + def __init__(self): + super().__init__() + self.conv_363 = tnn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) + self.bn_364 = tnn.BatchNorm2d(32) + + self.dconv_366 = tnn.Conv2d(32, 32, 3, padding=1, groups=32, bias=False) + self.bn_367 = tnn.BatchNorm2d(32) + self.conv_369 = tnn.Conv2d(32, 16, 1, padding=0, bias=False) + self.bn_370 = tnn.BatchNorm2d(16) + + self.conv_371 = tnn.Conv2d(16, 96, 1, padding=0, bias=False) + self.bn_372 = tnn.BatchNorm2d(96) + self.dconv_374 = tnn.Conv2d(96, 96, 3, stride=2, padding=1, groups=96, bias=False) + self.bn_375 = tnn.BatchNorm2d(96) + self.conv_377 = tnn.Conv2d(96, 24, 1, padding=0, bias=False) + self.bn_378 = tnn.BatchNorm2d(24) + + self.conv_379 = tnn.Conv2d(24, 144, 1, padding=0, bias=False) + self.bn_380 = tnn.BatchNorm2d(144) + self.dconv_382 = tnn.Conv2d(144, 144, 3, padding=1, groups=144, bias=False) + self.bn_383 = tnn.BatchNorm2d(144) + self.conv_385 = tnn.Conv2d(144, 24, 1, padding=0, bias=False) + self.bn_386 = tnn.BatchNorm2d(24) + self.conv_388 = tnn.Conv2d(24, 144, 1, padding=0, bias=False) + self.bn_389 = tnn.BatchNorm2d(144) + self.dconv_391 = tnn.Conv2d(144, 144, 3, stride=2, padding=1, groups=144, bias=False) + self.bn_392 = tnn.BatchNorm2d(144) + self.conv_394 = tnn.Conv2d(144, 32, 1, padding=0, bias=False) + self.bn_395 = tnn.BatchNorm2d(32) + self.conv_396 = tnn.Conv2d(32, 192, 1, padding=0, bias=False) + self.bn_397 = tnn.BatchNorm2d(192) + self.dconv_399 = tnn.Conv2d(192, 192, 3, padding=1, groups=192, bias=False) + self.bn_400 = tnn.BatchNorm2d(192) + self.conv_402 = tnn.Conv2d(192, 32, 1, padding=0, bias=False) + self.bn_403 = tnn.BatchNorm2d(32) + self.conv_405 = tnn.Conv2d(32, 192, 1, padding=0, bias=False) + self.bn_406 = tnn.BatchNorm2d(192) + self.dconv_408 = tnn.Conv2d(192, 192, 3, padding=1, groups=192, bias=False) + self.bn_409 = tnn.BatchNorm2d(192) + self.conv_411 = tnn.Conv2d(192, 32, 1, padding=0, bias=False) + self.bn_412 = tnn.BatchNorm2d(32) + self.conv_414 = tnn.Conv2d(32, 192, 1, padding=0, bias=False) + self.bn_415 = tnn.BatchNorm2d(192) + self.dconv_417 = tnn.Conv2d(192, 192, 3, stride=2, padding=1, groups=192, bias=False) + self.bn_418 = tnn.BatchNorm2d(192) + self.conv_420 = tnn.Conv2d(192, 64, 1, padding=0, bias=False) + self.bn_421 = tnn.BatchNorm2d(64) + self.conv_422 = tnn.Conv2d(64, 384, 1, padding=0, bias=False) + self.bn_423 = tnn.BatchNorm2d(384) + self.dconv_425 = tnn.Conv2d(384, 384, 3, padding=1, groups=384, bias=False) + self.bn_426 = tnn.BatchNorm2d(384) + self.conv_428 = tnn.Conv2d(384, 64, 1, padding=0, bias=False) + self.bn_429 = tnn.BatchNorm2d(64) + self.conv_431 = tnn.Conv2d(64, 384, 1, padding=0, bias=False) + self.bn_432 = tnn.BatchNorm2d(384) + self.dconv_434 = tnn.Conv2d(384, 384, 3, padding=1, groups=384, bias=False) + self.bn_435 = tnn.BatchNorm2d(384) + self.conv_437 = tnn.Conv2d(384, 64, 1, padding=0, bias=False) + self.bn_438 = tnn.BatchNorm2d(64) + self.conv_440 = tnn.Conv2d(64, 384, 1, padding=0, bias=False) + self.bn_441 = tnn.BatchNorm2d(384) + self.dconv_443 = tnn.Conv2d(384, 384, 3, padding=1, groups=384, bias=False) + self.bn_444 = tnn.BatchNorm2d(384) + self.conv_446 = tnn.Conv2d(384, 64, 1, padding=0, bias=False) + self.bn_447 = tnn.BatchNorm2d(64) + + self.conv_449 = tnn.Conv2d(64, 384, 1, padding=0, bias=False) + self.bn_450 = tnn.BatchNorm2d(384) + self.dconv_452 = tnn.Conv2d(384, 384, 3, padding=1, groups=384, bias=False) + self.bn_453 = tnn.BatchNorm2d(384) + self.conv_455 = tnn.Conv2d(384, 96, 1, padding=0, bias=False) + self.bn_456 = tnn.BatchNorm2d(96) + + self.conv_457 = tnn.Conv2d(96, 576, 1, padding=0, bias=False) + self.bn_458 = tnn.BatchNorm2d(576) + self.dconv_460 = tnn.Conv2d(576, 576, 3, padding=1, groups=576, bias=False) + self.bn_461 = tnn.BatchNorm2d(576) + self.conv_463 = tnn.Conv2d(576, 96, 1, padding=0, bias=False) + self.bn_464 = tnn.BatchNorm2d(96) + + self.conv_466 = tnn.Conv2d(96, 576, 1, padding=0, bias=False) + self.bn_467 = tnn.BatchNorm2d(576) + self.dconv_469 = tnn.Conv2d(576, 576, 3, padding=1, groups=576, bias=False) + self.bn_470 = tnn.BatchNorm2d(576) + self.conv_472 = tnn.Conv2d(576, 96, 1, padding=0, bias=False) + self.bn_473 = tnn.BatchNorm2d(96) + + self.conv_475 = tnn.Conv2d(96, 576, 1, padding=0, bias=False) + self.bn_476 = tnn.BatchNorm2d(576) + self.dconv_478 = tnn.Conv2d(576, 576, 3, stride=2, padding=1, groups=576, bias=False) + self.bn_479 = tnn.BatchNorm2d(576) + self.conv_481 = tnn.Conv2d(576, 160, 1, padding=0, bias=False) + self.bn_482 = tnn.BatchNorm2d(160) + + self.conv_483 = tnn.Conv2d(160, 960, 1, padding=0, bias=False) + self.bn_484 = tnn.BatchNorm2d(960) + self.dconv_486 = tnn.Conv2d(960, 960, 3, padding=1, groups=960, bias=False) + self.bn_487 = tnn.BatchNorm2d(960) + self.conv_489 = tnn.Conv2d(960, 160, 1, padding=0, bias=False) + self.bn_490 = tnn.BatchNorm2d(160) + + self.conv_492 = tnn.Conv2d(160, 960, 1, padding=0, bias=False) + self.bn_493 = tnn.BatchNorm2d(960) + self.dconv_495 = tnn.Conv2d(960, 960, 3, padding=1, groups=960, bias=False) + self.bn_496 = tnn.BatchNorm2d(960) + self.conv_498 = tnn.Conv2d(960, 160, 1, padding=0, bias=False) + self.bn_499 = tnn.BatchNorm2d(160) + + self.conv_501 = tnn.Conv2d(160, 960, 1, padding=0, bias=False) + self.bn_502 = tnn.BatchNorm2d(960) + self.dconv_504 = tnn.Conv2d(960, 960, 3, padding=1, groups=960, bias=False) + self.bn_505 = tnn.BatchNorm2d(960) + self.conv_507 = tnn.Conv2d(960, 320, 1, padding=0, bias=False) + self.bn_508 = tnn.BatchNorm2d(320) + + self.conv_509 = tnn.Conv2d(320, 24, 1, padding=0, bias=False) + self.bn_510 = tnn.BatchNorm2d(24) + + self.conv_512 = tnn.ConvTranspose2d(24, 24, 2, stride=2, padding=0, bias=False) + self.bn_513 = tnn.BatchNorm2d(24) + + self.conv_515 = tnn.Conv2d(96, 24, 1, padding=0, bias=False) + self.bn_516 = tnn.BatchNorm2d(24) + + self.conv_519 = tnn.ConvTranspose2d(24,24, 2, stride=2, padding=0, bias=False) + self.bn_520 = tnn.BatchNorm2d(24) + + self.conv_522 = tnn.Conv2d(32, 24, 1, padding=0, bias=False) + self.bn_523 = tnn.BatchNorm2d(24) + + self.conv_526 = tnn.ConvTranspose2d(24,24, 2, stride=2, padding=0, bias=False) + self.bn_527 = tnn.BatchNorm2d(24) + + self.conv_529 = tnn.Conv2d(24, 24, 1, padding=0, bias=False) + self.bn_530 = tnn.BatchNorm2d(24) + + self.conv_533 = tnn.Conv2d(24, 24, 3, padding=1, bias=False) + self.bn_534 = tnn.BatchNorm2d(24) + + self.conv_536 = tnn.Conv2d(24, 1, 1) + self.conv_538 = tnn.Conv2d(24, 2, 1) + self.conv_539 = tnn.Conv2d(24, 2, 1) + self.conv_540 = tnn.Conv2d(24, 10, 1) + + + def forward(self, x): + x = self.conv_363(x) + x = self.bn_364(x) + x = F.relu(x) + + x = self.dconv_366(x) + x = self.bn_367(x) + x = F.relu(x) + x = self.conv_369(x) + x = self.bn_370(x) + + x = self.conv_371(x) + x = self.bn_372(x) + x = F.relu(x) + x = self.dconv_374(x) + x = self.bn_375(x) + x = F.relu(x) + x = self.conv_377(x) + x = x378 = self.bn_378(x) + x = self.conv_379(x) + x = self.bn_380(x) + x = F.relu(x) + x = self.dconv_382(x) + x = self.bn_383(x) + x = F.relu(x) + x = self.conv_385(x) + x = self.bn_386(x) + x = x387 = x + x378 + x = self.conv_388(x) + x = self.bn_389(x) + x = F.relu(x) + x = self.dconv_391(x) + x = self.bn_392(x) + x = F.relu(x) + x = self.conv_394(x) + x = x395 = self.bn_395(x) + x = self.conv_396(x) + x = self.bn_397(x) + x = F.relu(x) + x = self.dconv_399(x) + x = self.bn_400(x) + x = F.relu(x) + x = self.conv_402(x) + x = self.bn_403(x) + x = x404 = x + x395 + x = self.conv_405(x) + x = self.bn_406(x) + x = F.relu(x) + x = self.dconv_408(x) + x = self.bn_409(x) + x = F.relu(x) + x = self.conv_411(x) + x = self.bn_412(x) + x = x413 = x + x404 + x = self.conv_414(x) + x = self.bn_415(x) + x = F.relu(x) + x = self.dconv_417(x) + x = self.bn_418(x) + x = F.relu(x) + x = self.conv_420(x) + x = x421 = self.bn_421(x) + x = self.conv_422(x) + x = self.bn_423(x) + x = F.relu(x) + x = self.dconv_425(x) + x = self.bn_426(x) + x = F.relu(x) + x = self.conv_428(x) + x = self.bn_429(x) + x = x430 = x + x421 + x = self.conv_431(x) + x = self.bn_432(x) + x = F.relu(x) + x = self.dconv_434(x) + x = self.bn_435(x) + x = F.relu(x) + x = self.conv_437(x) + x = self.bn_438(x) + x = x439 = x + x430 + + x = self.conv_440(x) + x = self.bn_441(x) + x = F.relu(x) + x = self.dconv_443(x) + x = self.bn_444(x) + x = F.relu(x) + x = self.conv_446(x) + x = self.bn_447(x) + x = x + x439 + + x = self.conv_449(x) + x = self.bn_450(x) + x = F.relu(x) + x = self.dconv_452(x) + x = self.bn_453(x) + x = F.relu(x) + x = self.conv_455(x) + x = x456 = self.bn_456(x) + + x = self.conv_457(x) + x = self.bn_458(x) + x = F.relu(x) + x = self.dconv_460(x) + x = self.bn_461(x) + x = F.relu(x) + x = self.conv_463(x) + x = self.bn_464(x) + + x = x465 = x + x456 + + x = self.conv_466(x) + x = self.bn_467(x) + x = F.relu(x) + x = self.dconv_469(x) + x = self.bn_470(x) + x = F.relu(x) + x = self.conv_472(x) + x = self.bn_473(x) + + x = x474 = x + x465 + + x = self.conv_475(x) + x = self.bn_476(x) + x = F.relu(x) + x = self.dconv_478(x) + x = self.bn_479(x) + x = F.relu(x) + x = self.conv_481(x) + x = x482 = self.bn_482(x) + + x = self.conv_483(x) + x = self.bn_484(x) + x = F.relu(x) + x = self.dconv_486(x) + x = self.bn_487(x) + x = F.relu(x) + x = self.conv_489(x) + x = self.bn_490(x) + + x = x491 = x + x482 + + x = self.conv_492(x) + x = self.bn_493(x) + x = F.relu(x) + x = self.dconv_495(x) + x = self.bn_496(x) + x = F.relu(x) + x = self.conv_498(x) + x = self.bn_499(x) + + x = x + x491 + + x = self.conv_501(x) + x = self.bn_502(x) + x = F.relu(x) + x = self.dconv_504(x) + x = self.bn_505(x) + x = F.relu(x) + x = self.conv_507(x) + x = self.bn_508(x) + + x = self.conv_509(x) + x = self.bn_510(x) + x = F.relu(x) + + x = self.conv_512(x) + x = self.bn_513(x) + x = x514 = F.relu(x) + + x = self.conv_515(x474) + x = self.bn_516(x) + x = F.relu(x) + + x = x + x514 + + x = self.conv_519(x) + x = self.bn_520(x) + x = x521 = F.relu(x) + + x = self.conv_522(x413) + x = self.bn_523(x) + x = F.relu(x) + + x = x + x521 + + x = self.conv_526(x) + x = self.bn_527(x) + x = x528 = F.relu(x) + + x = self.conv_529(x387) + x = self.bn_530(x) + x = F.relu(x) + + x = x + x528 + + x = self.conv_533(x) + x = self.bn_534(x) + x = F.relu(x) + + heatmap = torch.sigmoid( self.conv_536(x) ) + scale = self.conv_538(x) + offset = self.conv_539(x) + lms = self.conv_540(x) + + return heatmap, scale, offset, lms + + class CenterFaceLNN (lnn.Module): + def __init__(self): + self.conv_363 = lnn.Conv2D(3, 32, 3, 2, use_bias=False) + self.bn_364 = lnn.BatchNorm2D(32) + + self.dconv_366 = lnn.DepthwiseConv2D(32, 3, use_bias=False) + self.bn_367 = lnn.BatchNorm2D(32) + self.conv_369 = lnn.Conv2D(32, 16, 1, use_bias=False) + self.bn_370 = lnn.BatchNorm2D(16) + + self.conv_371 = lnn.Conv2D(16, 96, 1, use_bias=False) + self.bn_372 = lnn.BatchNorm2D(96) + self.dconv_374 = lnn.DepthwiseConv2D(96, 3, 2, use_bias=False) + self.bn_375 = lnn.BatchNorm2D(96) + self.conv_377 = lnn.Conv2D(96, 24, 1, use_bias=False) + self.bn_378 = lnn.BatchNorm2D(24) + + self.conv_379 = lnn.Conv2D(24, 144, 1, use_bias=False) + self.bn_380 = lnn.BatchNorm2D(144) + self.dconv_382 = lnn.DepthwiseConv2D(144, 3, use_bias=False) + self.bn_383 = lnn.BatchNorm2D(144) + self.conv_385 = lnn.Conv2D(144, 24, 1, use_bias=False) + self.bn_386 = lnn.BatchNorm2D(24) + self.conv_388 = lnn.Conv2D(24, 144, 1, use_bias=False) + self.bn_389 = lnn.BatchNorm2D(144) + self.dconv_391 = lnn.DepthwiseConv2D(144, 3, 2, use_bias=False) + self.bn_392 = lnn.BatchNorm2D(144) + self.conv_394 = lnn.Conv2D(144, 32, 1, use_bias=False) + self.bn_395 = lnn.BatchNorm2D(32) + self.conv_396 = lnn.Conv2D(32, 192, 1, use_bias=False) + self.bn_397 = lnn.BatchNorm2D(192) + self.dconv_399 = lnn.DepthwiseConv2D(192, 3, use_bias=False) + self.bn_400 = lnn.BatchNorm2D(192) + self.conv_402 = lnn.Conv2D(192, 32, 1, use_bias=False) + self.bn_403 = lnn.BatchNorm2D(32) + self.conv_405 = lnn.Conv2D(32, 192, 1, use_bias=False) + self.bn_406 = lnn.BatchNorm2D(192) + self.dconv_408 = lnn.DepthwiseConv2D(192, 3, use_bias=False) + self.bn_409 = lnn.BatchNorm2D(192) + self.conv_411 = lnn.Conv2D(192, 32, 1, use_bias=False) + self.bn_412 = lnn.BatchNorm2D(32) + self.conv_414 = lnn.Conv2D(32, 192, 1, use_bias=False) + self.bn_415 = lnn.BatchNorm2D(192) + self.dconv_417 = lnn.DepthwiseConv2D(192, 3, 2, use_bias=False) + self.bn_418 = lnn.BatchNorm2D(192) + self.conv_420 = lnn.Conv2D(192, 64, 1, use_bias=False) + self.bn_421 = lnn.BatchNorm2D(64) + self.conv_422 = lnn.Conv2D(64, 384, 1, use_bias=False) + self.bn_423 = lnn.BatchNorm2D(384) + self.dconv_425 = lnn.DepthwiseConv2D(384, 3, use_bias=False) + self.bn_426 = lnn.BatchNorm2D(384) + self.conv_428 = lnn.Conv2D(384, 64, 1, use_bias=False) + self.bn_429 = lnn.BatchNorm2D(64) + self.conv_431 = lnn.Conv2D(64, 384, 1, use_bias=False) + self.bn_432 = lnn.BatchNorm2D(384) + self.dconv_434 = lnn.DepthwiseConv2D(384, 3, use_bias=False) + self.bn_435 = lnn.BatchNorm2D(384) + self.conv_437 = lnn.Conv2D(384, 64, 1, use_bias=False) + self.bn_438 = lnn.BatchNorm2D(64) + self.conv_440 = lnn.Conv2D(64, 384, 1, use_bias=False) + self.bn_441 = lnn.BatchNorm2D(384) + self.dconv_443 = lnn.DepthwiseConv2D(384, 3, use_bias=False) + self.bn_444 = lnn.BatchNorm2D(384) + self.conv_446 = lnn.Conv2D(384, 64, 1, use_bias=False) + self.bn_447 = lnn.BatchNorm2D(64) + + self.conv_449 = lnn.Conv2D(64, 384, 1, use_bias=False) + self.bn_450 = lnn.BatchNorm2D(384) + self.dconv_452 = lnn.DepthwiseConv2D(384, 3, use_bias=False) + self.bn_453 = lnn.BatchNorm2D(384) + self.conv_455 = lnn.Conv2D(384, 96, 1, use_bias=False) + self.bn_456 = lnn.BatchNorm2D(96) + + self.conv_457 = lnn.Conv2D(96, 576, 1, use_bias=False) + self.bn_458 = lnn.BatchNorm2D(576) + self.dconv_460 = lnn.DepthwiseConv2D(576, 3, use_bias=False) + self.bn_461 = lnn.BatchNorm2D(576) + self.conv_463 = lnn.Conv2D(576, 96, 1, use_bias=False) + self.bn_464 = lnn.BatchNorm2D(96) + + self.conv_466 = lnn.Conv2D(96, 576, 1, use_bias=False) + self.bn_467 = lnn.BatchNorm2D(576) + self.dconv_469 = lnn.DepthwiseConv2D(576, 3, use_bias=False) + self.bn_470 = lnn.BatchNorm2D(576) + self.conv_472 = lnn.Conv2D(576, 96, 1, use_bias=False) + self.bn_473 = lnn.BatchNorm2D(96) + + self.conv_475 = lnn.Conv2D(96, 576, 1, use_bias=False) + self.bn_476 = lnn.BatchNorm2D(576) + self.dconv_478 = lnn.DepthwiseConv2D(576, 3, 2, use_bias=False) + self.bn_479 = lnn.BatchNorm2D(576) + self.conv_481 = lnn.Conv2D(576, 160, 1, use_bias=False) + self.bn_482 = lnn.BatchNorm2D(160) + + self.conv_483 = lnn.Conv2D(160, 960, 1, use_bias=False) + self.bn_484 = lnn.BatchNorm2D(960) + self.dconv_486 = lnn.DepthwiseConv2D(960, 3, use_bias=False) + self.bn_487 = lnn.BatchNorm2D(960) + self.conv_489 = lnn.Conv2D(960, 160, 1, use_bias=False) + self.bn_490 = lnn.BatchNorm2D(160) + + self.conv_492 = lnn.Conv2D(160, 960, 1, use_bias=False) + self.bn_493 = lnn.BatchNorm2D(960) + self.dconv_495 = lnn.DepthwiseConv2D(960, 3, use_bias=False) + self.bn_496 = lnn.BatchNorm2D(960) + self.conv_498 = lnn.Conv2D(960, 160, 1, use_bias=False) + self.bn_499 = lnn.BatchNorm2D(160) + + self.conv_501 = lnn.Conv2D(160, 960, 1, use_bias=False) + self.bn_502 = lnn.BatchNorm2D(960) + self.dconv_504 = lnn.DepthwiseConv2D(960, 3, use_bias=False) + self.bn_505 = lnn.BatchNorm2D(960) + self.conv_507 = lnn.Conv2D(960, 320, 1, use_bias=False) + self.bn_508 = lnn.BatchNorm2D(320) + + self.conv_509 = lnn.Conv2D(320, 24, 1, use_bias=False) + self.bn_510 = lnn.BatchNorm2D(24) + + self.conv_512 = lnn.Conv2DTranspose(24,24, 2, 2, use_bias=False) + self.bn_513 = lnn.BatchNorm2D(24) + + self.conv_515 = lnn.Conv2D(96, 24, 1, use_bias=False) + self.bn_516 = lnn.BatchNorm2D(24) + + self.conv_519 = lnn.Conv2DTranspose(24,24, 2, 2, use_bias=False) + self.bn_520 = lnn.BatchNorm2D(24) + + self.conv_522 = lnn.Conv2D(32, 24, 1, use_bias=False) + self.bn_523 = lnn.BatchNorm2D(24) + + self.conv_526 = lnn.Conv2DTranspose(24,24, 2, 2, use_bias=False) + self.bn_527 = lnn.BatchNorm2D(24) + + self.conv_529 = lnn.Conv2D(24, 24, 1, use_bias=False) + self.bn_530 = lnn.BatchNorm2D(24) + + self.conv_533 = lnn.Conv2D(24, 24, 3, use_bias=False) + self.bn_534 = lnn.BatchNorm2D(24) + + self.conv_536 = lnn.Conv2D(24, 1, 1) + self.conv_538 = lnn.Conv2D(24, 2, 1) + self.conv_539 = lnn.Conv2D(24, 2, 1) + self.conv_540 = lnn.Conv2D(24, 10, 1) + + def forward(self, x): + x = self.conv_363(x) + x = self.bn_364(x) + x = lnn.relu(x) + + x = self.dconv_366(x) + x = self.bn_367(x) + x = lnn.relu(x) + x = self.conv_369(x) + x = self.bn_370(x) + + x = self.conv_371(x) + x = self.bn_372(x) + x = lnn.relu(x) + x = self.dconv_374(x) + x = self.bn_375(x) + x = lnn.relu(x) + x = self.conv_377(x) + x = x378 = self.bn_378(x) + x = self.conv_379(x) + x = self.bn_380(x) + x = lnn.relu(x) + x = self.dconv_382(x) + x = self.bn_383(x) + x = lnn.relu(x) + x = self.conv_385(x) + x = self.bn_386(x) + x = x387 = x + x378 + x = self.conv_388(x) + x = self.bn_389(x) + x = lnn.relu(x) + x = self.dconv_391(x) + x = self.bn_392(x) + x = lnn.relu(x) + x = self.conv_394(x) + x = x395 = self.bn_395(x) + x = self.conv_396(x) + x = self.bn_397(x) + x = lnn.relu(x) + x = self.dconv_399(x) + x = self.bn_400(x) + x = lnn.relu(x) + x = self.conv_402(x) + x = self.bn_403(x) + x = x404 = x + x395 + x = self.conv_405(x) + x = self.bn_406(x) + x = lnn.relu(x) + x = self.dconv_408(x) + x = self.bn_409(x) + x = lnn.relu(x) + x = self.conv_411(x) + x = self.bn_412(x) + x = x413 = x + x404 + x = self.conv_414(x) + x = self.bn_415(x) + x = lnn.relu(x) + x = self.dconv_417(x) + x = self.bn_418(x) + x = lnn.relu(x) + x = self.conv_420(x) + x = x421 = self.bn_421(x) + x = self.conv_422(x) + x = self.bn_423(x) + x = lnn.relu(x) + x = self.dconv_425(x) + x = self.bn_426(x) + x = lnn.relu(x) + x = self.conv_428(x) + x = self.bn_429(x) + x = x430 = x + x421 + x = self.conv_431(x) + x = self.bn_432(x) + x = lnn.relu(x) + x = self.dconv_434(x) + x = self.bn_435(x) + x = lnn.relu(x) + x = self.conv_437(x) + x = self.bn_438(x) + x = x439 = x + x430 + + x = self.conv_440(x) + x = self.bn_441(x) + x = lnn.relu(x) + x = self.dconv_443(x) + x = self.bn_444(x) + x = lnn.relu(x) + x = self.conv_446(x) + x = self.bn_447(x) + x = x + x439 + + x = self.conv_449(x) + x = self.bn_450(x) + x = lnn.relu(x) + x = self.dconv_452(x) + x = self.bn_453(x) + x = lnn.relu(x) + x = self.conv_455(x) + x = x456 = self.bn_456(x) + + x = self.conv_457(x) + x = self.bn_458(x) + x = lnn.relu(x) + x = self.dconv_460(x) + x = self.bn_461(x) + x = lnn.relu(x) + x = self.conv_463(x) + x = self.bn_464(x) + + x = x465 = x + x456 + + x = self.conv_466(x) + x = self.bn_467(x) + x = lnn.relu(x) + x = self.dconv_469(x) + x = self.bn_470(x) + x = lnn.relu(x) + x = self.conv_472(x) + x = self.bn_473(x) + + x = x474 = x + x465 + + x = self.conv_475(x) + x = self.bn_476(x) + x = lnn.relu(x) + x = self.dconv_478(x) + x = self.bn_479(x) + x = lnn.relu(x) + x = self.conv_481(x) + x = x482 = self.bn_482(x) + + x = self.conv_483(x) + x = self.bn_484(x) + x = lnn.relu(x) + x = self.dconv_486(x) + x = self.bn_487(x) + x = lnn.relu(x) + x = self.conv_489(x) + x = self.bn_490(x) + + x = x491 = x + x482 + + x = self.conv_492(x) + x = self.bn_493(x) + x = lnn.relu(x) + x = self.dconv_495(x) + x = self.bn_496(x) + x = lnn.relu(x) + x = self.conv_498(x) + x = self.bn_499(x) + + x = x + x491 + + x = self.conv_501(x) + x = self.bn_502(x) + x = lnn.relu(x) + x = self.dconv_504(x) + x = self.bn_505(x) + x = lnn.relu(x) + x = self.conv_507(x) + x = self.bn_508(x) + + x = self.conv_509(x) + x = self.bn_510(x) + x = lnn.relu(x) + + x = self.conv_512(x) + x = self.bn_513(x) + x = x514 = lnn.relu(x) + + x = self.conv_515(x474) + x = self.bn_516(x) + x = lnn.relu(x) + + x = x + x514 + + x = self.conv_519(x) + x = self.bn_520(x) + x = x521 = lnn.relu(x) + + x = self.conv_522(x413) + x = self.bn_523(x) + x = lnn.relu(x) + + x = x + x521 + + x = self.conv_526(x) + x = self.bn_527(x) + x = x528 = lnn.relu(x) + + x = self.conv_529(x387) + x = self.bn_530(x) + x = lnn.relu(x) + + x = x + x528 + + x = self.conv_533(x) + x = self.bn_534(x) + x = lnn.relu(x) + + heatmap = lnn.sigmoid( self.conv_536(x) ) + scale = self.conv_538(x) + offset = self.conv_539(x) + lms = self.conv_540(x) + + return heatmap, scale, offset, lms + + c_tnn = CenterFaceTNN() + c_tnn.eval() + + c_tnn.conv_363.weight.data = torch.from_numpy ( g_get_weight('backbone.features.0.0.weight')) + c_tnn.bn_364.weight.data = torch.from_numpy ( g_get_weight('backbone.features.0.1.weight') ) + c_tnn.bn_364.bias.data = torch.from_numpy ( g_get_weight('backbone.features.0.1.bias') ) + c_tnn.bn_364.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.0.1.running_mean') ) + c_tnn.bn_364.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.0.1.running_var') ) + c_tnn.dconv_366.weight.data = torch.from_numpy ( g_get_weight('backbone.features.1.conv.0.weight') ) + c_tnn.bn_367.weight.data = torch.from_numpy ( g_get_weight('backbone.features.1.conv.1.weight') ) + c_tnn.bn_367.bias.data = torch.from_numpy ( g_get_weight('backbone.features.1.conv.1.bias') ) + c_tnn.bn_367.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.1.conv.1.running_mean') ) + c_tnn.bn_367.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.1.conv.1.running_var') ) + c_tnn.conv_369.weight.data = torch.from_numpy ( g_get_weight('backbone.features.1.conv.3.weight') ) + c_tnn.bn_370.weight.data = torch.from_numpy ( g_get_weight('backbone.features.1.conv.4.weight') ) + c_tnn.bn_370.bias.data = torch.from_numpy ( g_get_weight('backbone.features.1.conv.4.bias') ) + c_tnn.bn_370.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.1.conv.4.running_mean') ) + c_tnn.bn_370.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.1.conv.4.running_var') ) + c_tnn.conv_371.weight.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.0.weight') ) + c_tnn.bn_372.weight.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.1.weight') ) + c_tnn.bn_372.bias.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.1.bias') ) + c_tnn.bn_372.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.1.running_mean') ) + c_tnn.bn_372.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.1.running_var') ) + c_tnn.dconv_374.weight.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.3.weight') ) + c_tnn.bn_375.weight.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.4.weight') ) + c_tnn.bn_375.bias.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.4.bias') ) + c_tnn.bn_375.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.4.running_mean') ) + c_tnn.bn_375.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.4.running_var') ) + c_tnn.conv_377.weight.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.6.weight') ) + c_tnn.bn_378.weight.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.7.weight') ) + c_tnn.bn_378.bias.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.7.bias') ) + c_tnn.bn_378.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.7.running_mean') ) + c_tnn.bn_378.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.2.conv.7.running_var') ) + c_tnn.conv_379.weight.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.0.weight') ) + c_tnn.bn_380.weight.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.1.weight') ) + c_tnn.bn_380.bias.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.1.bias') ) + c_tnn.bn_380.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.1.running_mean') ) + c_tnn.bn_380.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.1.running_var') ) + c_tnn.dconv_382.weight.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.3.weight') ) + c_tnn.bn_383.weight.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.4.weight') ) + c_tnn.bn_383.bias.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.4.bias') ) + c_tnn.bn_383.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.4.running_mean') ) + c_tnn.bn_383.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.4.running_var') ) + c_tnn.conv_385.weight.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.6.weight') ) + c_tnn.bn_386.weight.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.7.weight') ) + c_tnn.bn_386.bias.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.7.bias') ) + c_tnn.bn_386.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.7.running_mean') ) + c_tnn.bn_386.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.3.conv.7.running_var') ) + c_tnn.conv_388.weight.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.0.weight') ) + c_tnn.bn_389.weight.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.1.weight') ) + c_tnn.bn_389.bias.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.1.bias') ) + c_tnn.bn_389.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.1.running_mean') ) + c_tnn.bn_389.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.1.running_var') ) + c_tnn.dconv_391.weight.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.3.weight') ) + c_tnn.bn_392.weight.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.4.weight') ) + c_tnn.bn_392.bias.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.4.bias') ) + c_tnn.bn_392.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.4.running_mean') ) + c_tnn.bn_392.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.4.running_var') ) + c_tnn.conv_394.weight.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.6.weight') ) + c_tnn.bn_395.weight.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.7.weight') ) + c_tnn.bn_395.bias.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.7.bias') ) + c_tnn.bn_395.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.7.running_mean') ) + c_tnn.bn_395.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.4.conv.7.running_var') ) + c_tnn.conv_396.weight.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.0.weight') ) + c_tnn.bn_397.weight.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.1.weight') ) + c_tnn.bn_397.bias.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.1.bias') ) + c_tnn.bn_397.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.1.running_mean') ) + c_tnn.bn_397.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.1.running_var') ) + c_tnn.dconv_399.weight.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.3.weight') ) + c_tnn.bn_400.weight.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.4.weight') ) + c_tnn.bn_400.bias.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.4.bias') ) + c_tnn.bn_400.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.4.running_mean') ) + c_tnn.bn_400.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.4.running_var') ) + c_tnn.conv_402.weight.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.6.weight') ) + c_tnn.bn_403.weight.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.7.weight') ) + c_tnn.bn_403.bias.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.7.bias') ) + c_tnn.bn_403.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.7.running_mean') ) + c_tnn.bn_403.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.5.conv.7.running_var') ) + c_tnn.conv_405.weight.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.0.weight') ) + c_tnn.bn_406.weight.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.1.weight') ) + c_tnn.bn_406.bias.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.1.bias') ) + c_tnn.bn_406.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.1.running_mean') ) + c_tnn.bn_406.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.1.running_var') ) + c_tnn.dconv_408.weight.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.3.weight') ) + c_tnn.bn_409.weight.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.4.weight') ) + c_tnn.bn_409.bias.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.4.bias') ) + c_tnn.bn_409.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.4.running_mean') ) + c_tnn.bn_409.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.4.running_var') ) + c_tnn.conv_411.weight.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.6.weight') ) + c_tnn.bn_412.weight.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.7.weight') ) + c_tnn.bn_412.bias.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.7.bias') ) + c_tnn.bn_412.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.7.running_mean') ) + c_tnn.bn_412.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.6.conv.7.running_var') ) + c_tnn.conv_414.weight.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.0.weight') ) + c_tnn.bn_415.weight.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.1.weight') ) + c_tnn.bn_415.bias.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.1.bias') ) + c_tnn.bn_415.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.1.running_mean') ) + c_tnn.bn_415.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.1.running_var') ) + c_tnn.dconv_417.weight.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.3.weight') ) + c_tnn.bn_418.weight.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.4.weight') ) + c_tnn.bn_418.bias.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.4.bias') ) + c_tnn.bn_418.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.4.running_mean') ) + c_tnn.bn_418.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.4.running_var') ) + c_tnn.conv_420.weight.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.6.weight') ) + c_tnn.bn_421.weight.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.7.weight') ) + c_tnn.bn_421.bias.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.7.bias') ) + c_tnn.bn_421.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.7.running_mean') ) + c_tnn.bn_421.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.7.conv.7.running_var') ) + c_tnn.conv_422.weight.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.0.weight') ) + c_tnn.bn_423.weight.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.1.weight') ) + c_tnn.bn_423.bias.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.1.bias') ) + c_tnn.bn_423.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.1.running_mean') ) + c_tnn.bn_423.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.1.running_var') ) + c_tnn.dconv_425.weight.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.3.weight') ) + c_tnn.bn_426.weight.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.4.weight') ) + c_tnn.bn_426.bias.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.4.bias') ) + c_tnn.bn_426.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.4.running_mean') ) + c_tnn.bn_426.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.4.running_var') ) + c_tnn.conv_428.weight.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.6.weight') ) + c_tnn.bn_429.weight.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.7.weight') ) + c_tnn.bn_429.bias.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.7.bias') ) + c_tnn.bn_429.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.7.running_mean') ) + c_tnn.bn_429.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.8.conv.7.running_var') ) + c_tnn.conv_431.weight.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.0.weight') ) + c_tnn.bn_432.weight.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.1.weight') ) + c_tnn.bn_432.bias.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.1.bias') ) + c_tnn.bn_432.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.1.running_mean') ) + c_tnn.bn_432.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.1.running_var') ) + c_tnn.dconv_434.weight.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.3.weight') ) + c_tnn.bn_435.weight.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.4.weight') ) + c_tnn.bn_435.bias.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.4.bias') ) + c_tnn.bn_435.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.4.running_mean') ) + c_tnn.bn_435.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.4.running_var') ) + c_tnn.conv_437.weight.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.6.weight') ) + c_tnn.bn_438.weight.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.7.weight') ) + c_tnn.bn_438.bias.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.7.bias') ) + c_tnn.bn_438.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.7.running_mean') ) + c_tnn.bn_438.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.9.conv.7.running_var') ) + c_tnn.conv_440.weight.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.0.weight') ) + c_tnn.bn_441.weight.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.1.weight') ) + c_tnn.bn_441.bias.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.1.bias') ) + c_tnn.bn_441.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.1.running_mean') ) + c_tnn.bn_441.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.1.running_var') ) + c_tnn.dconv_443.weight.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.3.weight') ) + c_tnn.bn_444.weight.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.4.weight') ) + c_tnn.bn_444.bias.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.4.bias') ) + c_tnn.bn_444.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.4.running_mean') ) + c_tnn.bn_444.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.4.running_var') ) + c_tnn.conv_446.weight.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.6.weight') ) + c_tnn.bn_447.weight.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.7.weight') ) + c_tnn.bn_447.bias.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.7.bias') ) + c_tnn.bn_447.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.7.running_mean') ) + c_tnn.bn_447.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.10.conv.7.running_var') ) + + c_tnn.conv_449.weight.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.0.weight') ) + c_tnn.bn_450.weight.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.1.weight') ) + c_tnn.bn_450.bias.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.1.bias') ) + c_tnn.bn_450.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.1.running_mean') ) + c_tnn.bn_450.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.1.running_var') ) + c_tnn.dconv_452.weight.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.3.weight') ) + c_tnn.bn_453.weight.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.4.weight') ) + c_tnn.bn_453.bias.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.4.bias') ) + c_tnn.bn_453.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.4.running_mean') ) + c_tnn.bn_453.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.4.running_var') ) + c_tnn.conv_455.weight.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.6.weight') ) + c_tnn.bn_456.weight.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.7.weight') ) + c_tnn.bn_456.bias.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.7.bias') ) + c_tnn.bn_456.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.7.running_mean') ) + c_tnn.bn_456.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.11.conv.7.running_var') ) + + c_tnn.conv_457.weight.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.0.weight') ) + c_tnn.bn_458.weight.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.1.weight') ) + c_tnn.bn_458.bias.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.1.bias') ) + c_tnn.bn_458.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.1.running_mean') ) + c_tnn.bn_458.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.1.running_var') ) + c_tnn.dconv_460.weight.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.3.weight') ) + c_tnn.bn_461.weight.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.4.weight') ) + c_tnn.bn_461.bias.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.4.bias') ) + c_tnn.bn_461.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.4.running_mean') ) + c_tnn.bn_461.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.4.running_var') ) + c_tnn.conv_463.weight.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.6.weight') ) + c_tnn.bn_464.weight.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.7.weight') ) + c_tnn.bn_464.bias.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.7.bias') ) + c_tnn.bn_464.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.7.running_mean') ) + c_tnn.bn_464.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.12.conv.7.running_var') ) + + c_tnn.conv_466.weight.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.0.weight') ) + c_tnn.bn_467.weight.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.1.weight') ) + c_tnn.bn_467.bias.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.1.bias') ) + c_tnn.bn_467.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.1.running_mean') ) + c_tnn.bn_467.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.1.running_var') ) + c_tnn.dconv_469.weight.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.3.weight') ) + c_tnn.bn_470.weight.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.4.weight') ) + c_tnn.bn_470.bias.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.4.bias') ) + c_tnn.bn_470.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.4.running_mean') ) + c_tnn.bn_470.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.4.running_var') ) + c_tnn.conv_472.weight.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.6.weight') ) + c_tnn.bn_473.weight.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.7.weight') ) + c_tnn.bn_473.bias.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.7.bias') ) + c_tnn.bn_473.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.7.running_mean') ) + c_tnn.bn_473.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.13.conv.7.running_var') ) + + c_tnn.conv_475.weight.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.0.weight') ) + c_tnn.bn_476.weight.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.1.weight') ) + c_tnn.bn_476.bias.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.1.bias') ) + c_tnn.bn_476.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.1.running_mean') ) + c_tnn.bn_476.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.1.running_var') ) + c_tnn.dconv_478.weight.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.3.weight') ) + c_tnn.bn_479.weight.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.4.weight') ) + c_tnn.bn_479.bias.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.4.bias') ) + c_tnn.bn_479.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.4.running_mean') ) + c_tnn.bn_479.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.4.running_var') ) + c_tnn.conv_481.weight.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.6.weight') ) + c_tnn.bn_482.weight.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.7.weight') ) + c_tnn.bn_482.bias.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.7.bias') ) + c_tnn.bn_482.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.7.running_mean') ) + c_tnn.bn_482.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.14.conv.7.running_var') ) + + c_tnn.conv_483.weight.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.0.weight') ) + c_tnn.bn_484.weight.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.1.weight') ) + c_tnn.bn_484.bias.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.1.bias') ) + c_tnn.bn_484.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.1.running_mean') ) + c_tnn.bn_484.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.1.running_var') ) + c_tnn.dconv_486.weight.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.3.weight') ) + c_tnn.bn_487.weight.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.4.weight') ) + c_tnn.bn_487.bias.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.4.bias') ) + c_tnn.bn_487.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.4.running_mean') ) + c_tnn.bn_487.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.4.running_var') ) + c_tnn.conv_489.weight.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.6.weight') ) + c_tnn.bn_490.weight.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.7.weight') ) + c_tnn.bn_490.bias.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.7.bias') ) + c_tnn.bn_490.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.7.running_mean') ) + c_tnn.bn_490.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.15.conv.7.running_var') ) + + c_tnn.conv_492.weight.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.0.weight') ) + c_tnn.bn_493.weight.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.1.weight') ) + c_tnn.bn_493.bias.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.1.bias') ) + c_tnn.bn_493.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.1.running_mean') ) + c_tnn.bn_493.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.1.running_var') ) + c_tnn.dconv_495.weight.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.3.weight') ) + c_tnn.bn_496.weight.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.4.weight') ) + c_tnn.bn_496.bias.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.4.bias') ) + c_tnn.bn_496.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.4.running_mean') ) + c_tnn.bn_496.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.4.running_var') ) + c_tnn.conv_498.weight.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.6.weight') ) + c_tnn.bn_499.weight.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.7.weight') ) + c_tnn.bn_499.bias.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.7.bias') ) + c_tnn.bn_499.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.7.running_mean') ) + c_tnn.bn_499.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.16.conv.7.running_var') ) + + c_tnn.conv_501.weight.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.0.weight') ) + c_tnn.bn_502.weight.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.1.weight') ) + c_tnn.bn_502.bias.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.1.bias') ) + c_tnn.bn_502.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.1.running_mean') ) + c_tnn.bn_502.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.1.running_var') ) + c_tnn.dconv_504.weight.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.3.weight') ) + c_tnn.bn_505.weight.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.4.weight') ) + c_tnn.bn_505.bias.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.4.bias') ) + c_tnn.bn_505.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.4.running_mean') ) + c_tnn.bn_505.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.4.running_var') ) + c_tnn.conv_507.weight.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.6.weight') ) + c_tnn.bn_508.weight.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.7.weight') ) + c_tnn.bn_508.bias.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.7.bias') ) + c_tnn.bn_508.running_mean.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.7.running_mean') ) + c_tnn.bn_508.running_var.data = torch.from_numpy ( g_get_weight('backbone.features.17.conv.7.running_var') ) + + c_tnn.conv_509.weight.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.3.conv.weight') ) + c_tnn.bn_510.weight.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.3.bn.weight') ) + c_tnn.bn_510.bias.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.3.bn.bias') ) + c_tnn.bn_510.running_mean.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.3.bn.running_mean') ) + c_tnn.bn_510.running_var.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.3.bn.running_var') ) + + c_tnn.conv_512.weight.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.2.conv.weight') ) + c_tnn.bn_513.weight.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.2.bn.weight') ) + c_tnn.bn_513.bias.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.2.bn.bias') ) + c_tnn.bn_513.running_mean.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.2.bn.running_mean') ) + c_tnn.bn_513.running_var.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.2.bn.running_var') ) + + c_tnn.conv_515.weight.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.2.conv.weight') ) + c_tnn.bn_516.weight.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.2.bn.weight') ) + c_tnn.bn_516.bias.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.2.bn.bias') ) + c_tnn.bn_516.running_mean.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.2.bn.running_mean') ) + c_tnn.bn_516.running_var.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.2.bn.running_var') ) + + c_tnn.conv_519.weight.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.1.conv.weight') ) + c_tnn.bn_520.weight.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.1.bn.weight') ) + c_tnn.bn_520.bias.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.1.bn.bias') ) + c_tnn.bn_520.running_mean.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.1.bn.running_mean') ) + c_tnn.bn_520.running_var.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.1.bn.running_var') ) + + c_tnn.conv_522.weight.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.1.conv.weight') ) + c_tnn.bn_523.weight.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.1.bn.weight') ) + c_tnn.bn_523.bias.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.1.bn.bias') ) + c_tnn.bn_523.running_mean.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.1.bn.running_mean') ) + c_tnn.bn_523.running_var.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.1.bn.running_var') ) + + c_tnn.conv_526.weight.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.0.conv.weight') ) + c_tnn.bn_527.weight.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.0.bn.weight') ) + c_tnn.bn_527.bias.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.0.bn.bias') ) + c_tnn.bn_527.running_mean.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.0.bn.running_mean') ) + c_tnn.bn_527.running_var.data = torch.from_numpy ( g_get_weight('neck.deconv_layers.0.bn.running_var') ) + + c_tnn.conv_529.weight.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.0.conv.weight') ) + c_tnn.bn_530.weight.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.0.bn.weight') ) + c_tnn.bn_530.bias.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.0.bn.bias') ) + c_tnn.bn_530.running_mean.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.0.bn.running_mean') ) + c_tnn.bn_530.running_var.data = torch.from_numpy ( g_get_weight('neck.lateral_layers.0.bn.running_var') ) + + c_tnn.conv_533.weight.data = torch.from_numpy ( g_get_weight('head.feat.weight') ) + c_tnn.bn_534.weight.data = torch.from_numpy ( g_get_weight('head.feat_bn.weight') ) + c_tnn.bn_534.bias.data = torch.from_numpy ( g_get_weight('head.feat_bn.bias') ) + c_tnn.bn_534.running_mean.data = torch.from_numpy ( g_get_weight('head.feat_bn.running_mean') ) + c_tnn.bn_534.running_var.data = torch.from_numpy ( g_get_weight('head.feat_bn.running_var') ) + + c_tnn.conv_536.weight.data = torch.from_numpy ( g_get_weight('head.pos_conv.weight') ) + c_tnn.conv_536.bias.data = torch.from_numpy ( g_get_weight('head.pos_conv.bias') ) + c_tnn.conv_538.weight.data = torch.from_numpy ( g_get_weight('head.reg_conv.weight') ) + c_tnn.conv_538.bias.data = torch.from_numpy ( g_get_weight('head.reg_conv.bias') ) + c_tnn.conv_539.weight.data = torch.from_numpy ( g_get_weight('head.off_conv.weight') ) + c_tnn.conv_539.bias.data = torch.from_numpy ( g_get_weight('head.off_conv.bias') ) + c_tnn.conv_540.weight.data = torch.from_numpy ( g_get_weight('head.lm_conv.weight') ) + c_tnn.conv_540.bias.data = torch.from_numpy ( g_get_weight('head.lm_conv.bias') ) + + torch.save(c_tnn.state_dict(), r'D:\DevelopPython\test\CenterFace.pth') + return + + c_tnn.eval() + c_tnn.to('cuda:0') + + class CenterFaceTorch(object): + def __init__(self, net): + self.net = net + self.img_h_new, self.img_w_new, self.scale_h, self.scale_w = 0, 0, 0, 0 + + def __call__(self, img, threshold=0.5): + h,w,c = img.shape + self.img_h_new, self.img_w_new, self.scale_h, self.scale_w = self.transform(h, w) + return self.inference_opencv(img, threshold) + + def inference_opencv(self, img, threshold): + #blob = cv2.dnn.blobFromImage(img, scalefactor=1.0, size=(self.img_w_new, self.img_h_new), mean=(0, 0, 0), swapRB=True, crop=False) + + img = cv2.resize(img, (self.img_w_new, self.img_h_new) ) + img = img[...,::-1].transpose(2,0,1).reshape(-1).reshape( (1,3,self.img_h_new,self.img_w_new) ).astype(np.float32) + + inp_t = torch.from_numpy(img) + inp_t = inp_t.to('cuda:0') + + heatmap, scale, offset, lms = self.net(inp_t) + + heatmap, scale, offset, lms = [ x.detach().cpu().numpy() for x in (heatmap, scale, offset, lms) ] + return self.postprocess(heatmap, lms, offset, scale, threshold) + + def transform(self, h, w): + img_h_new, img_w_new = int(np.ceil(h / 32) * 32), int(np.ceil(w / 32) * 32) + scale_h, scale_w = img_h_new / h, img_w_new / w + return img_h_new, img_w_new, scale_h, scale_w + + def postprocess(self, heatmap, lms, offset, scale, threshold): + dets, lms = self.decode(heatmap, scale, offset, lms, (self.img_h_new, self.img_w_new), threshold=threshold) + + if len(dets) > 0: + dets[:, 0:4:2], dets[:, 1:4:2] = dets[:, 0:4:2] / self.scale_w, dets[:, 1:4:2] / self.scale_h + lms[:, 0:10:2], lms[:, 1:10:2] = lms[:, 0:10:2] / self.scale_w, lms[:, 1:10:2] / self.scale_h + else: + dets = np.empty(shape=[0, 5], dtype=np.float32) + lms = np.empty(shape=[0, 10], dtype=np.float32) + return dets, lms + + def decode(self, heatmap, scale, offset, landmark, size, threshold=0.1): + heatmap = np.squeeze(heatmap) + scale0, scale1 = scale[0, 0, :, :], scale[0, 1, :, :] + offset0, offset1 = offset[0, 0, :, :], offset[0, 1, :, :] + c0, c1 = np.where(heatmap > threshold) + + boxes, lms = [], [] + if len(c0) > 0: + for i in range(len(c0)): + s0, s1 = np.exp(scale0[c0[i], c1[i]]) * 4, np.exp(scale1[c0[i], c1[i]]) * 4 + o0, o1 = offset0[c0[i], c1[i]], offset1[c0[i], c1[i]] + s = heatmap[c0[i], c1[i]] + x1, y1 = max(0, (c1[i] + o1 + 0.5) * 4 - s1 / 2), max(0, (c0[i] + o0 + 0.5) * 4 - s0 / 2) + x1, y1 = min(x1, size[1]), min(y1, size[0]) + boxes.append([x1, y1, min(x1 + s1, size[1]), min(y1 + s0, size[0]), s]) + + lm = [] + for j in range(5): + lm.append(landmark[0, j * 2 + 1, c0[i], c1[i]] * s1 + x1) + lm.append(landmark[0, j * 2, c0[i], c1[i]] * s0 + y1) + lms.append(lm) + boxes = np.asarray(boxes, dtype=np.float32) + keep = self.nms(boxes[:, :4], boxes[:, 4], 0.3) + boxes = boxes[keep, :] + + lms = np.asarray(lms, dtype=np.float32) + lms = lms[keep, :] + return boxes, lms + + def nms(self, boxes, scores, nms_thresh): + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = np.argsort(scores)[::-1] + num_detections = boxes.shape[0] + suppressed = np.zeros((num_detections,), dtype=np.bool) + + keep = [] + for _i in range(num_detections): + i = order[_i] + if suppressed[i]: + continue + keep.append(i) + + ix1 = x1[i] + iy1 = y1[i] + ix2 = x2[i] + iy2 = y2[i] + iarea = areas[i] + + for _j in range(_i + 1, num_detections): + j = order[_j] + if suppressed[j]: + continue + + xx1 = max(ix1, x1[j]) + yy1 = max(iy1, y1[j]) + xx2 = min(ix2, x2[j]) + yy2 = min(iy2, y2[j]) + w = max(0, xx2 - xx1 + 1) + h = max(0, yy2 - yy1 + 1) + + inter = w * h + ovr = inter / (iarea + areas[j] - inter) + if ovr >= nms_thresh: + suppressed[j] = True + + return keep + + cface_torch = CenterFaceTorch(c_tnn) + + + c_lnn = CenterFaceLNN() + c_lnn.conv_363.kernel.set ( g_get_weight('backbone.features.0.0.weight') ) + c_lnn.bn_364.gamma.set ( g_get_weight('backbone.features.0.1.weight') ) + c_lnn.bn_364.beta.set ( g_get_weight('backbone.features.0.1.bias') ) + c_lnn.bn_364.running_mean.set ( g_get_weight('backbone.features.0.1.running_mean') ) + c_lnn.bn_364.running_var.set ( g_get_weight('backbone.features.0.1.running_var') ) + c_lnn.dconv_366.kernel.set ( g_get_weight('backbone.features.1.conv.0.weight') ) + c_lnn.bn_367.gamma.set ( g_get_weight('backbone.features.1.conv.1.weight') ) + c_lnn.bn_367.beta.set ( g_get_weight('backbone.features.1.conv.1.bias') ) + c_lnn.bn_367.running_mean.set ( g_get_weight('backbone.features.1.conv.1.running_mean') ) + c_lnn.bn_367.running_var.set ( g_get_weight('backbone.features.1.conv.1.running_var') ) + c_lnn.conv_369.kernel.set ( g_get_weight('backbone.features.1.conv.3.weight') ) + c_lnn.bn_370.gamma.set ( g_get_weight('backbone.features.1.conv.4.weight') ) + c_lnn.bn_370.beta.set ( g_get_weight('backbone.features.1.conv.4.bias') ) + c_lnn.bn_370.running_mean.set ( g_get_weight('backbone.features.1.conv.4.running_mean') ) + c_lnn.bn_370.running_var.set ( g_get_weight('backbone.features.1.conv.4.running_var') ) + c_lnn.conv_371.kernel.set ( g_get_weight('backbone.features.2.conv.0.weight') ) + c_lnn.bn_372.gamma.set ( g_get_weight('backbone.features.2.conv.1.weight') ) + c_lnn.bn_372.beta.set ( g_get_weight('backbone.features.2.conv.1.bias') ) + c_lnn.bn_372.running_mean.set ( g_get_weight('backbone.features.2.conv.1.running_mean') ) + c_lnn.bn_372.running_var.set ( g_get_weight('backbone.features.2.conv.1.running_var') ) + c_lnn.dconv_374.kernel.set ( g_get_weight('backbone.features.2.conv.3.weight') ) + c_lnn.bn_375.gamma.set ( g_get_weight('backbone.features.2.conv.4.weight') ) + c_lnn.bn_375.beta.set ( g_get_weight('backbone.features.2.conv.4.bias') ) + c_lnn.bn_375.running_mean.set ( g_get_weight('backbone.features.2.conv.4.running_mean') ) + c_lnn.bn_375.running_var.set ( g_get_weight('backbone.features.2.conv.4.running_var') ) + c_lnn.conv_377.kernel.set( g_get_weight('backbone.features.2.conv.6.weight') ) + c_lnn.bn_378.gamma.set ( g_get_weight('backbone.features.2.conv.7.weight') ) + c_lnn.bn_378.beta.set ( g_get_weight('backbone.features.2.conv.7.bias') ) + c_lnn.bn_378.running_mean.set ( g_get_weight('backbone.features.2.conv.7.running_mean') ) + c_lnn.bn_378.running_var.set ( g_get_weight('backbone.features.2.conv.7.running_var') ) + c_lnn.conv_379.kernel.set ( g_get_weight('backbone.features.3.conv.0.weight') ) + c_lnn.bn_380.gamma.set ( g_get_weight('backbone.features.3.conv.1.weight') ) + c_lnn.bn_380.beta.set ( g_get_weight('backbone.features.3.conv.1.bias') ) + c_lnn.bn_380.running_mean.set ( g_get_weight('backbone.features.3.conv.1.running_mean') ) + c_lnn.bn_380.running_var.set ( g_get_weight('backbone.features.3.conv.1.running_var') ) + c_lnn.dconv_382.kernel.set ( g_get_weight('backbone.features.3.conv.3.weight') ) + c_lnn.bn_383.gamma.set ( g_get_weight('backbone.features.3.conv.4.weight') ) + c_lnn.bn_383.beta.set ( g_get_weight('backbone.features.3.conv.4.bias') ) + c_lnn.bn_383.running_mean.set ( g_get_weight('backbone.features.3.conv.4.running_mean') ) + c_lnn.bn_383.running_var.set ( g_get_weight('backbone.features.3.conv.4.running_var') ) + c_lnn.conv_385.kernel.set ( g_get_weight('backbone.features.3.conv.6.weight') ) + c_lnn.bn_386.gamma.set ( g_get_weight('backbone.features.3.conv.7.weight') ) + c_lnn.bn_386.beta.set ( g_get_weight('backbone.features.3.conv.7.bias') ) + c_lnn.bn_386.running_mean.set ( g_get_weight('backbone.features.3.conv.7.running_mean') ) + c_lnn.bn_386.running_var.set ( g_get_weight('backbone.features.3.conv.7.running_var') ) + c_lnn.conv_388.kernel.set ( g_get_weight('backbone.features.4.conv.0.weight') ) + c_lnn.bn_389.gamma.set ( g_get_weight('backbone.features.4.conv.1.weight') ) + c_lnn.bn_389.beta.set ( g_get_weight('backbone.features.4.conv.1.bias') ) + c_lnn.bn_389.running_mean.set ( g_get_weight('backbone.features.4.conv.1.running_mean') ) + c_lnn.bn_389.running_var.set ( g_get_weight('backbone.features.4.conv.1.running_var') ) + c_lnn.dconv_391.kernel.set ( g_get_weight('backbone.features.4.conv.3.weight') ) + c_lnn.bn_392.gamma.set ( g_get_weight('backbone.features.4.conv.4.weight') ) + c_lnn.bn_392.beta.set ( g_get_weight('backbone.features.4.conv.4.bias') ) + c_lnn.bn_392.running_mean.set ( g_get_weight('backbone.features.4.conv.4.running_mean') ) + c_lnn.bn_392.running_var.set ( g_get_weight('backbone.features.4.conv.4.running_var') ) + c_lnn.conv_394.kernel.set ( g_get_weight('backbone.features.4.conv.6.weight') ) + c_lnn.bn_395.gamma.set ( g_get_weight('backbone.features.4.conv.7.weight') ) + c_lnn.bn_395.beta.set ( g_get_weight('backbone.features.4.conv.7.bias') ) + c_lnn.bn_395.running_mean.set ( g_get_weight('backbone.features.4.conv.7.running_mean') ) + c_lnn.bn_395.running_var.set ( g_get_weight('backbone.features.4.conv.7.running_var') ) + c_lnn.conv_396.kernel.set ( g_get_weight('backbone.features.5.conv.0.weight') ) + c_lnn.bn_397.gamma.set ( g_get_weight('backbone.features.5.conv.1.weight') ) + c_lnn.bn_397.beta.set ( g_get_weight('backbone.features.5.conv.1.bias') ) + c_lnn.bn_397.running_mean.set ( g_get_weight('backbone.features.5.conv.1.running_mean') ) + c_lnn.bn_397.running_var.set ( g_get_weight('backbone.features.5.conv.1.running_var') ) + c_lnn.dconv_399.kernel.set ( g_get_weight('backbone.features.5.conv.3.weight') ) + c_lnn.bn_400.gamma.set ( g_get_weight('backbone.features.5.conv.4.weight') ) + c_lnn.bn_400.beta.set ( g_get_weight('backbone.features.5.conv.4.bias') ) + c_lnn.bn_400.running_mean.set ( g_get_weight('backbone.features.5.conv.4.running_mean') ) + c_lnn.bn_400.running_var.set ( g_get_weight('backbone.features.5.conv.4.running_var') ) + c_lnn.conv_402.kernel.set ( g_get_weight('backbone.features.5.conv.6.weight') ) + c_lnn.bn_403.gamma.set ( g_get_weight('backbone.features.5.conv.7.weight') ) + c_lnn.bn_403.beta.set ( g_get_weight('backbone.features.5.conv.7.bias') ) + c_lnn.bn_403.running_mean.set ( g_get_weight('backbone.features.5.conv.7.running_mean') ) + c_lnn.bn_403.running_var.set ( g_get_weight('backbone.features.5.conv.7.running_var') ) + c_lnn.conv_405.kernel.set ( g_get_weight('backbone.features.6.conv.0.weight') ) + c_lnn.bn_406.gamma.set ( g_get_weight('backbone.features.6.conv.1.weight') ) + c_lnn.bn_406.beta.set ( g_get_weight('backbone.features.6.conv.1.bias') ) + c_lnn.bn_406.running_mean.set ( g_get_weight('backbone.features.6.conv.1.running_mean') ) + c_lnn.bn_406.running_var.set ( g_get_weight('backbone.features.6.conv.1.running_var') ) + c_lnn.dconv_408.kernel.set ( g_get_weight('backbone.features.6.conv.3.weight') ) + c_lnn.bn_409.gamma.set ( g_get_weight('backbone.features.6.conv.4.weight') ) + c_lnn.bn_409.beta.set ( g_get_weight('backbone.features.6.conv.4.bias') ) + c_lnn.bn_409.running_mean.set ( g_get_weight('backbone.features.6.conv.4.running_mean') ) + c_lnn.bn_409.running_var.set ( g_get_weight('backbone.features.6.conv.4.running_var') ) + c_lnn.conv_411.kernel.set ( g_get_weight('backbone.features.6.conv.6.weight') ) + c_lnn.bn_412.gamma.set ( g_get_weight('backbone.features.6.conv.7.weight') ) + c_lnn.bn_412.beta.set ( g_get_weight('backbone.features.6.conv.7.bias') ) + c_lnn.bn_412.running_mean.set ( g_get_weight('backbone.features.6.conv.7.running_mean') ) + c_lnn.bn_412.running_var.set ( g_get_weight('backbone.features.6.conv.7.running_var') ) + c_lnn.conv_414.kernel.set ( g_get_weight('backbone.features.7.conv.0.weight') ) + c_lnn.bn_415.gamma.set ( g_get_weight('backbone.features.7.conv.1.weight') ) + c_lnn.bn_415.beta.set ( g_get_weight('backbone.features.7.conv.1.bias') ) + c_lnn.bn_415.running_mean.set ( g_get_weight('backbone.features.7.conv.1.running_mean') ) + c_lnn.bn_415.running_var.set ( g_get_weight('backbone.features.7.conv.1.running_var') ) + c_lnn.dconv_417.kernel.set ( g_get_weight('backbone.features.7.conv.3.weight') ) + c_lnn.bn_418.gamma.set ( g_get_weight('backbone.features.7.conv.4.weight') ) + c_lnn.bn_418.beta.set ( g_get_weight('backbone.features.7.conv.4.bias') ) + c_lnn.bn_418.running_mean.set ( g_get_weight('backbone.features.7.conv.4.running_mean') ) + c_lnn.bn_418.running_var.set ( g_get_weight('backbone.features.7.conv.4.running_var') ) + c_lnn.conv_420.kernel.set ( g_get_weight('backbone.features.7.conv.6.weight') ) + c_lnn.bn_421.gamma.set ( g_get_weight('backbone.features.7.conv.7.weight') ) + c_lnn.bn_421.beta.set ( g_get_weight('backbone.features.7.conv.7.bias') ) + c_lnn.bn_421.running_mean.set ( g_get_weight('backbone.features.7.conv.7.running_mean') ) + c_lnn.bn_421.running_var.set ( g_get_weight('backbone.features.7.conv.7.running_var') ) + c_lnn.conv_422.kernel.set ( g_get_weight('backbone.features.8.conv.0.weight') ) + c_lnn.bn_423.gamma.set ( g_get_weight('backbone.features.8.conv.1.weight') ) + c_lnn.bn_423.beta.set ( g_get_weight('backbone.features.8.conv.1.bias') ) + c_lnn.bn_423.running_mean.set ( g_get_weight('backbone.features.8.conv.1.running_mean') ) + c_lnn.bn_423.running_var.set ( g_get_weight('backbone.features.8.conv.1.running_var') ) + c_lnn.dconv_425.kernel.set ( g_get_weight('backbone.features.8.conv.3.weight') ) + c_lnn.bn_426.gamma.set ( g_get_weight('backbone.features.8.conv.4.weight') ) + c_lnn.bn_426.beta.set ( g_get_weight('backbone.features.8.conv.4.bias') ) + c_lnn.bn_426.running_mean.set ( g_get_weight('backbone.features.8.conv.4.running_mean') ) + c_lnn.bn_426.running_var.set ( g_get_weight('backbone.features.8.conv.4.running_var') ) + c_lnn.conv_428.kernel.set ( g_get_weight('backbone.features.8.conv.6.weight') ) + c_lnn.bn_429.gamma.set ( g_get_weight('backbone.features.8.conv.7.weight') ) + c_lnn.bn_429.beta.set ( g_get_weight('backbone.features.8.conv.7.bias') ) + c_lnn.bn_429.running_mean.set ( g_get_weight('backbone.features.8.conv.7.running_mean') ) + c_lnn.bn_429.running_var.set ( g_get_weight('backbone.features.8.conv.7.running_var') ) + c_lnn.conv_431.kernel.set ( g_get_weight('backbone.features.9.conv.0.weight') ) + c_lnn.bn_432.gamma.set ( g_get_weight('backbone.features.9.conv.1.weight') ) + c_lnn.bn_432.beta.set ( g_get_weight('backbone.features.9.conv.1.bias') ) + c_lnn.bn_432.running_mean.set ( g_get_weight('backbone.features.9.conv.1.running_mean') ) + c_lnn.bn_432.running_var.set ( g_get_weight('backbone.features.9.conv.1.running_var') ) + c_lnn.dconv_434.kernel.set ( g_get_weight('backbone.features.9.conv.3.weight') ) + c_lnn.bn_435.gamma.set ( g_get_weight('backbone.features.9.conv.4.weight') ) + c_lnn.bn_435.beta.set ( g_get_weight('backbone.features.9.conv.4.bias') ) + c_lnn.bn_435.running_mean.set ( g_get_weight('backbone.features.9.conv.4.running_mean') ) + c_lnn.bn_435.running_var.set ( g_get_weight('backbone.features.9.conv.4.running_var') ) + c_lnn.conv_437.kernel.set ( g_get_weight('backbone.features.9.conv.6.weight') ) + c_lnn.bn_438.gamma.set ( g_get_weight('backbone.features.9.conv.7.weight') ) + c_lnn.bn_438.beta.set ( g_get_weight('backbone.features.9.conv.7.bias') ) + c_lnn.bn_438.running_mean.set ( g_get_weight('backbone.features.9.conv.7.running_mean') ) + c_lnn.bn_438.running_var.set ( g_get_weight('backbone.features.9.conv.7.running_var') ) + c_lnn.conv_440.kernel.set ( g_get_weight('backbone.features.10.conv.0.weight') ) + c_lnn.bn_441.gamma.set ( g_get_weight('backbone.features.10.conv.1.weight') ) + c_lnn.bn_441.beta.set ( g_get_weight('backbone.features.10.conv.1.bias') ) + c_lnn.bn_441.running_mean.set ( g_get_weight('backbone.features.10.conv.1.running_mean') ) + c_lnn.bn_441.running_var.set ( g_get_weight('backbone.features.10.conv.1.running_var') ) + c_lnn.dconv_443.kernel.set ( g_get_weight('backbone.features.10.conv.3.weight') ) + c_lnn.bn_444.gamma.set ( g_get_weight('backbone.features.10.conv.4.weight') ) + c_lnn.bn_444.beta.set ( g_get_weight('backbone.features.10.conv.4.bias') ) + c_lnn.bn_444.running_mean.set ( g_get_weight('backbone.features.10.conv.4.running_mean') ) + c_lnn.bn_444.running_var.set ( g_get_weight('backbone.features.10.conv.4.running_var') ) + c_lnn.conv_446.kernel.set ( g_get_weight('backbone.features.10.conv.6.weight') ) + c_lnn.bn_447.gamma.set ( g_get_weight('backbone.features.10.conv.7.weight') ) + c_lnn.bn_447.beta.set ( g_get_weight('backbone.features.10.conv.7.bias') ) + c_lnn.bn_447.running_mean.set ( g_get_weight('backbone.features.10.conv.7.running_mean') ) + c_lnn.bn_447.running_var.set ( g_get_weight('backbone.features.10.conv.7.running_var') ) + + c_lnn.conv_449.kernel.set ( g_get_weight('backbone.features.11.conv.0.weight') ) + c_lnn.bn_450.gamma.set ( g_get_weight('backbone.features.11.conv.1.weight') ) + c_lnn.bn_450.beta.set ( g_get_weight('backbone.features.11.conv.1.bias') ) + c_lnn.bn_450.running_mean.set ( g_get_weight('backbone.features.11.conv.1.running_mean') ) + c_lnn.bn_450.running_var.set ( g_get_weight('backbone.features.11.conv.1.running_var') ) + c_lnn.dconv_452.kernel.set ( g_get_weight('backbone.features.11.conv.3.weight') ) + c_lnn.bn_453.gamma.set ( g_get_weight('backbone.features.11.conv.4.weight') ) + c_lnn.bn_453.beta.set ( g_get_weight('backbone.features.11.conv.4.bias') ) + c_lnn.bn_453.running_mean.set ( g_get_weight('backbone.features.11.conv.4.running_mean') ) + c_lnn.bn_453.running_var.set ( g_get_weight('backbone.features.11.conv.4.running_var') ) + c_lnn.conv_455.kernel.set ( g_get_weight('backbone.features.11.conv.6.weight') ) + c_lnn.bn_456.gamma.set ( g_get_weight('backbone.features.11.conv.7.weight') ) + c_lnn.bn_456.beta.set ( g_get_weight('backbone.features.11.conv.7.bias') ) + c_lnn.bn_456.running_mean.set ( g_get_weight('backbone.features.11.conv.7.running_mean') ) + c_lnn.bn_456.running_var.set ( g_get_weight('backbone.features.11.conv.7.running_var') ) + + c_lnn.conv_457.kernel.set ( g_get_weight('backbone.features.12.conv.0.weight') ) + c_lnn.bn_458.gamma.set ( g_get_weight('backbone.features.12.conv.1.weight') ) + c_lnn.bn_458.beta.set ( g_get_weight('backbone.features.12.conv.1.bias') ) + c_lnn.bn_458.running_mean.set ( g_get_weight('backbone.features.12.conv.1.running_mean') ) + c_lnn.bn_458.running_var.set ( g_get_weight('backbone.features.12.conv.1.running_var') ) + c_lnn.dconv_460.kernel.set ( g_get_weight('backbone.features.12.conv.3.weight') ) + c_lnn.bn_461.gamma.set ( g_get_weight('backbone.features.12.conv.4.weight') ) + c_lnn.bn_461.beta.set ( g_get_weight('backbone.features.12.conv.4.bias') ) + c_lnn.bn_461.running_mean.set ( g_get_weight('backbone.features.12.conv.4.running_mean') ) + c_lnn.bn_461.running_var.set ( g_get_weight('backbone.features.12.conv.4.running_var') ) + c_lnn.conv_463.kernel.set ( g_get_weight('backbone.features.12.conv.6.weight') ) + c_lnn.bn_464.gamma.set ( g_get_weight('backbone.features.12.conv.7.weight') ) + c_lnn.bn_464.beta.set ( g_get_weight('backbone.features.12.conv.7.bias') ) + c_lnn.bn_464.running_mean.set ( g_get_weight('backbone.features.12.conv.7.running_mean') ) + c_lnn.bn_464.running_var.set ( g_get_weight('backbone.features.12.conv.7.running_var') ) + + c_lnn.conv_466.kernel.set ( g_get_weight('backbone.features.13.conv.0.weight') ) + c_lnn.bn_467.gamma.set ( g_get_weight('backbone.features.13.conv.1.weight') ) + c_lnn.bn_467.beta.set ( g_get_weight('backbone.features.13.conv.1.bias') ) + c_lnn.bn_467.running_mean.set ( g_get_weight('backbone.features.13.conv.1.running_mean') ) + c_lnn.bn_467.running_var.set ( g_get_weight('backbone.features.13.conv.1.running_var') ) + c_lnn.dconv_469.kernel.set ( g_get_weight('backbone.features.13.conv.3.weight') ) + c_lnn.bn_470.gamma.set ( g_get_weight('backbone.features.13.conv.4.weight') ) + c_lnn.bn_470.beta.set ( g_get_weight('backbone.features.13.conv.4.bias') ) + c_lnn.bn_470.running_mean.set ( g_get_weight('backbone.features.13.conv.4.running_mean') ) + c_lnn.bn_470.running_var.set ( g_get_weight('backbone.features.13.conv.4.running_var') ) + c_lnn.conv_472.kernel.set ( g_get_weight('backbone.features.13.conv.6.weight') ) + c_lnn.bn_473.gamma.set ( g_get_weight('backbone.features.13.conv.7.weight') ) + c_lnn.bn_473.beta.set ( g_get_weight('backbone.features.13.conv.7.bias') ) + c_lnn.bn_473.running_mean.set ( g_get_weight('backbone.features.13.conv.7.running_mean') ) + c_lnn.bn_473.running_var.set ( g_get_weight('backbone.features.13.conv.7.running_var') ) + + c_lnn.conv_475.kernel.set ( g_get_weight('backbone.features.14.conv.0.weight') ) + c_lnn.bn_476.gamma.set ( g_get_weight('backbone.features.14.conv.1.weight') ) + c_lnn.bn_476.beta.set ( g_get_weight('backbone.features.14.conv.1.bias') ) + c_lnn.bn_476.running_mean.set ( g_get_weight('backbone.features.14.conv.1.running_mean') ) + c_lnn.bn_476.running_var.set ( g_get_weight('backbone.features.14.conv.1.running_var') ) + c_lnn.dconv_478.kernel.set ( g_get_weight('backbone.features.14.conv.3.weight') ) + c_lnn.bn_479.gamma.set ( g_get_weight('backbone.features.14.conv.4.weight') ) + c_lnn.bn_479.beta.set ( g_get_weight('backbone.features.14.conv.4.bias') ) + c_lnn.bn_479.running_mean.set ( g_get_weight('backbone.features.14.conv.4.running_mean') ) + c_lnn.bn_479.running_var.set ( g_get_weight('backbone.features.14.conv.4.running_var') ) + c_lnn.conv_481.kernel.set ( g_get_weight('backbone.features.14.conv.6.weight') ) + c_lnn.bn_482.gamma.set ( g_get_weight('backbone.features.14.conv.7.weight') ) + c_lnn.bn_482.beta.set ( g_get_weight('backbone.features.14.conv.7.bias') ) + c_lnn.bn_482.running_mean.set ( g_get_weight('backbone.features.14.conv.7.running_mean') ) + c_lnn.bn_482.running_var.set ( g_get_weight('backbone.features.14.conv.7.running_var') ) + + + c_lnn.conv_483.kernel.set ( g_get_weight('backbone.features.15.conv.0.weight') ) + c_lnn.bn_484.gamma.set ( g_get_weight('backbone.features.15.conv.1.weight') ) + c_lnn.bn_484.beta.set ( g_get_weight('backbone.features.15.conv.1.bias') ) + c_lnn.bn_484.running_mean.set ( g_get_weight('backbone.features.15.conv.1.running_mean') ) + c_lnn.bn_484.running_var.set ( g_get_weight('backbone.features.15.conv.1.running_var') ) + c_lnn.dconv_486.kernel.set ( g_get_weight('backbone.features.15.conv.3.weight') ) + c_lnn.bn_487.gamma.set ( g_get_weight('backbone.features.15.conv.4.weight') ) + c_lnn.bn_487.beta.set ( g_get_weight('backbone.features.15.conv.4.bias') ) + c_lnn.bn_487.running_mean.set ( g_get_weight('backbone.features.15.conv.4.running_mean') ) + c_lnn.bn_487.running_var.set ( g_get_weight('backbone.features.15.conv.4.running_var') ) + c_lnn.conv_489.kernel.set ( g_get_weight('backbone.features.15.conv.6.weight') ) + c_lnn.bn_490.gamma.set ( g_get_weight('backbone.features.15.conv.7.weight') ) + c_lnn.bn_490.beta.set ( g_get_weight('backbone.features.15.conv.7.bias') ) + c_lnn.bn_490.running_mean.set ( g_get_weight('backbone.features.15.conv.7.running_mean') ) + c_lnn.bn_490.running_var.set ( g_get_weight('backbone.features.15.conv.7.running_var') ) + + c_lnn.conv_492.kernel.set ( g_get_weight('backbone.features.16.conv.0.weight') ) + c_lnn.bn_493.gamma.set ( g_get_weight('backbone.features.16.conv.1.weight') ) + c_lnn.bn_493.beta.set ( g_get_weight('backbone.features.16.conv.1.bias') ) + c_lnn.bn_493.running_mean.set ( g_get_weight('backbone.features.16.conv.1.running_mean') ) + c_lnn.bn_493.running_var.set ( g_get_weight('backbone.features.16.conv.1.running_var') ) + c_lnn.dconv_495.kernel.set ( g_get_weight('backbone.features.16.conv.3.weight') ) + c_lnn.bn_496.gamma.set ( g_get_weight('backbone.features.16.conv.4.weight') ) + c_lnn.bn_496.beta.set ( g_get_weight('backbone.features.16.conv.4.bias') ) + c_lnn.bn_496.running_mean.set ( g_get_weight('backbone.features.16.conv.4.running_mean') ) + c_lnn.bn_496.running_var.set ( g_get_weight('backbone.features.16.conv.4.running_var') ) + c_lnn.conv_498.kernel.set ( g_get_weight('backbone.features.16.conv.6.weight') ) + c_lnn.bn_499.gamma.set ( g_get_weight('backbone.features.16.conv.7.weight') ) + c_lnn.bn_499.beta.set ( g_get_weight('backbone.features.16.conv.7.bias') ) + c_lnn.bn_499.running_mean.set ( g_get_weight('backbone.features.16.conv.7.running_mean') ) + c_lnn.bn_499.running_var.set ( g_get_weight('backbone.features.16.conv.7.running_var') ) + + c_lnn.conv_501.kernel.set ( g_get_weight('backbone.features.17.conv.0.weight') ) + c_lnn.bn_502.gamma.set ( g_get_weight('backbone.features.17.conv.1.weight') ) + c_lnn.bn_502.beta.set ( g_get_weight('backbone.features.17.conv.1.bias') ) + c_lnn.bn_502.running_mean.set ( g_get_weight('backbone.features.17.conv.1.running_mean') ) + c_lnn.bn_502.running_var.set ( g_get_weight('backbone.features.17.conv.1.running_var') ) + c_lnn.dconv_504.kernel.set ( g_get_weight('backbone.features.17.conv.3.weight') ) + c_lnn.bn_505.gamma.set ( g_get_weight('backbone.features.17.conv.4.weight') ) + c_lnn.bn_505.beta.set ( g_get_weight('backbone.features.17.conv.4.bias') ) + c_lnn.bn_505.running_mean.set ( g_get_weight('backbone.features.17.conv.4.running_mean') ) + c_lnn.bn_505.running_var.set ( g_get_weight('backbone.features.17.conv.4.running_var') ) + c_lnn.conv_507.kernel.set ( g_get_weight('backbone.features.17.conv.6.weight') ) + c_lnn.bn_508.gamma.set ( g_get_weight('backbone.features.17.conv.7.weight') ) + c_lnn.bn_508.beta.set ( g_get_weight('backbone.features.17.conv.7.bias') ) + c_lnn.bn_508.running_mean.set ( g_get_weight('backbone.features.17.conv.7.running_mean') ) + c_lnn.bn_508.running_var.set ( g_get_weight('backbone.features.17.conv.7.running_var') ) + + c_lnn.conv_509.kernel.set ( g_get_weight('neck.lateral_layers.3.conv.weight') ) + c_lnn.bn_510.gamma.set ( g_get_weight('neck.lateral_layers.3.bn.weight') ) + c_lnn.bn_510.beta.set ( g_get_weight('neck.lateral_layers.3.bn.bias') ) + c_lnn.bn_510.running_mean.set ( g_get_weight('neck.lateral_layers.3.bn.running_mean') ) + c_lnn.bn_510.running_var.set ( g_get_weight('neck.lateral_layers.3.bn.running_var') ) + + c_lnn.conv_512.kernel.set ( g_get_weight('neck.deconv_layers.2.conv.weight').transpose( (1,0,2,3) ) ) + c_lnn.bn_513.gamma.set ( g_get_weight('neck.deconv_layers.2.bn.weight') ) + c_lnn.bn_513.beta.set ( g_get_weight('neck.deconv_layers.2.bn.bias') ) + c_lnn.bn_513.running_mean.set ( g_get_weight('neck.deconv_layers.2.bn.running_mean') ) + c_lnn.bn_513.running_var.set ( g_get_weight('neck.deconv_layers.2.bn.running_var') ) + + c_lnn.conv_515.kernel.set ( g_get_weight('neck.lateral_layers.2.conv.weight') ) + c_lnn.bn_516.gamma.set ( g_get_weight('neck.lateral_layers.2.bn.weight') ) + c_lnn.bn_516.beta.set ( g_get_weight('neck.lateral_layers.2.bn.bias') ) + c_lnn.bn_516.running_mean.set ( g_get_weight('neck.lateral_layers.2.bn.running_mean') ) + c_lnn.bn_516.running_var.set ( g_get_weight('neck.lateral_layers.2.bn.running_var') ) + + c_lnn.conv_519.kernel.set ( g_get_weight('neck.deconv_layers.1.conv.weight').transpose( (1,0,2,3) ) ) + c_lnn.bn_520.gamma.set ( g_get_weight('neck.deconv_layers.1.bn.weight') ) + c_lnn.bn_520.beta.set ( g_get_weight('neck.deconv_layers.1.bn.bias') ) + c_lnn.bn_520.running_mean.set ( g_get_weight('neck.deconv_layers.1.bn.running_mean') ) + c_lnn.bn_520.running_var.set ( g_get_weight('neck.deconv_layers.1.bn.running_var') ) + + c_lnn.conv_522.kernel.set ( g_get_weight('neck.lateral_layers.1.conv.weight') ) + c_lnn.bn_523.gamma.set ( g_get_weight('neck.lateral_layers.1.bn.weight') ) + c_lnn.bn_523.beta.set ( g_get_weight('neck.lateral_layers.1.bn.bias') ) + c_lnn.bn_523.running_mean.set ( g_get_weight('neck.lateral_layers.1.bn.running_mean') ) + c_lnn.bn_523.running_var.set ( g_get_weight('neck.lateral_layers.1.bn.running_var') ) + + c_lnn.conv_526.kernel.set ( g_get_weight('neck.deconv_layers.0.conv.weight').transpose( (1,0,2,3) ) ) + c_lnn.bn_527.gamma.set ( g_get_weight('neck.deconv_layers.0.bn.weight') ) + c_lnn.bn_527.beta.set ( g_get_weight('neck.deconv_layers.0.bn.bias') ) + c_lnn.bn_527.running_mean.set ( g_get_weight('neck.deconv_layers.0.bn.running_mean') ) + c_lnn.bn_527.running_var.set ( g_get_weight('neck.deconv_layers.0.bn.running_var') ) + + c_lnn.conv_529.kernel.set ( g_get_weight('neck.lateral_layers.0.conv.weight') ) + c_lnn.bn_530.gamma.set ( g_get_weight('neck.lateral_layers.0.bn.weight') ) + c_lnn.bn_530.beta.set ( g_get_weight('neck.lateral_layers.0.bn.bias') ) + c_lnn.bn_530.running_mean.set ( g_get_weight('neck.lateral_layers.0.bn.running_mean') ) + c_lnn.bn_530.running_var.set ( g_get_weight('neck.lateral_layers.0.bn.running_var') ) + + c_lnn.conv_533.kernel.set ( g_get_weight('head.feat.weight') ) + c_lnn.bn_534.gamma.set ( g_get_weight('head.feat_bn.weight') ) + c_lnn.bn_534.beta.set ( g_get_weight('head.feat_bn.bias') ) + c_lnn.bn_534.running_mean.set ( g_get_weight('head.feat_bn.running_mean') ) + c_lnn.bn_534.running_var.set ( g_get_weight('head.feat_bn.running_var') ) + + c_lnn.conv_536.kernel.set ( g_get_weight('head.pos_conv.weight') ) + c_lnn.conv_536.bias.set ( g_get_weight('head.pos_conv.bias') ) + c_lnn.conv_538.kernel.set ( g_get_weight('head.reg_conv.weight') ) + c_lnn.conv_538.bias.set ( g_get_weight('head.reg_conv.bias') ) + c_lnn.conv_539.kernel.set ( g_get_weight('head.off_conv.weight') ) + c_lnn.conv_539.bias.set ( g_get_weight('head.off_conv.bias') ) + c_lnn.conv_540.kernel.set ( g_get_weight('head.lm_conv.weight') ) + c_lnn.conv_540.bias.set ( g_get_weight('head.lm_conv.bias') ) + + img_t = lnn.Tensor_from_value( img[...,::-1].transpose(2,0,1).reshape( (1,3,512,512) ) ) + + c_lnn.set_training(False) + + #================================== + img = img_u8 = cv2_imread(r"D:\DevelopPython\test\00004.jpg") + img = img.astype(np.float32) / 255.0 + + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + canny = cv2.Canny(img_u8, 200,200) + + sobelx = cv2.Sobel(img,cv2.CV_32F,1,0,ksize=3) + sobely = cv2.Sobel(img,cv2.CV_32F,0,1,ksize=3) + x = np.abs(sobelx)+np.abs(sobely) + + #x = cv2.blur(x, (5,5) ) + x[x<0.8] = 0.0 + #x = np.clip(x, 0, 1) + #x = cv2.blur(x, (5,5) ) + x = x / x.max() + + #x[x>=0.1] = 1.0 + #x = np.clip(x*5.0, 0.0, 1.0) + + cv2.imshow("", x) + cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + + + sys.path.insert(0, str(r'D:\DevelopPython\Projects\SBR\lib')) + + import torch + import torch.nn as nn + import torch.nn.functional as F + import torchvision as tv + + #from datasets import GeneralDataset as Dataset + #from xvision import transforms, draw_image_by_points + #from models import obtain_model, remove_module_dict + #from utils import get_model_infos + #from config_utils import load_configure + + + snapshot = Path(r'D:\DevelopPython\test\SBR.pth') + + snapshot = torch.load(snapshot) + + + def get_parameters(model, bias): + for m in model.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + if bias: + yield m.bias + else: + yield m.weight + elif isinstance(m, nn.BatchNorm2d): + if bias: + yield m.bias + else: + yield m.weight + def remove_module_dict(state_dict, is_print=False): + new_state_dict = dict() + for k, v in state_dict.items(): + if k[:7] == 'module.': + name = k[7:] # remove `module.` + else: + name = k + new_state_dict[name] = v + if is_print: print(new_state_dict.keys()) + return new_state_dict + def find_tensor_peak_batch(heatmap, radius, downsample, threshold = 0.000001): + assert heatmap.dim() == 3, 'The dimension of the heatmap is wrong : {}'.format(heatmap.size()) + #assert radius > 0 and isinstance(radius, numbers.Number), 'The radius is not ok : {}'.format(radius) + num_pts, H, W = heatmap.size(0), heatmap.size(1), heatmap.size(2) + assert W > 1 and H > 1, 'To avoid the normalization function divide zero' + # find the approximate location: + score, index = torch.max(heatmap.view(num_pts, -1), 1) + index_w = (index % W).float() + index_h = (index / W).float() + + def normalize(x, L): + return -1. + 2. * x.data / (L-1) + boxes = [index_w - radius, index_h - radius, index_w + radius, index_h + radius] + boxes[0] = normalize(boxes[0], W) + boxes[1] = normalize(boxes[1], H) + boxes[2] = normalize(boxes[2], W) + boxes[3] = normalize(boxes[3], H) + + affine_parameter = torch.zeros((num_pts, 2, 3)) + affine_parameter[:,0,0] = (boxes[2]-boxes[0])/2 + affine_parameter[:,0,2] = (boxes[2]+boxes[0])/2 + affine_parameter[:,1,1] = (boxes[3]-boxes[1])/2 + affine_parameter[:,1,2] = (boxes[3]+boxes[1])/2 + # extract the sub-region heatmap + theta = affine_parameter.to(heatmap.device) + grid_size = torch.Size([num_pts, 1, radius*2+1, radius*2+1]) + grid = F.affine_grid(theta, grid_size) + sub_feature = F.grid_sample(heatmap.unsqueeze(1), grid).squeeze(1) + + sub_feature = F.threshold(sub_feature, threshold, np.finfo(float).eps) + + X = torch.arange(-radius, radius+1).to(heatmap).view(1, 1, radius*2+1) + Y = torch.arange(-radius, radius+1).to(heatmap).view(1, radius*2+1, 1) + + sum_region = torch.sum(sub_feature.view(num_pts,-1),1) + x = torch.sum((sub_feature*X).view(num_pts,-1),1) / sum_region + index_w + y = torch.sum((sub_feature*Y).view(num_pts,-1),1) / sum_region + index_h + #import code + #code.interact(local=dict(globals(), **locals())) + + x = x * downsample + downsample / 2.0 - 0.5 + y = y * downsample + downsample / 2.0 - 0.5 + return torch.stack([x, y],1), score + + class VGG16_base(nn.Module): + def __init__(self): + super(VGG16_base, self).__init__() + + self.downsample = 8 + pts_num = 69 + + self.features = nn.Sequential( + nn.Conv2d( 3, 64, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), + nn.Conv2d( 64, 64, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d( 64, 128, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(128, 256, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(256, 512, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(512, 512, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(512, 512, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True)) + + + self.CPM_feature = nn.Sequential( + nn.Conv2d(512, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), #CPM_1 + nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True)) #CPM_2 + + stage1 = nn.Sequential( + nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(128, 512, kernel_size=1, padding=0), nn.ReLU(inplace=True), + nn.Conv2d(512, pts_num, kernel_size=1, padding=0)) + stages = [stage1] + for i in range(1, 3): + stagex = nn.Sequential( + nn.Conv2d(128+pts_num, 128, kernel_size=7, dilation=1, padding=3), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=7, dilation=1, padding=3), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=7, dilation=1, padding=3), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=3, dilation=1, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(128, 128, kernel_size=1, padding=0), nn.ReLU(inplace=True), + nn.Conv2d(128, pts_num, kernel_size=1, padding=0)) + stages.append( stagex ) + self.stages = nn.ModuleList(stages) + + def specify_parameter(self, base_lr, base_weight_decay): + params_dict = [ {'params': get_parameters(self.features, bias=False), 'lr': base_lr , 'weight_decay': base_weight_decay}, + {'params': get_parameters(self.features, bias=True ), 'lr': base_lr*2, 'weight_decay': 0}, + {'params': get_parameters(self.CPM_feature, bias=False), 'lr': base_lr , 'weight_decay': base_weight_decay}, + {'params': get_parameters(self.CPM_feature, bias=True ), 'lr': base_lr*2, 'weight_decay': 0}, + ] + for stage in self.stages: + params_dict.append( {'params': get_parameters(stage, bias=False), 'lr': base_lr*4, 'weight_decay': base_weight_decay} ) + params_dict.append( {'params': get_parameters(stage, bias=True ), 'lr': base_lr*8, 'weight_decay': 0} ) + return params_dict + + # return : cpm-stages, locations + def forward(self, inputs): + assert inputs.dim() == 4, 'This model accepts 4 dimension input tensor: {}'.format(inputs.size()) + batch_size, feature_dim = inputs.size(0), inputs.size(1) + batch_cpms, batch_locs, batch_scos = [], [], [] + + feature = self.features(inputs) + xfeature = self.CPM_feature(feature) + for i in range(3): + if i == 0: + cpm = self.stages[i]( xfeature ) + else: + cpm = self.stages[i]( torch.cat([xfeature, batch_cpms[i-1]], 1) ) + batch_cpms.append( cpm ) + + # The location of the current batch + for ibatch in range(batch_size): + batch_location, batch_score = find_tensor_peak_batch(batch_cpms[-1][ibatch], 4, self.downsample) + batch_locs.append( batch_location ) + batch_scos.append( batch_score ) + batch_locs, batch_scos = torch.stack(batch_locs), torch.stack(batch_scos) + + return batch_cpms, batch_locs, batch_scos + + def transform( point, center, scale, resolution): + pt = np.array ( [point[0], point[1], 1.0] ) + h = 200.0 * scale + m = np.eye(3) + m[0,0] = resolution / h + m[1,1] = resolution / h + m[0,2] = resolution * ( -center[0] / h + 0.5 ) + m[1,2] = resolution * ( -center[1] / h + 0.5 ) + m = np.linalg.inv(m) + return np.matmul (m, pt)[0:2] + + def crop(image, center, scale, resolution=224.0): + ul = transform([1, 1], center, scale, resolution).astype( np.int ) + br = transform([resolution, resolution], center, scale, resolution).astype( np.int ) + + if image.ndim > 2: + newDim = np.array([br[1] - ul[1], br[0] - ul[0], image.shape[2]], dtype=np.int32) + newImg = np.zeros(newDim, dtype=np.uint8) + else: + newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int) + newImg = np.zeros(newDim, dtype=np.uint8) + ht = image.shape[0] + wd = image.shape[1] + newX = np.array([max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32) + newY = np.array([max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32) + oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) + oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) + newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :] + + newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), interpolation=cv2.INTER_LINEAR) + return newImg + + param = snapshot['args'] + #eval_transform = transforms.Compose([transforms.PreCrop(param.pre_crop_expand), transforms.TrainScale2WH((param.crop_width, param.crop_height)), transforms.ToTensor(), normalize]) + #model_config = load_configure(param.model_config, None) + + net = VGG16_base() + d=remove_module_dict(snapshot['state_dict']) + net.load_state_dict(d) + + img = cv2_imread(r"D:\DevelopPython\test\00003.jpg") + #img = cv2_imread(r"D:\DevelopPython\test\ct_00003.jpg") + img = img.astype(np.float32) / 255.0 + img = cv2.resize( img, (224,224) ) + + mat = cv2.getRotationMatrix2D( (112,112), 0, 0.75 ) + #mat = np.array([ [1,0,0], [0,1,0] ], np.float32) + #mat *= 0.5 + #img = cv2.warpAffine(img, mat, (224, 224), cv2.INTER_LANCZOS4) + + img_t = img.copy() + img_t = cv2.resize( img_t, (224,224) ) + + img_t -= [0.485, 0.456, 0.406] + img_t /= [0.229, 0.224, 0.225] + img_t = img_t[None,:,:,::-1].copy().transpose ( (0,3,1,2) ) + img_t = torch.from_numpy(img_t) + pred = net(img_t) + + lmrks = pred[1][0].detach().cpu().numpy() + + for pt in lmrks: + x,y = pt.astype(np.int) + cv2.circle(img, (x, y), 1, (0,0,1) ) + + cv2.imshow("", img) + cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + + + + + + Y, Cr, Cb = bgr2dct(img) + + img2 = dct2bgr(Y, Cr, Cb) + + import code + code.interact(local=dict(globals(), **locals())) + #cv2.imshow("", img) + #cv2.waitKey(0) + cv2.imshow("", img2) + cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + + #============================================= + + + #heatmap_c, scale_c, offset_c, lms_c = net.forward(['537','538','539','540']) + + # while True: + # with timeit(): + # inp_t = torch.from_numpy(img[...,::-1].transpose(2,0,1).reshape(-1).reshape( (1,3,512,512) ).astype(np.float32)) + # inp_t = inp_t.to('cuda:0') + # heatmap_t, scale_t, offset_t, lms_t = c_tnn(inp_t) + + # diff = np.sum( np.abs( np.ndarray.flatten( heatmap_c - heatmap_t.detach().cpu().numpy() ) ) ) + # print(f'diff = {diff}. per param = { diff / np.prod(heatmap_c.shape) }') + # diff = np.sum( np.abs( np.ndarray.flatten( scale_c - scale_t.detach().cpu().numpy() ) ) ) + # print(f'diff = {diff}. per param = { diff / np.prod(scale_c.shape) }') + # diff = np.sum( np.abs( np.ndarray.flatten( offset_c - offset_t.detach().cpu().numpy() ) ) ) + # print(f'diff = {diff}. per param = { diff / np.prod(offset_c.shape) }') + # diff = np.sum( np.abs( np.ndarray.flatten( lms_c - lms_t.detach().cpu().numpy() ) ) ) + # print(f'diff = {diff}. per param = { diff / np.prod(lms_c.shape) }') + + #import code + #code.interact(local=dict(globals(), **locals())) + + # #while True: + # # with timeit(): + # heatmap_l, scale_l, offset_l, lms_l = c_lnn(img_t) + + # #while True: + # # with timeit(): + # heatmap_c, scale_c, offset_c, lms_c = net.forward(['537','538','539','540']) + + # diff = np.sum( np.abs( np.ndarray.flatten( heatmap_c - heatmap_l.np() ) ) ) + # print(f'diff = {diff}. per param = { diff / np.prod(heatmap_c.shape) }') + # diff = np.sum( np.abs( np.ndarray.flatten( scale_c - scale_l.np() ) ) ) + # print(f'diff = {diff}. per param = { diff / np.prod(scale_c.shape) }') + # diff = np.sum( np.abs( np.ndarray.flatten( offset_c - offset_l.np() ) ) ) + # print(f'diff = {diff}. per param = { diff / np.prod(offset_c.shape) }') + # diff = np.sum( np.abs( np.ndarray.flatten( lms_c - lms_l.np() ) ) ) + # print(f'diff = {diff}. per param = { diff / np.prod(lms_c.shape) }') + + class FaceAligner(tnn.Module): + def __init__(self, resolution): + super().__init__() + self.conv1 = tnn.Conv2d(3, 32, 5, stride=1, padding=2) + self.conv2 = tnn.Conv2d(32, 64, 5, stride=1, padding=2) + self.conv3 = tnn.Conv2d(64, 128, 5, stride=1, padding=2) + self.conv4 = tnn.Conv2d(128, 256, 5, stride=1, padding=2) + self.conv5 = tnn.Conv2d(256, 512, 5, stride=1, padding=2) + + low_res = resolution // (2**5) + self.fc1 = tnn.Linear(512*low_res*low_res, 6) + + def forward(self, x): + + x = F.leaky_relu(self.conv1(x)) + x = F.max_pool2d(x, 2, 2) + x = F.leaky_relu(self.conv2(x)) + x = F.max_pool2d(x, 2, 2) + x = F.leaky_relu(self.conv3(x)) + x = F.max_pool2d(x, 2, 2) + x = F.leaky_relu(self.conv4(x)) + x = F.max_pool2d(x, 2, 2) + x = F.leaky_relu(self.conv5(x)) + x = F.max_pool2d(x, 2, 2) + x = x.reshape( x.shape[0], -1 ) + x = self.fc1(x) + return x + + #net = FaceAligner(256) + #net = tv.models.mobilenet_v2 ( num_classes = 6) + + #import code + #code.interact(local=dict(globals(), **locals())) + + net = tv.models.resnet.ResNet ( tv.models.resnet.BasicBlock, [2,2,2,2], 6) + model_ckp_path = Path(r'E:\face_aligner.pt') + if model_ckp_path.exists(): + checkpoint = torch.load(model_ckp_path) + net.load_state_dict(checkpoint['model_state_dict']) + net.cuda() + net.eval() + + + landmarks_5pt = np.array([ + [ 0.0, 0.0 ], + [ 1.0, 0.0 ], + [ 0.5, 0.5 ], + [ 0.0, 1.0 ], + [ 1.0, 1.0 ]]) + + def transform_points(points, mat, invert=False): + if invert: + mat = cv2.invertAffineTransform (mat) + points = np.expand_dims(points, axis=1) + points = cv2.transform(points, mat, points.shape) + points = np.squeeze(points) + return points + + def get_transform_mat (image_landmarks, output_size, scale=1.0): + if not isinstance(image_landmarks, np.ndarray): + image_landmarks = np.array (image_landmarks) + + # estimate landmarks transform from global space to local aligned space with bounds [0..1] + mat = umeyama(image_landmarks, landmarks_5pt, True)[0:2] + + # get corner points in global space + g_p = transform_points ( np.float32([(0,0),(1,0),(1,1),(0,1),(0.5,0.5) ]) , mat, True) + g_c = g_p[4] + + # calc diagonal vectors between corners in global space + tb_diag_vec = (g_p[2]-g_p[0]).astype(np.float32) + tb_diag_vec /= npla.norm(tb_diag_vec) + bt_diag_vec = (g_p[1]-g_p[3]).astype(np.float32) + bt_diag_vec /= npla.norm(bt_diag_vec) + + # calc modifier of diagonal vectors for scale and padding value + padding, remove_align = 1.0, False #FaceType_to_padding_remove_align.get(face_type, 0.0) + mod = (1.0 / scale)* ( npla.norm(g_p[0]-g_p[2])*(padding*np.sqrt(2.0) + 0.5) ) + + + # calc 3 points in global space to estimate 2d affine transform + if not remove_align: + l_t = np.array( [ g_c - tb_diag_vec*mod, + g_c + bt_diag_vec*mod, + g_c + tb_diag_vec*mod ] ) + else: + # remove_align - face will be centered in the frame but not aligned + l_t = np.array( [ g_c - tb_diag_vec*mod, + g_c + bt_diag_vec*mod, + g_c + tb_diag_vec*mod, + g_c - bt_diag_vec*mod, + ] ) + + # get area of face square in global space + area = mathlib.polygon_area(l_t[:,0], l_t[:,1] ) + + # calc side of square + side = np.float32(math.sqrt(area) / 2) + + # calc 3 points with unrotated square + l_t = np.array( [ g_c + [-side,-side], + g_c + [ side,-side], + g_c + [ side, side] ] ) + + # calc affine transform from 3 global space points to 3 local space points size of 'output_size' + pts2 = np.float32(( (0,0),(output_size,0),(output_size,output_size) )) + mat = cv2.getAffineTransform(l_t,pts2) + return mat + + output_path = Path(r'F:\DeepFaceLabCUDA9.2SSE\workspace шиа\data_dst\test_out') + output_path.mkdir(exist_ok=True, parents=True) + + for filename in pathex.get_image_paths(output_path): + Path(filename).unlink() + + #import code + #code.interact(local=dict(globals(), **locals())) + resolution = 224 + + # img = cv2_imread(r'D:\DevelopPython\test\00006.jpg') + # img = img.transpose( (2,0,1) )[None,...].astype(np.float32) + # img /= 255.0 + # pred_mats_t = net( torch.from_numpy(img).cuda() ) + # #pred_mats_t = pred_mats_t.detach().cpu().numpy().reshape( (2,3) ) + # #pred_mats_t[:,2] *= resolution + # print( pred_mats_t ) + + # import code + # code.interact(local=dict(globals(), **locals())) + + for filename in tqdm(pathex.get_image_paths(r'F:\DeepFaceLabCUDA9.2SSE\workspace шиа\data_dst') , desc="test", ascii=True): + filepath = Path(filename) + img = cv2_imread(filepath) + + #while True: + # with timeit(): + boxes, lms = cface_torch(img) + + # for (l,t,r,b,c), lm in zip(boxes, lms): + # lm = lm.reshape( (5,2) ) + + # imagelib.draw_rect(img, (l,t,r,b), (0,0,255), 2 ) + # for (x,y) in lm: + # cv2.circle (img, ( int(x), int(y) ), 2, (0,0,255), thickness=2) + if len(lms) == 0: + continue + + lm = lms[0] + lm = lm.reshape( (5,2) ) + + mat = get_transform_mat(lm, resolution) + face_img = cv2.warpAffine(img, mat, (resolution, resolution), cv2.INTER_LANCZOS4) + + #cv2_imwrite(output_path / filepath.name , face_img) + #continue + + with torch.no_grad(): + face_img_t = (face_img[None,...].astype(np.float32) / 255.0) + face_img_t -= np.array([0.406, 0.456, 0.485], np.float32) + face_img_t /= np.array([0.225, 0.224, 0.229], np.float32) + face_img_t = face_img_t.transpose( (0,3,1,2)) + face_img_t = torch.from_numpy(face_img_t).cuda() + + + face_mat_t = net(face_img_t) + + #import code + #code.interact(local=dict(globals(), **locals())) + + face_mat = face_mat_t.detach().cpu().numpy() + face_mat = face_mat.reshape( (2,3) ) + face_mat[:,2] *= resolution + + face_mat = cv2.invertAffineTransform (face_mat) + #print(face_mat) + + #cv2.imshow("", cv2.warpAffine(face_img, face_mat, (256, 256), cv2.INTER_LANCZOS4) ) + #cv2.waitKey(0) + + + img = cv2.warpAffine(face_img, face_mat, (resolution, resolution), cv2.INTER_LANCZOS4) + + cv2_imwrite(output_path / filepath.name , img) + + import code + code.interact(local=dict(globals(), **locals())) + + + c = CenterFace() + + + boxes, lms = c(img) + + for (l,t,r,b,c), lm in zip(boxes, lms): + lm = lm.reshape( (5,2) ) + + imagelib.draw_rect (img, [l,t,r,b], (0,255,0), 2) + for x,y in lm: + cv2.circle (img, ( int(x), int(y) ), 2, (0,0,255), thickness=2) + + #img = cv2.resize(img, (1920, 1080) ) + cv2.imshow("", img) + cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() )# device_config=nn.DeviceConfig.GPUIndexes([1]) ) + tf = nn.tf + + + + + + def load_pb(path_to_pb): + with tf.gfile.GFile(path_to_pb, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name='') + return graph + + graph = load_pb (r"D:\DevelopPython\test\tlv3_std_2x.pb") + + + # input0 = graph.get_tensor_by_name('VideoSR_Unet/inputFrame0:0') + # input1 = graph.get_tensor_by_name('VideoSR_Unet/inputFrame1:0') + # input2 = graph.get_tensor_by_name('VideoSR_Unet/inputFrame2:0') + # input3 = graph.get_tensor_by_name('VideoSR_Unet/inputFrame3:0') + # input4 = graph.get_tensor_by_name('VideoSR_Unet/inputFrame4:0') + # output = graph.get_tensor_by_name('VideoSR_Unet/Out4X/output/add:0') + + #filepath = r'D:\DevelopPython\test\00000.jpg' + #img = cv2.imread(filepath).astype(np.float32) / 255.0 + #inp_img = img *2 - 1 + #inp_img = cv2.resize (inp_img, (192,192) ) + + + + sess = tf.Session(graph=graph, config=nn.tf_sess_config) + + writer = tf.summary.FileWriter(r'D:\logs', nn.tf_sess.graph) + + def get_op_value(op_name, n_output=0): + return sess.run ([ graph.get_operation_by_name(op_name).outputs[n_output] ])[0].astype(K.floatx()) + + # + + import code + code.interact(local=dict(globals(), **locals())) + + + # ==================================== + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.BestGPU() )# device_config=nn.DeviceConfig.GPUIndexes([1]) ) + tf = nn.tf + + import torch + import torch.nn as tnn + import torch.nn.functional as F + + + + def sf3d_keras(input_shape, sf3d_torch): + + inp = Input ( (None, None,3), dtype=K.floatx() ) + x = inp + x = Lambda ( lambda x: x - K.constant([104,117,123]), output_shape=(None,None,3) ) (x) + + x = Conv2D(64, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv1_1), activation='relu') (ZeroPadding2D(1)(x)) + x = Conv2D(64, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv1_2), activation='relu') (ZeroPadding2D(1)(x)) + x = MaxPooling2D()(x) + + x = Conv2D(128, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv2_1), activation='relu') (ZeroPadding2D(1)(x)) + x = Conv2D(128, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv2_2), activation='relu') (ZeroPadding2D(1)(x)) + x = MaxPooling2D()(x) + + x = Conv2D(256, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv3_1), activation='relu') (ZeroPadding2D(1)(x)) + x = Conv2D(256, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv3_2), activation='relu') (ZeroPadding2D(1)(x)) + x = Conv2D(256, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv3_3), activation='relu') (ZeroPadding2D(1)(x)) + f3_3 = x + x = MaxPooling2D()(x) + + x = Conv2D(512, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv4_1), activation='relu') (ZeroPadding2D(1)(x)) + x = Conv2D(512, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv4_2), activation='relu') (ZeroPadding2D(1)(x)) + x = Conv2D(512, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv4_3), activation='relu') (ZeroPadding2D(1)(x)) + f4_3 = x + x = MaxPooling2D()(x) + + x = Conv2D(512, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv5_1), activation='relu') (ZeroPadding2D(1)(x)) + x = Conv2D(512, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv5_2), activation='relu') (ZeroPadding2D(1)(x)) + x = Conv2D(512, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv5_3), activation='relu') (ZeroPadding2D(1)(x)) + f5_3 = x + x = MaxPooling2D()(x) + + x = ZeroPadding2D(padding=(3,3))(x) + x = Conv2D(1024, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.fc6), activation='relu') (x) + x = Conv2D(1024, kernel_size=1, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.fc7), activation='relu') (x) + ffc7 = x + + x = Conv2D(256, kernel_size=1, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv6_1), activation='relu') (x) + x = ZeroPadding2D(padding=(1,1))(x) + x = Conv2D(512, kernel_size=3, strides=2, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv6_2), activation='relu') (x) + f6_2 = x + + x = Conv2D(128, kernel_size=1, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv7_1), activation='relu') (x) + x = ZeroPadding2D(padding=(1,1))(x) + x = Conv2D(256, kernel_size=3, strides=2, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv7_2), activation='relu') (x) + f7_2 = x + + class L2Norm(KL.Layer): + def __init__(self, n_channels, scale=1.0, weights=None, **kwargs): + self.n_channels = n_channels + self.scale = scale + + self.weights_ = weights + super(L2Norm, self).__init__(**kwargs) + + def build(self, input_shape): + self.input_spec = None + + self.w = self.add_weight( shape=(1, 1, self.n_channels), initializer='ones', name='w' ) + + if self.weights_ is not None: + self.set_weights( [self.weights_.reshape ( (1,1,-1) )] ) + + self.built = True + + def call(self, inputs, training=None): + x = inputs + x = x / (K.sqrt( K.sum( K.pow(x, 2), axis=-1, keepdims=True ) ) + 1e-10) * self.w + return x + + def get_config(self): + config = {'n_channels': self.n_channels, 'scale': self.scale } + + base_config = super(L2Norm, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def compute_output_shape(self, input_shape): + return input_shape + + f3_3 = L2Norm(256, scale=10, weights=sf3d_torch.conv3_3_norm.weight.data.cpu().numpy())(f3_3) + f4_3 = L2Norm(512, scale=8, weights=sf3d_torch.conv4_3_norm.weight.data.cpu().numpy())(f4_3) + f5_3 = L2Norm(512, scale=5, weights=sf3d_torch.conv5_3_norm.weight.data.cpu().numpy())(f5_3) + + cls1 = Conv2D(4, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv3_3_norm_mbox_conf), activation='softmax')(ZeroPadding2D(1)(f3_3)) + reg1 = Conv2D(4, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv3_3_norm_mbox_loc)) (ZeroPadding2D(1)(f3_3)) + + cls2 = Conv2D(2, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv4_3_norm_mbox_conf), activation='softmax')(ZeroPadding2D(1)(f4_3)) + reg2 = Conv2D(4, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv4_3_norm_mbox_loc)) (ZeroPadding2D(1)(f4_3)) + + cls3 = Conv2D(2, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv5_3_norm_mbox_conf), activation='softmax')(ZeroPadding2D(1)(f5_3)) + reg3 = Conv2D(4, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv5_3_norm_mbox_loc)) (ZeroPadding2D(1)(f5_3)) + + cls4 = Conv2D(2, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.fc7_mbox_conf), activation='softmax')(ZeroPadding2D(1)(ffc7)) + reg4 = Conv2D(4, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.fc7_mbox_loc)) (ZeroPadding2D(1)(ffc7)) + + cls5 = Conv2D(2, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv6_2_mbox_conf), activation='softmax')(ZeroPadding2D(1)(f6_2)) + reg5 = Conv2D(4, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv6_2_mbox_loc)) (ZeroPadding2D(1)(f6_2)) + + cls6 = Conv2D(2, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv7_2_mbox_conf), activation='softmax')(ZeroPadding2D(1)(f7_2)) + reg6 = Conv2D(4, kernel_size=3, strides=1, padding='valid', weights=t2kw_conv2d(sf3d_torch.conv7_2_mbox_loc)) (ZeroPadding2D(1)(f7_2)) + + L = Lambda ( lambda x: x[:,:,:,-1], output_shape=(None,None,1) ) + cls1, cls2, cls3, cls4, cls5, cls6 = [ L(x) for x in [cls1, cls2, cls3, cls4, cls5, cls6] ] + + return Model(inp, [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]) + + class TL2Norm(tnn.Module): + def __init__(self, n_channels, scale=1.0): + super(TL2Norm, self).__init__() + self.n_channels = n_channels + self.scale = scale + self.eps = 1e-10 + self.weight = tnn.Parameter(torch.Tensor(self.n_channels)) + self.weight.data *= 0.0 + self.weight.data += self.scale + + def forward(self, x): + norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps + x = x / norm * self.weight.view(1, -1, 1, 1) + return x + class L2Norm(nn.LayerBase): + def __init__(self, n_channels, **kwargs): + self.n_channels = n_channels + super().__init__(**kwargs) + + def init_weights(self): + self.weight = tf.get_variable ("weight", (1, 1, 1, self.n_channels), dtype=nn.floatx, initializer=tf.initializers.ones ) + + def get_weights(self): + return [self.weight] + + def __call__(self, inputs): + x = inputs + x = x / (tf.sqrt( tf.reduce_sum( tf.pow(x, 2), axis=-1, keepdims=True ) ) + 1e-10) * self.weight + return x + + class S3FD(nn.ModelBase): + def __init__(self): + super().__init__(name='S3FD') + + def on_build(self): + self.minus = tf.constant([104,117,123], dtype=nn.floatx ) + self.conv1_1 = nn.Conv2D(3, 64, kernel_size=3, strides=1, padding='SAME') + self.conv1_2 = nn.Conv2D(64, 64, kernel_size=3, strides=1, padding='SAME') + + self.conv2_1 = nn.Conv2D(64, 128, kernel_size=3, strides=1, padding='SAME') + self.conv2_2 = nn.Conv2D(128, 128, kernel_size=3, strides=1, padding='SAME') + + self.conv3_1 = nn.Conv2D(128, 256, kernel_size=3, strides=1, padding='SAME') + self.conv3_2 = nn.Conv2D(256, 256, kernel_size=3, strides=1, padding='SAME') + self.conv3_3 = nn.Conv2D(256, 256, kernel_size=3, strides=1, padding='SAME') + + self.conv4_1 = nn.Conv2D(256, 512, kernel_size=3, strides=1, padding='SAME') + self.conv4_2 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') + self.conv4_3 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') + + self.conv5_1 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') + self.conv5_2 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') + self.conv5_3 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME') + + self.fc6 = nn.Conv2D(512, 1024, kernel_size=3, strides=1, padding=3) + self.fc7 = nn.Conv2D(1024, 1024, kernel_size=1, strides=1, padding='SAME') + + self.conv6_1 = nn.Conv2D(1024, 256, kernel_size=1, strides=1, padding='SAME') + self.conv6_2 = nn.Conv2D(256, 512, kernel_size=3, strides=2, padding='SAME') + + self.conv7_1 = nn.Conv2D(512, 128, kernel_size=1, strides=1, padding='SAME') + self.conv7_2 = nn.Conv2D(128, 256, kernel_size=3, strides=2, padding='SAME') + + self.conv3_3_norm = L2Norm(256) + self.conv4_3_norm = L2Norm(512) + self.conv5_3_norm = L2Norm(512) + + + self.conv3_3_norm_mbox_conf = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME') + self.conv3_3_norm_mbox_loc = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME') + + self.conv4_3_norm_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME') + self.conv4_3_norm_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME') + + self.conv5_3_norm_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME') + self.conv5_3_norm_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME') + + self.fc7_mbox_conf = nn.Conv2D(1024, 2, kernel_size=3, strides=1, padding='SAME') + self.fc7_mbox_loc = nn.Conv2D(1024, 4, kernel_size=3, strides=1, padding='SAME') + + self.conv6_2_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME') + self.conv6_2_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME') + + self.conv7_2_mbox_conf = nn.Conv2D(256, 2, kernel_size=3, strides=1, padding='SAME') + self.conv7_2_mbox_loc = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME') + + def call(self, x): + x = x - self.minus + x = tf.nn.relu(self.conv1_1(x)) + x = tf.nn.relu(self.conv1_2(x)) + x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + + x = tf.nn.relu(self.conv2_1(x)) + x = tf.nn.relu(self.conv2_2(x)) + x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + + x = tf.nn.relu(self.conv3_1(x)) + x = tf.nn.relu(self.conv3_2(x)) + x = tf.nn.relu(self.conv3_3(x)) + f3_3 = x + x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + + x = tf.nn.relu(self.conv4_1(x)) + x = tf.nn.relu(self.conv4_2(x)) + x = tf.nn.relu(self.conv4_3(x)) + f4_3 = x + x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + + x = tf.nn.relu(self.conv5_1(x)) + x = tf.nn.relu(self.conv5_2(x)) + x = tf.nn.relu(self.conv5_3(x)) + f5_3 = x + x = tf.nn.max_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + + x = tf.nn.relu(self.fc6(x)) + x = tf.nn.relu(self.fc7(x)) + ffc7 = x + + x = tf.nn.relu(self.conv6_1(x)) + x = tf.nn.relu(self.conv6_2(x)) + f6_2 = x + + x = tf.nn.relu(self.conv7_1(x)) + x = tf.nn.relu(self.conv7_2(x)) + f7_2 = x + + f3_3 = self.conv3_3_norm(f3_3) + f4_3 = self.conv4_3_norm(f4_3) + f5_3 = self.conv5_3_norm(f5_3) + + cls1 = self.conv3_3_norm_mbox_conf(f3_3) + reg1 = self.conv3_3_norm_mbox_loc(f3_3) + + cls2 = tf.nn.softmax(self.conv4_3_norm_mbox_conf(f4_3)) + reg2 = self.conv4_3_norm_mbox_loc(f4_3) + + cls3 = tf.nn.softmax(self.conv5_3_norm_mbox_conf(f5_3)) + reg3 = self.conv5_3_norm_mbox_loc(f5_3) + + cls4 = tf.nn.softmax(self.fc7_mbox_conf(ffc7)) + reg4 = self.fc7_mbox_loc(ffc7) + + cls5 = tf.nn.softmax(self.conv6_2_mbox_conf(f6_2)) + reg5 = self.conv6_2_mbox_loc(f6_2) + + cls6 = tf.nn.softmax(self.conv7_2_mbox_conf(f7_2)) + reg6 = self.conv7_2_mbox_loc(f7_2) + + # max-out background label + bmax = tf.maximum(tf.maximum(cls1[:,:,:,0:1], cls1[:,:,:,1:2]), cls1[:,:,:,2:3]) + + cls1 = tf.concat ([bmax, cls1[:,:,:,3:4] ], axis=-1) + cls1 = tf.nn.softmax(cls1) + + return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] + + + + class s3fd_torch(tnn.Module): + def __init__(self): + super(s3fd_torch, self).__init__() + self.conv1_1 = tnn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = tnn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = tnn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = tnn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = tnn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = tnn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = tnn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = tnn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = tnn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = tnn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = tnn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = tnn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = tnn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.fc6 = tnn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) + self.fc7 = tnn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) + + self.conv6_1 = tnn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) + self.conv6_2 = tnn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) + + self.conv7_1 = tnn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) + self.conv7_2 = tnn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) + + self.conv3_3_norm = TL2Norm(256, scale=10) + self.conv4_3_norm = TL2Norm(512, scale=8) + self.conv5_3_norm = TL2Norm(512, scale=5) + + self.conv3_3_norm_mbox_conf = tnn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv3_3_norm_mbox_loc = tnn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_conf = tnn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_loc = tnn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_conf = tnn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_loc = tnn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + + self.fc7_mbox_conf = tnn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) + self.fc7_mbox_loc = tnn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_conf = tnn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_loc = tnn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_conf = tnn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_loc = tnn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + + h = F.relu(self.conv1_1(x)) + h = F.relu(self.conv1_2(h)) + h = F.max_pool2d(h, 2, 2) + + + h = F.relu(self.conv2_1(h)) + h = F.relu(self.conv2_2(h)) + h = F.max_pool2d(h, 2, 2) + + + + h = F.relu(self.conv3_1(h)) + h = F.relu(self.conv3_2(h)) + h = F.relu(self.conv3_3(h)) + f3_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv4_1(h)) + h = F.relu(self.conv4_2(h)) + h = F.relu(self.conv4_3(h)) + f4_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv5_1(h)) + h = F.relu(self.conv5_2(h)) + h = F.relu(self.conv5_3(h)) + f5_3 = h + h = F.max_pool2d(h, 2, 2) + + + h = F.relu(self.fc6(h)) + h = F.relu(self.fc7(h)) + ffc7 = h + + + h = F.relu(self.conv6_1(h)) + h = F.relu(self.conv6_2(h)) + + f6_2 = h + + h = F.relu(self.conv7_1(h)) + h = F.relu(self.conv7_2(h)) + f7_2 = h + + f3_3 = self.conv3_3_norm(f3_3) + + + f4_3 = self.conv4_3_norm(f4_3) + f5_3 = self.conv5_3_norm(f5_3) + + cls1 = self.conv3_3_norm_mbox_conf(f3_3) + reg1 = self.conv3_3_norm_mbox_loc(f3_3) + cls2 = self.conv4_3_norm_mbox_conf(f4_3) + reg2 = self.conv4_3_norm_mbox_loc(f4_3) + cls3 = self.conv5_3_norm_mbox_conf(f5_3) + reg3 = self.conv5_3_norm_mbox_loc(f5_3) + + cls4 = self.fc7_mbox_conf(ffc7) + reg4 = self.fc7_mbox_loc(ffc7) + cls5 = self.conv6_2_mbox_conf(f6_2) + reg5 = self.conv6_2_mbox_loc(f6_2) + cls6 = self.conv7_2_mbox_conf(f7_2) + reg6 = self.conv7_2_mbox_loc(f7_2) + + # max-out background label + chunk = torch.chunk(cls1, 4, 1) + bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) + cls1 = torch.cat ([bmax,chunk[3]], dim=1) + cls1, cls2, cls3, cls4, cls5, cls6 = [ F.softmax(x, dim=1) for x in [cls1, cls2, cls3, cls4, cls5, cls6] ] + + return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] + + + + + + def decode(loc, priors, variances): + boxes = np.concatenate((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * np.exp(loc[:, 2:] * variances[1])), + 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + def softmax(x, axis=-1): + y = np.exp(x - np.max(x, axis, keepdims=True)) + return y / np.sum(y, axis, keepdims=True) + + def nms(dets, thresh): + """ Perform Non-Maximum Suppression """ + keep = list() + if len(dets) == 0: + return keep + + x_1, y_1, x_2, y_2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] + areas = (x_2 - x_1 + 1) * (y_2 - y_1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx_1, yy_1 = np.maximum(x_1[i], x_1[order[1:]]), np.maximum(y_1[i], y_1[order[1:]]) + xx_2, yy_2 = np.minimum(x_2[i], x_2[order[1:]]), np.minimum(y_2[i], y_2[order[1:]]) + + width, height = np.maximum(0.0, xx_2 - xx_1 + 1), np.maximum(0.0, yy_2 - yy_1 + 1) + ovr = width * height / (areas[i] + areas[order[1:]] - width * height) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + return keep + + def detect_torch(olist): + + bboxlist = [] + + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + stride = 2**(i + 2) # 4,8,16,32,64,128 + poss = [*zip(*np.where(ocls[:, 1, :, :] > 0.05))] + + for Iindex, hindex, windex in poss: + axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride + score = ocls[0, 1, hindex, windex] + loc = np.ascontiguousarray(oreg[0, :, hindex, windex]).reshape((1, 4)) + priors = np.array([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) + variances = [0.1, 0.2] + box = decode(loc, priors, variances) + x1, y1, x2, y2 = box[0] * 1.0 + bboxlist.append([x1, y1, x2, y2, score]) + bboxlist = np.array(bboxlist) + if 0 == len(bboxlist): + bboxlist = np.zeros((1, 5)) + + return bboxlist + + def detect_keras(olist): + bboxlist = [] + for i, ((ocls,), (oreg,)) in enumerate ( zip ( olist[::2], olist[1::2] ) ): + stride = 2**(i + 2) # 4,8,16,32,64,128 + s_d2 = stride / 2 + s_m4 = stride * 4 + + for hindex, windex in zip(*np.where(ocls[...,1] > 0.05)): + score = ocls[hindex, windex, 1] + loc = oreg[hindex, windex, :] + priors = np.array([windex * stride + s_d2, hindex * stride + s_d2, s_m4, s_m4]) + priors_2p = priors[2:] + box = np.concatenate((priors[:2] + loc[:2] * 0.1 * priors_2p, + priors_2p * np.exp(loc[2:] * 0.2)) ) + box[:2] -= box[2:] / 2 + box[2:] += box[:2] + + bboxlist.append([*box, score]) + + bboxlist = np.array(bboxlist) + if len(bboxlist) == 0: + bboxlist = np.zeros((1, 5)) + + bboxlist = bboxlist[nms(bboxlist, 0.3), :] + bboxlist = [ x[:-1] for x in bboxlist if x[-1] >= 0.5] + return bboxlist + + #transfer weights + def convd2d_torch_to_lnn(litenn_layer, torch_layer): + litenn_layer.kernel.set ( torch_layer.weight.data.numpy() ) + if torch_layer.bias is not None: + litenn_layer.bias.set (torch_layer.bias.data.numpy() ) + + def convd2d_from_torch(torch_layer): + result = [ torch_layer.weight.data.numpy().transpose(2,3,1,0) ] + if torch_layer.bias is not None: + result += [ torch_layer.bias.data.numpy().reshape( (1,1,1,-1) ) ] + return result + + def l2norm_from_torch(torch_layer): + result = [ torch_layer.weight.data.numpy().reshape( (1,1,1,-1) ) ] + return result + + + + class L2Norm_lnn(lnn.Module): + def __init__(self, n_channels, scale=1.0): + self.n_channels = n_channels + self.weight = lnn.Tensor( (n_channels,), init=lnn.initializer.Scalar(scale) ) + super().__init__(saveables=['weight']) + + def forward(self, x): + x = x / (lnn.sqrt( lnn.reduce_sum( lnn.square(x), axes=1, keepdims=True ) ) + 1e-10) * self.weight.reshape( (1,-1,1,1) ) + return x + + class s3fd_litenn(lnn.Module): + def __init__(self): + self.conv1_1 = lnn.Conv2D(3, 64, 3) + self.conv1_2 = lnn.Conv2D(64, 64, 3) + + self.conv2_1 = lnn.Conv2D(64, 128, 3) + self.conv2_2 = lnn.Conv2D(128, 128, 3) + + self.conv3_1 = lnn.Conv2D(128, 256, 3) + self.conv3_2 = lnn.Conv2D(256, 256, 3) + self.conv3_3 = lnn.Conv2D(256, 256, 3) + + self.conv4_1 = lnn.Conv2D(256, 512, 3) + self.conv4_2 = lnn.Conv2D(512, 512, 3) + self.conv4_3 = lnn.Conv2D(512, 512, 3) + + self.conv5_1 = lnn.Conv2D(512, 512, 3) + self.conv5_2 = lnn.Conv2D(512, 512, 3) + self.conv5_3 = lnn.Conv2D(512, 512, 3) + + self.fc6 = lnn.Conv2D(512, 1024, 3, padding=3) + self.fc7 = lnn.Conv2D(1024, 1024, 1) + + self.conv6_1 = lnn.Conv2D(1024, 256, 1) + self.conv6_2 = lnn.Conv2D(256, 512, 3, stride=2) + + self.conv7_1 = lnn.Conv2D(512, 128, 1) + self.conv7_2 = lnn.Conv2D(128, 256, 3, stride=2) + + self.conv3_3_norm = L2Norm_lnn(256, scale=10) + self.conv4_3_norm = L2Norm_lnn(512, scale=8) + self.conv5_3_norm = L2Norm_lnn(512, scale=5) + + self.conv3_3_norm_mbox_conf = lnn.Conv2D(256, 4, 3) + self.conv3_3_norm_mbox_loc = lnn.Conv2D(256, 4, 3) + self.conv4_3_norm_mbox_conf = lnn.Conv2D(512, 2, 3) + self.conv4_3_norm_mbox_loc = lnn.Conv2D(512, 4, 3) + self.conv5_3_norm_mbox_conf = lnn.Conv2D(512, 2, 3) + self.conv5_3_norm_mbox_loc = lnn.Conv2D(512, 4, 3) + + self.fc7_mbox_conf = lnn.Conv2D(1024, 2, 3) + self.fc7_mbox_loc = lnn.Conv2D(1024, 4, 3) + self.conv6_2_mbox_conf = lnn.Conv2D(512, 2, 3) + self.conv6_2_mbox_loc = lnn.Conv2D(512, 4, 3) + self.conv7_2_mbox_conf = lnn.Conv2D(256, 2, 3) + self.conv7_2_mbox_loc = lnn.Conv2D(256, 4, 3) + + def forward(self, x): + + h = lnn.relu(self.conv1_1(x)) + h = lnn.relu(self.conv1_2(h)) + h = lnn.max_pool2D(h) + + + h = lnn.relu(self.conv2_1(h)) + h = lnn.relu(self.conv2_2(h)) + h = lnn.max_pool2D(h) + + + h = lnn.relu(self.conv3_1(h)) + h = lnn.relu(self.conv3_2(h)) + h = lnn.relu(self.conv3_3(h)) + f3_3 = h + h = lnn.max_pool2D(h) + + h = lnn.relu(self.conv4_1(h)) + h = lnn.relu(self.conv4_2(h)) + h = lnn.relu(self.conv4_3(h)) + f4_3 = h + h = lnn.max_pool2D(h) + + h = lnn.relu(self.conv5_1(h)) + h = lnn.relu(self.conv5_2(h)) + h = lnn.relu(self.conv5_3(h)) + f5_3 = h + h = lnn.max_pool2D(h) + + h = lnn.relu(self.fc6(h)) + h = lnn.relu(self.fc7(h)) + ffc7 = h + + h = lnn.relu(self.conv6_1(h)) + h = lnn.relu(self.conv6_2(h)) + + f6_2 = h + + h = lnn.relu(self.conv7_1(h)) + h = lnn.relu(self.conv7_2(h)) + f7_2 = h + + f3_3 = self.conv3_3_norm(f3_3) + f4_3 = self.conv4_3_norm(f4_3) + f5_3 = self.conv5_3_norm(f5_3) + + + cls1 = self.conv3_3_norm_mbox_conf(f3_3) + reg1 = self.conv3_3_norm_mbox_loc(f3_3) + cls2 = lnn.softmax(self.conv4_3_norm_mbox_conf(f4_3), 1) + reg2 = self.conv4_3_norm_mbox_loc(f4_3) + cls3 = lnn.softmax(self.conv5_3_norm_mbox_conf(f5_3), 1) + reg3 = self.conv5_3_norm_mbox_loc(f5_3) + cls4 = lnn.softmax(self.fc7_mbox_conf(ffc7), 1) + reg4 = self.fc7_mbox_loc(ffc7) + cls5 = lnn.softmax(self.conv6_2_mbox_conf(f6_2), 1) + reg5 = self.conv6_2_mbox_loc(f6_2) + cls6 = lnn.softmax(self.conv7_2_mbox_conf(f7_2), 1) + reg6 = self.conv7_2_mbox_loc(f7_2) + + cls1 = lnn.concat( [ lnn.reduce_max( cls1[:,0:3,:,:], 1, keepdims=True ),cls1[:,3:4,:,:] ], 1 ) + cls1 = lnn.softmax(cls1, 1) + + return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] + + model_path = r"D:\DevelopPython\test\s3fd.pth" + model_weights = torch.load(str(model_path)) + + device = 'cpu' + fd_torch = s3fd_torch() + + fd_torch.load_state_dict(model_weights) + fd_torch.eval() + + #img = np.random.uniform ( size=(480,270,3) ) * 255 + img = cv2.imread ( r"D:\DevelopPython\test\00000.png" ) + + torch_img = torch.from_numpy( np.expand_dims((img - np.array([104, 117, 123])).transpose(2, 0, 1),0) ).float() + + + + #t = time.time() + with torch.no_grad(): + olist_torch = [x.data.cpu().numpy() for x in fd_torch( torch.autograd.Variable( torch_img) )] + #print ("torch took:", time.time() - t) + + + fd_litenn = s3fd_litenn() + convd2d_torch_to_lnn( fd_litenn.conv1_1, fd_torch.conv1_1 ) + convd2d_torch_to_lnn( fd_litenn.conv1_2, fd_torch.conv1_2 ) + convd2d_torch_to_lnn( fd_litenn.conv2_1, fd_torch.conv2_1 ) + convd2d_torch_to_lnn( fd_litenn.conv2_2, fd_torch.conv2_2 ) + convd2d_torch_to_lnn( fd_litenn.conv3_1, fd_torch.conv3_1 ) + convd2d_torch_to_lnn( fd_litenn.conv3_2, fd_torch.conv3_2 ) + convd2d_torch_to_lnn( fd_litenn.conv3_3, fd_torch.conv3_3 ) + convd2d_torch_to_lnn( fd_litenn.conv4_1, fd_torch.conv4_1 ) + convd2d_torch_to_lnn( fd_litenn.conv4_2, fd_torch.conv4_2 ) + convd2d_torch_to_lnn( fd_litenn.conv4_3, fd_torch.conv4_3 ) + convd2d_torch_to_lnn( fd_litenn.conv5_1, fd_torch.conv5_1 ) + convd2d_torch_to_lnn( fd_litenn.conv5_2, fd_torch.conv5_2 ) + convd2d_torch_to_lnn( fd_litenn.conv5_3, fd_torch.conv5_3 ) + convd2d_torch_to_lnn( fd_litenn.fc6, fd_torch.fc6 ) + convd2d_torch_to_lnn( fd_litenn.fc7, fd_torch.fc7 ) + convd2d_torch_to_lnn( fd_litenn.conv6_1, fd_torch.conv6_1 ) + convd2d_torch_to_lnn( fd_litenn.conv6_2, fd_torch.conv6_2 ) + convd2d_torch_to_lnn( fd_litenn.conv7_1, fd_torch.conv7_1 ) + convd2d_torch_to_lnn( fd_litenn.conv7_2, fd_torch.conv7_2 ) + + fd_litenn.conv3_3_norm.weight.set( fd_torch.conv3_3_norm.weight.detach().numpy() ) + fd_litenn.conv4_3_norm.weight.set( fd_torch.conv4_3_norm.weight.detach().numpy() ) + fd_litenn.conv5_3_norm.weight.set( fd_torch.conv5_3_norm.weight.detach().numpy() ) + + + convd2d_torch_to_lnn( fd_litenn.conv3_3_norm_mbox_conf , fd_torch.conv3_3_norm_mbox_conf ) + convd2d_torch_to_lnn( fd_litenn.conv3_3_norm_mbox_loc , fd_torch.conv3_3_norm_mbox_loc ) + convd2d_torch_to_lnn( fd_litenn.conv4_3_norm_mbox_conf , fd_torch.conv4_3_norm_mbox_conf ) + convd2d_torch_to_lnn( fd_litenn.conv4_3_norm_mbox_loc , fd_torch.conv4_3_norm_mbox_loc ) + convd2d_torch_to_lnn( fd_litenn.conv5_3_norm_mbox_conf , fd_torch.conv5_3_norm_mbox_conf ) + convd2d_torch_to_lnn( fd_litenn.conv5_3_norm_mbox_loc , fd_torch.conv5_3_norm_mbox_loc ) + convd2d_torch_to_lnn( fd_litenn.fc7_mbox_conf , fd_torch.fc7_mbox_conf ) + convd2d_torch_to_lnn( fd_litenn.fc7_mbox_loc , fd_torch.fc7_mbox_loc ) + convd2d_torch_to_lnn( fd_litenn.conv6_2_mbox_conf , fd_torch.conv6_2_mbox_conf ) + convd2d_torch_to_lnn( fd_litenn.conv6_2_mbox_loc , fd_torch.conv6_2_mbox_loc ) + convd2d_torch_to_lnn( fd_litenn.conv7_2_mbox_conf , fd_torch.conv7_2_mbox_conf ) + convd2d_torch_to_lnn( fd_litenn.conv7_2_mbox_loc , fd_torch.conv7_2_mbox_loc ) + + + + fd_litenn_input = lnn.Tensor_from_value ( torch_img.numpy() ) + olist_litenn = fd_litenn( fd_litenn_input ) + + abs_diff = 0 + for i in range(len(olist_torch)): + td = olist_torch[i] + ld = olist_litenn[i].np() + td = td[...,-1] + ld = ld[...,-1] + p = np.ndarray.flatten(td-ld) + diff = np.sum ( np.abs(p)) + print ("nparams=", len(p), " diff=",diff, "diff_per_param=", diff / len(p) ) + abs_diff += diff + print ("Total absolute diff = ", abs_diff) + + fd_litenn.save(r'D:\DevelopPython\Projects\litenn-apps\releases_litenn_apps\S3FD.npy') + + import code + code.interact(local=dict(globals(), **locals())) + + fd_keras_path = r"D:\DevelopPython\test\S3FD.npy" + #if Path(fd_keras_path).exists(): + # fd_keras = keras.models.load_model (fd_keras_path) + #else: + + #fd_keras = sf3d_keras(img.shape, fd_torch) + #fd_keras.save_weights (fd_keras_path) + + fd_keras = S3FD() + fd_keras.build() + + + + fd_keras.conv1_1.set_weights ( convd2d_from_torch(fd_torch.conv1_1) ) + fd_keras.conv1_2.set_weights ( convd2d_from_torch(fd_torch.conv1_2) ) + + fd_keras.conv2_1.set_weights ( convd2d_from_torch(fd_torch.conv2_1) ) + fd_keras.conv2_2.set_weights ( convd2d_from_torch(fd_torch.conv2_2) ) + + fd_keras.conv3_1.set_weights ( convd2d_from_torch(fd_torch.conv3_1) ) + fd_keras.conv3_2.set_weights ( convd2d_from_torch(fd_torch.conv3_2) ) + fd_keras.conv3_3.set_weights ( convd2d_from_torch(fd_torch.conv3_3) ) + + fd_keras.conv4_1.set_weights ( convd2d_from_torch(fd_torch.conv4_1) ) + fd_keras.conv4_2.set_weights ( convd2d_from_torch(fd_torch.conv4_2) ) + fd_keras.conv4_3.set_weights ( convd2d_from_torch(fd_torch.conv4_3) ) + + fd_keras.conv5_1.set_weights ( convd2d_from_torch(fd_torch.conv5_1) ) + fd_keras.conv5_2.set_weights ( convd2d_from_torch(fd_torch.conv5_2) ) + fd_keras.conv5_3.set_weights ( convd2d_from_torch(fd_torch.conv5_3) ) + + fd_keras.fc6.set_weights ( convd2d_from_torch(fd_torch.fc6) ) + fd_keras.fc7.set_weights ( convd2d_from_torch(fd_torch.fc7) ) + + fd_keras.conv6_1.set_weights ( convd2d_from_torch(fd_torch.conv6_1) ) + fd_keras.conv6_2.set_weights ( convd2d_from_torch(fd_torch.conv6_2) ) + + fd_keras.conv7_1.set_weights ( convd2d_from_torch(fd_torch.conv7_1) ) + fd_keras.conv7_2.set_weights ( convd2d_from_torch(fd_torch.conv7_2) ) + + fd_keras.conv3_3_norm.set_weights ( l2norm_from_torch(fd_torch.conv3_3_norm)) + fd_keras.conv4_3_norm.set_weights ( l2norm_from_torch(fd_torch.conv4_3_norm)) + fd_keras.conv5_3_norm.set_weights ( l2norm_from_torch(fd_torch.conv5_3_norm)) + + fd_keras.conv3_3_norm_mbox_conf.set_weights ( convd2d_from_torch(fd_torch.conv3_3_norm_mbox_conf) ) + fd_keras.conv3_3_norm_mbox_loc .set_weights ( convd2d_from_torch(fd_torch.conv3_3_norm_mbox_loc) ) + + fd_keras.conv4_3_norm_mbox_conf.set_weights ( convd2d_from_torch(fd_torch.conv4_3_norm_mbox_conf) ) + fd_keras.conv4_3_norm_mbox_loc .set_weights ( convd2d_from_torch(fd_torch.conv4_3_norm_mbox_loc) ) + + fd_keras.conv5_3_norm_mbox_conf.set_weights ( convd2d_from_torch(fd_torch.conv5_3_norm_mbox_conf) ) + fd_keras.conv5_3_norm_mbox_loc .set_weights ( convd2d_from_torch(fd_torch.conv5_3_norm_mbox_loc) ) + + fd_keras.fc7_mbox_conf.set_weights ( convd2d_from_torch(fd_torch.fc7_mbox_conf) ) + fd_keras.fc7_mbox_loc .set_weights ( convd2d_from_torch(fd_torch.fc7_mbox_loc) ) + + fd_keras.conv6_2_mbox_conf.set_weights ( convd2d_from_torch(fd_torch.conv6_2_mbox_conf) ) + fd_keras.conv6_2_mbox_loc .set_weights ( convd2d_from_torch(fd_torch.conv6_2_mbox_loc) ) + + fd_keras.conv7_2_mbox_conf.set_weights ( convd2d_from_torch(fd_torch.conv7_2_mbox_conf) ) + fd_keras.conv7_2_mbox_loc .set_weights ( convd2d_from_torch(fd_torch.conv7_2_mbox_loc) ) + + fd_keras.save_weights ( fd_keras_path ) + + import code + code.interact(local=dict(globals(), **locals())) + + + + inp = tf.placeholder(tf.float32, (None,None,None,3) ) + outp = fd_keras(inp) + + t = time.time() + olist_keras = nn.tf_sess.run (outp, feed_dict={inp: np.expand_dims(img,0)}) + print ("keras took:", time.time() - t) + + abs_diff = 0 + for i in range(len(olist_torch)): + td = np.transpose( olist_torch[i], (0,2,3,1) ) + kd = olist_keras[i] + td = td[...,-1] + kd = kd[...,-1] + p = np.ndarray.flatten(td-kd) + diff = np.sum ( np.abs(p)) + print ("nparams=", len(p), " diff=",diff, "diff_per_param=", diff / len(p) ) + abs_diff += diff + print ("Total absolute diff = ", abs_diff) + + import code + code.interact(local=dict(globals(), **locals())) + + t = time.time() + bbox_torch = detect_torch(olist_torch) + bbox_torch = bbox_torch[ nms(bbox_torch, 0.3) , :] + bbox_torch = [x for x in bbox_torch if x[-1] >= 0.5] + print ("torch took:", time.time() - t) + + t = time.time() + bbox_keras = detect_keras(olist_keras) + print ("keras took:", time.time() - t) + + #bbox_keras = bbox_keras[ nms(bbox_keras, 0.3) , :] + #bbox_keras = [x for x in bbox_keras if x[-1] >= 0.5] + + print (bbox_torch) + print (bbox_keras) + + import code + code.interact(local=dict(globals(), **locals())) + #=============================================================================== + + + + + """ + inp_t = nn.Tensor( (1,3,4,4), init=nn.InitConst(1.0) ) + kernel_t = nn.Tensor( (1,3,3,3), init=nn.InitConst(1.0) ) + + x = nn.depthwise_conv2D(inp_t, kernel_t) + x.backward(grad_for_non_trainables=True) + + print(kernel_t.get_grad().np() ) + import code + code.interact(local=dict(globals(), **locals())) + """ + #nn.devices.set_current( nn.devices.get() ) + + res = 64 + batch_size=1 + lowest_dense_res = res // (2**4) + + class Downscale(nn.Module): + def __init__(self, in_ch, out_ch, kernel_size=5 ): + self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, stride=2, padding='same') + + def forward(self, x): + x = self.conv1(x) + #x = self.bn1(x) + x = nn.leaky_relu(x, 0.1) + + return x + + class DownscaleBlock(nn.Module): + def __init__(self, in_ch, ch, n_downscales, kernel_size): + downs = [] + + last_ch = in_ch + for i in range(n_downscales): + cur_ch = ch*( min(2**i, 8) ) + downs.append ( Downscale(last_ch, cur_ch, kernel_size=kernel_size) ) + last_ch = cur_ch + self.downs = downs + + self.out_ch = last_ch + + def forward(self, x): + for down in self.downs: + x = down(x) + return x + + class Upscale(nn.Module): + def __init__(self, in_ch, out_ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='same') + #self.conv1 = nn.Conv2DTranspose ( in_ch, out_ch, kernel_size=kernel_size, padding='same') + + def forward(self, x): + x = nn.leaky_relu(self.conv1(x), 0.1) + x = nn.depth_to_space(x, 2) + return x + + class ResidualBlock(nn.Module): + def __init__(self, ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='same') + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='same') + + def forward(self, inp): + x = self.conv1(inp) + x = nn.leaky_relu(x, 0.2) + x = self.conv2(x) + x = nn.leaky_relu(inp + x, 0.2) + return x + + class Encoder(nn.Module): + def __init__(self, in_ch, e_ch): + self.down1 = DownscaleBlock (in_ch, e_ch, 4, 5) + + def forward(self, x): + x = self.down1(x) + x = nn.flatten(x) + return x + + class Inter(nn.Module): + def __init__(self, in_ch, ae_ch, ae_out_ch): + self.dense1 = nn.Dense( in_ch, ae_ch )#, weight_initializer=nn.initializer.RandomUniform(-0.00000001, 0.00000001) ) + self.dense2 = nn.Dense( ae_ch, lowest_dense_res*lowest_dense_res*ae_out_ch ) + + self.upscale1 = Upscale(ae_out_ch, ae_out_ch) + self.ae_out_ch = ae_out_ch + + def forward(self, x): + x = self.dense1(x) + x = self.dense2(x) + x = nn.reshape(x, (x.shape[0], self.ae_out_ch, lowest_dense_res, lowest_dense_res) ) + x = self.upscale1(x) + return x + + class Decoder(nn.Module): + def __init__(self, in_ch, d_ch, d_mask_ch ): + self.upscale0 = Upscale(in_ch, d_ch*8, kernel_size=3) + self.upscale1 = Upscale(d_ch*8, d_ch*4, kernel_size=3) + self.upscale2 = Upscale(d_ch*4, d_ch*2, kernel_size=3) + self.out_conv = nn.Conv2D( d_ch*2, 3, kernel_size=1, padding='same') + + self.res0 = ResidualBlock(d_ch*8, kernel_size=3) + self.res1 = ResidualBlock(d_ch*4, kernel_size=3) + self.res2 = ResidualBlock(d_ch*2, kernel_size=3) + + + + self.upscalem0 = Upscale(in_ch, d_mask_ch*8, kernel_size=3) + self.upscalem1 = Upscale(d_mask_ch*8, d_mask_ch*4, kernel_size=3) + self.upscalem2 = Upscale(d_mask_ch*4, d_mask_ch*2, kernel_size=3) + self.out_convm = nn.Conv2D( d_mask_ch*2, 1, kernel_size=1, padding='same') + + def forward(self, z): + x = self.upscale0(z) + x = self.res0(x) + x = self.upscale1(x) + x = self.res1(x) + x = self.upscale2(x) + x = self.res2(x) + x = nn.sigmoid(self.out_conv(x)) + + #m = self.upscalem0(z) + #m = self.upscalem1(m) + #m = self.upscalem2(m) + #m = nn.sigmoid(self.out_convm(m)) + + return x#, m + + class DecoderSRC(Decoder): pass + class DecoderDST(Decoder): pass + + e_dims=64 + d_dims=64 + d_mask_dims=22 + + ae_dims = 256 + + enc = Encoder(3, e_dims) + + inter = Inter( enc.shallow_forward( nn.Tensor( (1,3,res,res) )).shape[1], ae_dims, ae_dims ) + + + dec_src = DecoderSRC(256, d_dims, d_mask_dims) + dec_dst = DecoderDST(256, d_dims, d_mask_dims) + + + + """ + enc.save(r'D:\enc.npy') + inter.save(r'D:\inter.npy') + dec_src.save(r'D:\dec_src.npy') + dec_dst.save(r'D:\dec_dst.npy') + """ + + + opt = nn.optimizer.Adam(enc.trainables()+inter.trainables()+dec_src.trainables()+dec_dst.trainables(), lr=5e-5, lr_dropout=0.7) + + face_type = FaceType.HALF + gen1 = SampleGeneratorFace( Path(r'F:\DeepFaceLabCUDA9.2SSE\workspace\data_src\aligned'), + batch_size=batch_size, + sample_process_options=SampleProcessor.Options(random_flip=True), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':'NCHW', 'resolution': res}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':'NCHW', 'resolution': res}, + ], + generators_count=2 ) + gen2 = SampleGeneratorFace( Path(r'F:\DeepFaceLabCUDA9.2SSE\workspace\data_dst\aligned'), + batch_size=batch_size, + sample_process_options=SampleProcessor.Options(random_flip=True), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':'NCHW', 'resolution': res}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':face_type, 'data_format':'NCHW', 'resolution': res}, + ], + generators_count=2 ) + + iter = 0 + def train_func(): + nonlocal iter + iter += 1 + + tim = time.time() + opt.zero_grad() + warped_src_n, target_src_n = next(gen1) + warped_dst_n, target_dst_n = next(gen2) + + + + nn.devices.wait() + tim = time.time() + + warped_src_t = nn.Tensor_sliced_from_value (warped_src_n) + target_src_t = nn.Tensor_sliced_from_value (target_src_n) + warped_dst_t = nn.Tensor_sliced_from_value (warped_dst_n) + target_dst_t = nn.Tensor_sliced_from_value (target_dst_n) + + #srcm_t = nn.Tensor( (batch_size, 1, res,res) )#, init=nn.InitRandomUniform() ) + #srcm_t.fill(1.0) + #dstm_t = nn.Tensor( (batch_size, 1, res,res) )#, init=nn.InitRandomUniform() ) + #dstm_t.fill(1.0) + + + src_code_t = inter(enc(warped_src_t)) + dst_code_t = inter(enc(warped_dst_t)) + + rec_src_t = dec_src(src_code_t) + rec_dst_t = dec_dst(dst_code_t) + + loss1 = 10*nn.dssim(rec_src_t, target_src_t, max_val=1.0, filter_size=int(res/11.6)) + loss2 = nn.reduce_mean(10*nn.square(rec_src_t-target_src_t)) + + loss3 = 10*nn.dssim(rec_dst_t, target_dst_t, max_val=1.0, filter_size=int(res/11.6)) + loss4 = nn.reduce_mean( 10*nn.square(rec_dst_t-target_dst_t) ) + #loss += nn.reduce_mean( 10*nn.square(rec_dstm_t-dstm_t) ) + + #nn.Tensor.backward([loss2, loss4]) + nn.Tensor.backward([loss1, loss2, loss3, loss4]) + + #loss.backward()#grad_for_non_trainables=True) + #import code + #code.interact(local=dict(globals(), **locals())) + + #print(f'time of backward {time.time()-tim}') + + #print(dec.to_rgb.kernel.get_grad().np()) + #print(dec.to_rgb.kernel.get_grad().np()) + #print(rec_t.np()) + + #tim = time.time() + #import code + #code.interact(local=dict(globals(), **locals())) + + opt.step(multi_gpu_step= (iter % 10 == 0) ) + + nn.devices.wait() + + print(f'loss {loss2.np()}') + print(f'time {time.time()-tim}, obj count {nn.Tensor._object_count}') + + + rec_src_t = dec_src(inter(enc(target_src_t))) + rec_dst_t = dec_dst(inter(enc(target_dst_t))) + + rec_srcdst_t = dec_src(inter(enc(target_dst_t))) + + screen0 = np.transpose(warped_src_t.np()[0], (1,2,0) ) + screen1 = np.transpose(target_src_t.np()[0], (1,2,0) ) + screen2 = np.transpose(rec_src_t.np()[0], (1,2,0) ) + screen3 = np.transpose(target_dst_t.np()[0], (1,2,0) ) + screen4 = np.transpose(rec_dst_t.np()[0], (1,2,0) ) + screen5 = np.transpose(rec_srcdst_t.np()[0], (1,2,0) ) + screen = np.concatenate( [screen0, screen1, screen2, screen3, screen4, screen5], axis=1)# + screen = (screen*255).astype(np.uint8) + + cv2.imshow('', screen ) + cv2.waitKey(5) + + while True: + train_func() + + import code + code.interact(local=dict(globals(), **locals())) + + + + os.environ['PLAIDML_EXPERIMENTAL'] = 'false' + os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" + os.environ["PLAIDML_DEVICE_IDS"] = "opencl_nvidia_geforce_rtx_2080_ti.0" + + import keras + import plaidml + import plaidml.tile + PML = plaidml + PMLK = plaidml.keras.backend + PMLTile = plaidml.tile + + t1 = PMLK.placeholder( (512,32768) ) + t2 = PMLK.placeholder( (32768,4608) ) + d = PMLK.dot(t1,t2) + f = PMLK.function( [t1,t2], [d ]) + + for i in range(100): + n1 = np.random.randint( 2**8, size=(512,32768) ).astype(np.float32) + n2 = np.random.randint( 2**8, size=(32768,4608) ).astype(np.float32) + t = time.time() + x = f([n1,n2]) + print(f'time {time.time()-t} ') + + import code + code.interact(local=dict(globals(), **locals())) + + + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.BestGPU(), data_format="NCHW" ) + tf = nn.tf + tf_sess = nn.tf_sess + + class Decoder(nn.ModelBase): + def on_build(self): + self.c = nn.Conv2D( 512, 512, kernel_size=3, strides=1, padding='SAME', use_bias=False) + def forward(self, inp): + return self.c(inp) + + dec = Decoder(name='decoder') + dec.init_weights() + + #t1 = tf.get(tf.float32, (None,128,512,512) ) + t1 = tf.get_variable ("t1", (8,512,64,64), dtype=tf.float32) + loss = dec(t1) + + grads = nn.gradients (loss, dec.get_weights() ) + + + + for i in range(100): + nn.batch_set_value([( t1, np.random.randint( 2**8, size=(8,512,64,64) ).astype(np.float32))] ) + t = time.time() + #q = nn.tf_sess.run ( [ loss, ], feed_dict={t1 : np.random.randint( 2**8, size=(4,128,512,512) ).astype(np.float32)} ) + q = nn.tf_sess.run ( [ loss, grads ] ) + # + print(f'time {time.time()-t} ') + #print(q[1][0][0]) + import code + code.interact(local=dict(globals(), **locals())) + #=================================================== + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + tf = nn.tf + tf_sess = nn.tf_sess + + + #inp = tf.placeholder(tf.float32, (1,) ) + + inp = tf.get_variable ("inp", (1,), dtype=nn.floatx) + + nn.batch_set_value ( [(inp, [2.0])] ) + + #cond = tf.greater_equal(inp, 1 ) + #sel = tf.where (cond, [1], [0] ) + sel = tf.square(inp) + #gv = nn.gradients( sel , [inp]) + + #x = nn.tf_sess.run([sel], feed_dict={ inp : [-2] } ) + + #gv = nn.gradients( tf.square(x-target_t) , [triangles_t]) + + #while True: + #r, rxg= nn.tf_sess.run([x,xg*-1], feed_dict={ rd_t : rd_np, + + import code + code.interact(local=dict(globals(), **locals())) + #=================================================== + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.BestGPU() ) + + tf = nn.tf + + + def load_pb(path_to_pb): + with tf.gfile.GFile(path_to_pb, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def, name='') + return graph + + graph = load_pb (r"D:\DevelopPython\test\giga.pb") + + sess = tf.Session(graph=graph, config=nn.tf_sess_config) + + def get_op_value(op_name, n_output=0): + return sess.run ([ graph.get_operation_by_name(op_name).outputs[n_output] ])[0].astype(np.float32) + import code + code.interact(local=dict(globals(), **locals())) + + class FaceEnhancer (nn.ModelBase): + def __init__(self, name='FaceEnhancer'): + super().__init__(name=name) + + def on_build(self): + self.conv1 = nn.Conv2D (3, 64, kernel_size=3, strides=1, padding='SAME') + + self.dense1 = nn.Dense (1, 64, use_bias=False) + self.dense2 = nn.Dense (1, 64, use_bias=False) + + self.e0_conv0 = nn.Conv2D (64, 64, kernel_size=3, strides=1, padding='SAME') + self.e0_conv1 = nn.Conv2D (64, 64, kernel_size=3, strides=1, padding='SAME') + + self.e1_conv0 = nn.Conv2D (64, 112, kernel_size=3, strides=1, padding='SAME') + self.e1_conv1 = nn.Conv2D (112, 112, kernel_size=3, strides=1, padding='SAME') + + self.e2_conv0 = nn.Conv2D (112, 192, kernel_size=3, strides=1, padding='SAME') + self.e2_conv1 = nn.Conv2D (192, 192, kernel_size=3, strides=1, padding='SAME') + + self.e3_conv0 = nn.Conv2D (192, 336, kernel_size=3, strides=1, padding='SAME') + self.e3_conv1 = nn.Conv2D (336, 336, kernel_size=3, strides=1, padding='SAME') + + self.e4_conv0 = nn.Conv2D (336, 512, kernel_size=3, strides=1, padding='SAME') + self.e4_conv1 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + + self.center_conv0 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + self.center_conv1 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + self.center_conv2 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + self.center_conv3 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + + self.d4_conv0 = nn.Conv2D (1024, 512, kernel_size=3, strides=1, padding='SAME') + self.d4_conv1 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + + self.d3_conv0 = nn.Conv2D (848, 512, kernel_size=3, strides=1, padding='SAME') + self.d3_conv1 = nn.Conv2D (512, 512, kernel_size=3, strides=1, padding='SAME') + + self.d2_conv0 = nn.Conv2D (704, 288, kernel_size=3, strides=1, padding='SAME') + self.d2_conv1 = nn.Conv2D (288, 288, kernel_size=3, strides=1, padding='SAME') + + self.d1_conv0 = nn.Conv2D (400, 160, kernel_size=3, strides=1, padding='SAME') + self.d1_conv1 = nn.Conv2D (160, 160, kernel_size=3, strides=1, padding='SAME') + + self.d0_conv0 = nn.Conv2D (224, 96, kernel_size=3, strides=1, padding='SAME') + self.d0_conv1 = nn.Conv2D (96, 96, kernel_size=3, strides=1, padding='SAME') + + self.out1x_conv0 = nn.Conv2D (96, 48, kernel_size=3, strides=1, padding='SAME') + self.out1x_conv1 = nn.Conv2D (48, 3, kernel_size=3, strides=1, padding='SAME') + + self.dec2x_conv0 = nn.Conv2D (96, 96, kernel_size=3, strides=1, padding='SAME') + self.dec2x_conv1 = nn.Conv2D (96, 96, kernel_size=3, strides=1, padding='SAME') + + self.out2x_conv0 = nn.Conv2D (96, 48, kernel_size=3, strides=1, padding='SAME') + self.out2x_conv1 = nn.Conv2D (48, 3, kernel_size=3, strides=1, padding='SAME') + + self.dec4x_conv0 = nn.Conv2D (96, 72, kernel_size=3, strides=1, padding='SAME') + self.dec4x_conv1 = nn.Conv2D (72, 72, kernel_size=3, strides=1, padding='SAME') + + self.out4x_conv0 = nn.Conv2D (72, 36, kernel_size=3, strides=1, padding='SAME') + self.out4x_conv1 = nn.Conv2D (36, 3 , kernel_size=3, strides=1, padding='SAME') + + def forward(self, inp): + bgr, param, param1 = inp + + x = self.conv1(bgr) + a = self.dense1(param) + a = tf.reshape(a, (-1,1,1,64) ) + + b = self.dense2(param1) + b = tf.reshape(b, (-1,1,1,64) ) + + x = tf.nn.leaky_relu(x+a+b, 0.1) + + x = tf.nn.leaky_relu(self.e0_conv0(x), 0.1) + x = e0 = tf.nn.leaky_relu(self.e0_conv1(x), 0.1) + + x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + x = tf.nn.leaky_relu(self.e1_conv0(x), 0.1) + x = e1 = tf.nn.leaky_relu(self.e1_conv1(x), 0.1) + + x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + x = tf.nn.leaky_relu(self.e2_conv0(x), 0.1) + x = e2 = tf.nn.leaky_relu(self.e2_conv1(x), 0.1) + + x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + x = tf.nn.leaky_relu(self.e3_conv0(x), 0.1) + x = e3 = tf.nn.leaky_relu(self.e3_conv1(x), 0.1) + + x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + x = tf.nn.leaky_relu(self.e4_conv0(x), 0.1) + x = e4 = tf.nn.leaky_relu(self.e4_conv1(x), 0.1) + + x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], "VALID") + x = tf.nn.leaky_relu(self.center_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.center_conv1(x), 0.1) + x = tf.nn.leaky_relu(self.center_conv2(x), 0.1) + x = tf.nn.leaky_relu(self.center_conv3(x), 0.1) + + x = tf.concat( [nn.tf_upsample2d_bilinear(x), e4], -1 ) + x = tf.nn.leaky_relu(self.d4_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.d4_conv1(x), 0.1) + + x = tf.concat( [nn.tf_upsample2d_bilinear(x), e3], -1 ) + x = tf.nn.leaky_relu(self.d3_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.d3_conv1(x), 0.1) + + x = tf.concat( [nn.tf_upsample2d_bilinear(x), e2], -1 ) + x = tf.nn.leaky_relu(self.d2_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.d2_conv1(x), 0.1) + + x = tf.concat( [nn.tf_upsample2d_bilinear(x), e1], -1 ) + x = tf.nn.leaky_relu(self.d1_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.d1_conv1(x), 0.1) + + x = tf.concat( [nn.tf_upsample2d_bilinear(x), e0], -1 ) + x = tf.nn.leaky_relu(self.d0_conv0(x), 0.1) + x = d0 = tf.nn.leaky_relu(self.d0_conv1(x), 0.1) + + x = tf.nn.leaky_relu(self.out1x_conv0(x), 0.1) + x = self.out1x_conv1(x) + out1x = bgr + tf.nn.tanh(x) + + x = d0 + x = tf.nn.leaky_relu(self.dec2x_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.dec2x_conv1(x), 0.1) + x = d2x = nn.tf_upsample2d_bilinear(x) + + x = tf.nn.leaky_relu(self.out2x_conv0(x), 0.1) + x = self.out2x_conv1(x) + + out2x = nn.tf_upsample2d_bilinear(out1x) + tf.nn.tanh(x) + + x = d2x + x = tf.nn.leaky_relu(self.dec4x_conv0(x), 0.1) + x = tf.nn.leaky_relu(self.dec4x_conv1(x), 0.1) + x = d4x = nn.tf_upsample2d_bilinear(x) + + x = tf.nn.leaky_relu(self.out4x_conv0(x), 0.1) + x = self.out4x_conv1(x) + + out4x = nn.tf_upsample2d_bilinear(out2x) + tf.nn.tanh(x) + + return out4x + + + with tf.device ("/CPU:0"): + face_enhancer = FaceEnhancer() + + if True: + face_enhancer.load_weights (r"D:\DevelopPython\test\FaceEnhancer.npy") + + face_enhancer.save_weights (r"D:\DevelopPython\test\FaceEnhancer.npy", np.float16) + face_enhancer.load_weights (r"D:\DevelopPython\test\FaceEnhancer.npy") + else: + face_enhancer.build() + face_enhancer.conv1.set_weights( [get_op_value('tl_unet1x2x4x/paramW'), get_op_value('tl_unet1x2x4x/paramB')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.dense1.set_weights( [get_op_value('tl_unet1x2x4x/paramInW')] ) + face_enhancer.dense2.set_weights( [get_op_value('tl_unet1x2x4x/paramInW1')] ) + + face_enhancer.e0_conv0.set_weights( [get_op_value('tl_unet1x2x4x/Encoder_0/w0'), get_op_value('tl_unet1x2x4x/Encoder_0/b0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.e0_conv1.set_weights( [get_op_value('tl_unet1x2x4x/Encoder_0/w1'), get_op_value('tl_unet1x2x4x/Encoder_0/b1')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.e1_conv0.set_weights( [get_op_value('tl_unet1x2x4x/Encoder_1/w0'), get_op_value('tl_unet1x2x4x/Encoder_1/b0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.e1_conv1.set_weights( [get_op_value('tl_unet1x2x4x/Encoder_1/w1'), get_op_value('tl_unet1x2x4x/Encoder_1/b1')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.e2_conv0.set_weights( [get_op_value('tl_unet1x2x4x/Encoder_2/w0'), get_op_value('tl_unet1x2x4x/Encoder_2/b0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.e2_conv1.set_weights( [get_op_value('tl_unet1x2x4x/Encoder_2/w1'), get_op_value('tl_unet1x2x4x/Encoder_2/b1')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.e3_conv0.set_weights( [get_op_value('tl_unet1x2x4x/Encoder_3/w0'), get_op_value('tl_unet1x2x4x/Encoder_3/b0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.e3_conv1.set_weights( [get_op_value('tl_unet1x2x4x/Encoder_3/w1'), get_op_value('tl_unet1x2x4x/Encoder_3/b1')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.e4_conv0.set_weights( [get_op_value('tl_unet1x2x4x/Encoder_4/w0'), get_op_value('tl_unet1x2x4x/Encoder_4/b0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.e4_conv1.set_weights( [get_op_value('tl_unet1x2x4x/Encoder_4/w1'), get_op_value('tl_unet1x2x4x/Encoder_4/b1')[0].reshape( (1,1,1,-1)) ] ) + + face_enhancer.center_conv0.set_weights( [get_op_value('tl_unet1x2x4x/Center/w0'), get_op_value('tl_unet1x2x4x/Center/b0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.center_conv1.set_weights( [get_op_value('tl_unet1x2x4x/Center/w1'), get_op_value('tl_unet1x2x4x/Center/b1')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.center_conv2.set_weights( [get_op_value('tl_unet1x2x4x/Center/w2'), get_op_value('tl_unet1x2x4x/Center/b2')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.center_conv3.set_weights( [get_op_value('tl_unet1x2x4x/Center/w3'), get_op_value('tl_unet1x2x4x/Center/b3')[0].reshape( (1,1,1,-1)) ] ) + + face_enhancer.d4_conv0.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_4/w0'), get_op_value('tl_unet1x2x4x/Decoder_4/b0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.d4_conv1.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_4/w1'), get_op_value('tl_unet1x2x4x/Decoder_4/b1')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.d3_conv0.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_3/w0'), get_op_value('tl_unet1x2x4x/Decoder_3/b0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.d3_conv1.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_3/w1'), get_op_value('tl_unet1x2x4x/Decoder_3/b1')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.d2_conv0.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_2/w0'), get_op_value('tl_unet1x2x4x/Decoder_2/b0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.d2_conv1.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_2/w1'), get_op_value('tl_unet1x2x4x/Decoder_2/b1')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.d1_conv0.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_1/w0'), get_op_value('tl_unet1x2x4x/Decoder_1/b0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.d1_conv1.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_1/w1'), get_op_value('tl_unet1x2x4x/Decoder_1/b1')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.d0_conv0.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_0/w0'), get_op_value('tl_unet1x2x4x/Decoder_0/b0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.d0_conv1.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_0/w1'), get_op_value('tl_unet1x2x4x/Decoder_0/b1')[0].reshape( (1,1,1,-1)) ] ) + + face_enhancer.out1x_conv0.set_weights( [get_op_value('tl_unet1x2x4x/out1x/W0'), get_op_value('tl_unet1x2x4x/out1x/B0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.out1x_conv1.set_weights( [get_op_value('tl_unet1x2x4x/out1x/W1'), get_op_value('tl_unet1x2x4x/out1x/B1')[0].reshape( (1,1,1,-1)) ] ) + + face_enhancer.dec2x_conv0.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_2x/w0'), get_op_value('tl_unet1x2x4x/Decoder_2x/b0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.dec2x_conv1.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_2x/w1'), get_op_value('tl_unet1x2x4x/Decoder_2x/b1')[0].reshape( (1,1,1,-1)) ] ) + + face_enhancer.out2x_conv0.set_weights( [get_op_value('tl_unet1x2x4x/out2x/W0'), get_op_value('tl_unet1x2x4x/out2x/B0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.out2x_conv1.set_weights( [get_op_value('tl_unet1x2x4x/out2x/W1'), get_op_value('tl_unet1x2x4x/out2x/B1')[0].reshape( (1,1,1,-1)) ] ) + + face_enhancer.dec4x_conv0.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_4x/w0'), get_op_value('tl_unet1x2x4x/Decoder_4x/b0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.dec4x_conv1.set_weights( [get_op_value('tl_unet1x2x4x/Decoder_4x/w1'), get_op_value('tl_unet1x2x4x/Decoder_4x/b1')[0].reshape( (1,1,1,-1)) ] ) + + face_enhancer.out4x_conv0.set_weights( [get_op_value('tl_unet1x2x4x/out4x/W0'), get_op_value('tl_unet1x2x4x/out4x/B0')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.out4x_conv1.set_weights( [get_op_value('tl_unet1x2x4x/out4x/W1'), get_op_value('tl_unet1x2x4x/out4x/B1')[0].reshape( (1,1,1,-1)) ] ) + face_enhancer.save_weights (r"D:\DevelopPython\test\FaceEnhancer.npy") + + + #import code + #code.interact(local=dict(globals(), **locals())) + + + """ + x = Conv2D (64, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/paramW'), get_op_value('tl_unet1x2x4x/paramB')[0] ] )(bgr_inp) + + a = Dense (64, use_bias=False, weights=[get_op_value('tl_unet1x2x4x/paramInW')] ) ( t_param_inp ) + a = Reshape( (1,1,64) )(a) + b = Dense (64, use_bias=False, weights=[get_op_value('tl_unet1x2x4x/paramInW1')] ) ( t_param1_inp ) + b = Reshape( (1,1,64) )(b) + x = Add()([x,a,b]) + + x = LeakyReLU(0.1)(x) + + x = LeakyReLU(0.1)(Conv2D (64, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Encoder_0/w0'), get_op_value('tl_unet1x2x4x/Encoder_0/b0')[0] ] )(x)) + x = e0 = LeakyReLU(0.1)(Conv2D (64, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Encoder_0/w1'), get_op_value('tl_unet1x2x4x/Encoder_0/b1')[0] ] )(x)) + + + x = AveragePooling2D()(x) + x = LeakyReLU(0.1)(Conv2D (112, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Encoder_1/w0'), get_op_value('tl_unet1x2x4x/Encoder_1/b0')[0] ] )(x)) + x = e1 = LeakyReLU(0.1)(Conv2D (112, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Encoder_1/w1'), get_op_value('tl_unet1x2x4x/Encoder_1/b1')[0] ] )(x)) + + x = AveragePooling2D()(x) + x = LeakyReLU(0.1)(Conv2D (192, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Encoder_2/w0'), get_op_value('tl_unet1x2x4x/Encoder_2/b0')[0] ] )(x)) + x = e2 = LeakyReLU(0.1)(Conv2D (192, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Encoder_2/w1'), get_op_value('tl_unet1x2x4x/Encoder_2/b1')[0] ] )(x)) + + x = AveragePooling2D()(x) + x = LeakyReLU(0.1)(Conv2D (336, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Encoder_3/w0'), get_op_value('tl_unet1x2x4x/Encoder_3/b0')[0] ] )(x)) + x = e3 = LeakyReLU(0.1)(Conv2D (336, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Encoder_3/w1'), get_op_value('tl_unet1x2x4x/Encoder_3/b1')[0] ] )(x)) + + x = AveragePooling2D()(x) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Encoder_4/w0'), get_op_value('tl_unet1x2x4x/Encoder_4/b0')[0] ] )(x)) + x = e4 = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Encoder_4/w1'), get_op_value('tl_unet1x2x4x/Encoder_4/b1')[0] ] )(x)) + + x = AveragePooling2D()(x) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Center/w0'), get_op_value('tl_unet1x2x4x/Center/b0')[0] ] )(x)) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Center/w1'), get_op_value('tl_unet1x2x4x/Center/b1')[0] ] )(x)) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Center/w2'), get_op_value('tl_unet1x2x4x/Center/b2')[0] ] )(x)) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Center/w3'), get_op_value('tl_unet1x2x4x/Center/b3')[0] ] )(x)) + + + x = Concatenate()([ BilinearInterpolation()(x), e4 ]) + + + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_4/w0'), get_op_value('tl_unet1x2x4x/Decoder_4/b0')[0] ] )(x)) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_4/w1'), get_op_value('tl_unet1x2x4x/Decoder_4/b1')[0] ] )(x)) + + x = Concatenate()([ BilinearInterpolation()(x), e3 ]) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_3/w0'), get_op_value('tl_unet1x2x4x/Decoder_3/b0')[0] ] )(x)) + x = LeakyReLU(0.1)(Conv2D (512, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_3/w1'), get_op_value('tl_unet1x2x4x/Decoder_3/b1')[0] ] )(x)) + + x = Concatenate()([ BilinearInterpolation()(x), e2 ]) + x = LeakyReLU(0.1)(Conv2D (288, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_2/w0'), get_op_value('tl_unet1x2x4x/Decoder_2/b0')[0] ] )(x)) + x = LeakyReLU(0.1)(Conv2D (288, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_2/w1'), get_op_value('tl_unet1x2x4x/Decoder_2/b1')[0] ] )(x)) + + x = Concatenate()([ BilinearInterpolation()(x), e1 ]) + x = LeakyReLU(0.1)(Conv2D (160, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_1/w0'), get_op_value('tl_unet1x2x4x/Decoder_1/b0')[0] ] )(x)) + x = LeakyReLU(0.1)(Conv2D (160, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_1/w1'), get_op_value('tl_unet1x2x4x/Decoder_1/b1')[0] ] )(x)) + + x = Concatenate()([ BilinearInterpolation()(x), e0 ]) + x = LeakyReLU(0.1)(Conv2D (96, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_0/w0'), get_op_value('tl_unet1x2x4x/Decoder_0/b0')[0] ] )(x)) + x = d0 = LeakyReLU(0.1)(Conv2D (96, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_0/w1'), get_op_value('tl_unet1x2x4x/Decoder_0/b1')[0] ] )(x)) + + x = LeakyReLU(0.1)(Conv2D (48, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/out1x/W0'), get_op_value('tl_unet1x2x4x/out1x/B0')[0] ] )(x)) + + x = Conv2D (3, 3, strides=1, padding='same', activation='tanh', weights=[get_op_value('tl_unet1x2x4x/out1x/W1'), get_op_value('tl_unet1x2x4x/out1x/B1')[0] ] )(x) + out1x = Add()([bgr_inp, x]) + + x = d0 + x = LeakyReLU(0.1)(Conv2D (96, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_2x/w0'), get_op_value('tl_unet1x2x4x/Decoder_2x/b0')[0] ] )(x)) + x = LeakyReLU(0.1)(Conv2D (96, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_2x/w1'), get_op_value('tl_unet1x2x4x/Decoder_2x/b1')[0] ] )(x)) + x = d2x = BilinearInterpolation()(x) + + x = LeakyReLU(0.1)(Conv2D (48, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/out2x/W0'), get_op_value('tl_unet1x2x4x/out2x/B0')[0] ] )(x)) + x = Conv2D (3, 3, strides=1, padding='same', activation='tanh', weights=[get_op_value('tl_unet1x2x4x/out2x/W1'), get_op_value('tl_unet1x2x4x/out2x/B1')[0] ] )(x) + + out2x = Add()([BilinearInterpolation()(out1x), x]) + + x = d2x + x = LeakyReLU(0.1)(Conv2D (72, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_4x/w0'), get_op_value('tl_unet1x2x4x/Decoder_4x/b0')[0] ] )(x)) + x = LeakyReLU(0.1)(Conv2D (72, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/Decoder_4x/w1'), get_op_value('tl_unet1x2x4x/Decoder_4x/b1')[0] ] )(x)) + x = d4x = BilinearInterpolation()(x) + + x = LeakyReLU(0.1)(Conv2D (36, 3, strides=1, padding='same', weights=[get_op_value('tl_unet1x2x4x/out4x/W0'), get_op_value('tl_unet1x2x4x/out4x/B0')[0] ] )(x)) + x = Conv2D (3, 3, strides=1, padding='same', activation='tanh', weights=[get_op_value('tl_unet1x2x4x/out4x/W1'), get_op_value('tl_unet1x2x4x/out4x/B1')[0] ] )(x) + out4x = Add()([BilinearInterpolation()(out2x), x ]) + """ + + #model = keras.models.Model ( [bgr_inp,t_param_inp,t_param1_inp], [out4x] ) + #model.load_weights (r"D:\DevelopPython\test\Jiva.h5") + + #weights_filepath = Path(r"D:\DevelopPython\test\FaceEnhancer.npy") + #model.save_weights (str(weights_filepath)) + + #weights_filepath + #import code + #code.interact(local=dict(globals(), **locals())) + + + """ + + + param = np.array([0.2]) + param1 = np.array([1.0]) + + up_res = 4 + patch_size = 192 + patch_size_half = patch_size // 2 + + #inp_img = border_pad(inp_img, patch_size_half) + h,w,c = inp_img.shape + + i_max = w-patch_size+1 + j_max = h-patch_size+1 + + final_img = np.zeros ( (h*up_res,w*up_res,c), dtype=np.float32 ) + final_img_div = np.zeros ( (h*up_res,w*up_res,1), dtype=np.float32 ) + + + x = np.concatenate ( [ np.linspace (0,1,patch_size_half*up_res), np.linspace (1,0,patch_size_half*up_res) ] ) + x,y = np.meshgrid(x,x) + patch_mask = (x*y)[...,None] + + j=0 + while j < j_max: + i = 0 + while i < i_max: + is_last = i == i_max-1 + + patch_img = inp_img[j:j+patch_size, i:i+patch_size,:] + + x = model.predict( [ patch_img[None,...], param, param1 ] )[0] + + final_img [j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += x*patch_mask + final_img_div[j*up_res:(j+patch_size)*up_res, i*up_res:(i+patch_size)*up_res,:] += patch_mask + + if is_last: + break + + i = min( i+patch_size_half, i_max-1) + + if j == j_max-1: + break + j = min( j+patch_size_half, j_max-1) + + final_img_div[final_img_div==0] = 1.0 + final_img /= final_img_div + + cv2.imshow("", ( np.clip( (final_img/2+0.5)*255, 0, 255) ).astype(np.uint8) ) + cv2.waitKey(0) + """ + + + + """ + def border_pad(x, pad): + + x = np.concatenate ([ np.tile(x[:,0:1,:], (1,pad,1) ), + x, + np.tile(x[:,-2:-1,:], (1,pad,1) ) ], axis=1 ) + + x = np.concatenate ([ np.tile(x[0:1,:,:], (pad,1,1) ), + x, + np.tile(x[-2:-1,:,:], (pad,1,1) ) ], axis=0 ) + return x + def reflect_pad(x, pad): + x = np.concatenate ([ x[:,pad:0:-1,:], + x, + x[:,-2:-pad-2:-1,:] ], axis=1 ) + x = np.concatenate ([ x[pad:0:-1,:,:], + x, + x[-2:-pad-2:-1,:,:] ], axis=0 ) + return x + + + psnr1 = K.placeholder( (None,None,None)) + psnr2 = K.placeholder( (None,None,None)) + psnr_func = K.function([psnr1, psnr2], [tf.image.psnr (psnr1, psnr2, max_val=2.0)]) + + j=0 + while j < j_max: + i = 0 + while i < i_max: + is_first = i == 0 + is_last = i == i_max-1 + + pr=[] + psnrs=[] + + mod = 1 if is_first else -1 + + for n in range(n_psnr_patches): + + patch_img = inp_img[j:j+192, i+n*mod:i+n*mod+192,:] + bilinear_patch_img = bilinear_img[j*4:(j+192)*4, (i+n*mod)*4:(i+n*mod+192)*4,:] + + x = model.predict( [ patch_img[None,...], param, param1 ] )[0] + pr += [ x ] + psnrs += [ psnr_func ( [x, bilinear_patch_img ])[0] ] + + final_img[j*4:(j+192)*4, i*4:(i+192)*4,:] = pr[0] + + best_n = np.argmin(np.array(psnrs) ) + if best_n != 0: + final_img[j*4:(j+192)*4, (i+best_n*mod)*4:(i+best_n*mod+192)*4,:] = pr[best_n] + + if is_last: + break + + i = min( best_n+192, i_max-1) + + if j == j_max-1: + break + j = min( j+192, j_max-1) + """ + #patch_img = inp_img[j:j+192, i:i+192,:] + + #final_img[j*4:j*4+192*4, i*4:i*4+192*4,:] = img + + #x = model.predict( [ patch_img[None,...], param, param1 ] ) + + + + #cv2.imshow("", ( np.clip( (img/2+0.5)*255, 0, 255) ).astype(np.uint8) ) + #cv2.waitKey(0) + + #blur = cv2.GaussianBlur(x, (3, 3), 0) + #x = cv2.addWeighted(x, 1.0 + (0.5 * amount), blur, -(0.5 * amount), 0) + #cv2.filter2D(x, -1, kernel) + + #final_img [j*4:j*4+192*4, i*4:i*4+192*4,:] += img#np.clip(x/2+0.5,0, 1) + #final_img_div[j*4:j*4+192*4, i*4:i*4+192*4,:] += 1.0 + + #import code + #code.interact(local=dict(globals(), **locals())) + + + input = graph.get_tensor_by_name('netInput:0') + t_param = graph.get_tensor_by_name('t_param:0') + t_param1 = graph.get_tensor_by_name('t_param1:0') + + + filepath = r'D:\DevelopPython\test\00000.jpg' + img = cv2.imread(filepath).astype(np.float32) / 255.0 + inp_img = img *2 - 1 + inp_img = cv2.resize (inp_img, (192,192) ) + + """ + with tf.device ("/CPU:0"): + face_enhancer = FaceEnhancer(name=f'fe') + face_enhancer.load_weights (r"D:\DevelopPython\test\FaceEnhancer.npy") + face_enhancer.build_for_run ([ (tf.float32, (192,192,3) ), + (tf.float32, (1,) ), + (tf.float32, (1,) ), + ]) + """ + #writer = tf.summary.FileWriter(r'D:\logs', sess.graph) + + + + face_enhancer.build_for_run ([ (tf.float32, (192,192,3) ), + (tf.float32, (1,) ), + (tf.float32, (1,) ), + ]) + param = 0.2 + param1 = 1.0 + inp_x = 0 + while True: + #inp_img = img[-192:,inp_x:inp_x+192,:] + #inp_img = img[inp_x:inp_x+192,-192:,:] + #inp_img = img[-192:,-192:,:] + + #output = graph.get_tensor_by_name('tl_unet1x2x4x/out1x/Tanh:0') + output = graph.get_tensor_by_name('netOutput4X:0') + #output = graph.get_tensor_by_name('tl_unet1x2x4x/Conv2D:0') + + x1 = sess.run (output, feed_dict={input: inp_img[None,...], + t_param: np.array([param]), + t_param1: np.array([param1]) } ) + + #x = face_enhancer_predict( inp_img[None,...], np.array([[param]]), np.array([[param1]]) ) + + + x = face_enhancer.run ([ inp_img[None,...], np.array([[param]]), np.array([[param1]]) ]) + + print (f"diff = {np.sum(np.abs(x1-x))}") + import code + code.interact(local=dict(globals(), **locals())) + + + + + x1 = np.clip( x1/2 + 0.5, 0, 1) + cv2.imshow("", (x1[0]*255).astype(np.uint8) ) + cv2.waitKey(0) + + x = np.clip( x/2 + 0.5, 0, 1) + cv2.imshow("", (x[0]*255).astype(np.uint8) ) + cv2.waitKey(0) + + #param += 0.1 + #inp_x += 1 + + + #[n.name for n in tf.get_default_graph().as_graph_def().node] + + ct_1_filepath = r'D:\DevelopPython\test\00000.jpg'#r'F:\DeepFaceLabCUDA9.2SSE\workspace\data_dst\aligned\00658_0.jpg' + ct_1_img = cv2.imread(ct_1_filepath).astype(np.float32) / 255.0 + ct_1_img_shape = ct_1_img.shape + ct_1_dflimg = DFLJPG.load ( ct_1_filepath) + + ct_1_mask = LandmarksProcessor.get_image_hull_mask (ct_1_img_shape , ct_1_dflimg.get_landmarks() ) + + img_size = 128 + face_mat = LandmarksProcessor.get_transform_mat( ct_1_dflimg.get_landmarks(), img_size, FaceType.FULL, scale=1.0) + wrp = cv2.warpAffine(ct_1_img, face_mat, (img_size, img_size), cv2.INTER_LANCZOS4) + + + cv2.imshow("", (wrp*255).astype(np.uint8) ) + cv2.waitKey(0) + + #===================================== + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( data_format="NCHW", device_config=nn.DeviceConfig.BestGPU() ) + print('starting') + tf = nn.tf + tf_sess = nn.tf_sess + + res = 128 + lowest_dense_res = 128 // (2**4) + + class Encoder(nn.ModelBase): + + def on_build(self, *args, **kwargs ): + self.conv1 = nn.Conv2D(3,64, kernel_size=5, strides=2, padding='SAME') + self.conv2 = nn.Conv2D(64,128, kernel_size=5, strides=2, padding='SAME') + self.conv3 = nn.Conv2D(128,256, kernel_size=5, strides=2, padding='SAME') + self.conv4 = nn.Conv2D(256,512, kernel_size=5, strides=2, padding='SAME') + + self.dense1 = nn.Dense( lowest_dense_res*lowest_dense_res*512, 256 ) + self.dense2 = nn.Dense( 256, lowest_dense_res*lowest_dense_res*512 ) + + self.upconv4 = nn.Conv2DTranspose( 512, 512, 3) + def forward(self, x): + x = self.conv1(x) + return x + x = tf.nn.leaky_relu(x) + + x = self.conv2(x) + + x = tf.nn.leaky_relu(x) + + x = self.conv3(x) + x = tf.nn.leaky_relu(x) + x = self.conv4(x) + x = tf.nn.leaky_relu(x) + + x = nn.flatten(x) + + x = self.dense1(x) + x = self.dense2(x) + x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, 512) + + x = tf.nn.leaky_relu(self.upconv4(x)) + return x + + + enc = Encoder(1,1, name='asd') + enc.init_weights() + + input_t = tf.placeholder(tf.float32, (None,3,res,res) ) + rec_t = enc(input_t) + + print( nn.tf_get_value( enc.conv1.weight) ) + input_n = np.ones( (1,3,res,res), dtype=np.float32) * 0.5 + q = nn.tf_sess.run ( [rec_t], feed_dict={input_t:input_n} ) + #print(q) + import code + code.interact(local=dict(globals(), **locals())) + #=================================================== + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( data_format="NHWC", device_config=nn.DeviceConfig.BestGPU() ) + print('starting') + tf = nn.tf + tf_sess = nn.tf_sess + + + t1 = tf.get_variable ("t1", (1,4,4,1), dtype=tf.float32) + loss = tf.nn.avg_pool(t1, [1,3,3,1], [1,2,2,1], 'SAME', data_format=nn.data_format) + + grads = nn.gradients (loss, [t1] ) + + t1_val = np.ones( t1.shape, dtype=np.float32) + + + nn.batch_set_value([( t1,t1_val)] ) + + q = nn.tf_sess.run ( [ loss, grads ] ) + + print( np.transpose(q[1][0][0], (0,3,1,2)) ) + import code + code.interact(local=dict(globals(), **locals())) + #=================================================== + + t1 = tf.get_variable ("t1", (2,2,4,), dtype=tf.float32) + loss = tf.reduce_sum(t1, (0,2,) ) + + grads = nn.gradients (loss, [t1] ) + + t1_val = np.array( + [ + [ [1,1,1,1], + [1,1,1,2] ], + [ [1,1,2,2], + [1,2,2,2] ], + + ]).astype(np.float32) + + + #w = t1_val.transpose( (1,0,2))#.reshape( (2,8) ) + + nn.batch_set_value([( t1,t1_val)] ) + + q = nn.tf_sess.run ( [ loss, grads ] ) + import code + code.interact(local=dict(globals(), **locals())) + #=================================================== + + + + class Decoder(nn.ModelBase): + def on_build(self): + self.d1 = nn.Dense( 65536,16384, use_bias=False) + self.d2 = nn.Dense( 16384,65536, use_bias=False) + def forward(self, inp): + x = inp + x = self.d1(x) + x = self.d2(x) + return x + + dec = Decoder(name='decoder') + dec.init_weights() + + t1 = tf.get_variable ("t1", (4,65536), dtype=tf.float32) + #t1 = tf.placeholder(tf.float32, (None,65536) ) + loss = dec(t1) + + grads = nn.gradients (loss, dec.get_weights() ) + + for i in range(100): + nn.batch_set_value([( t1, np.random.randint( 2**8, size=(4,65536) ).astype(np.float32))] ) + + t = time.time() + q = nn.tf_sess.run ( [ loss, ] ) + + print(f'time {time.time()-t} ') + + import code + code.interact(local=dict(globals(), **locals())) + #=================================================== + + + + + + + class Decoder(nn.ModelBase): + def on_build(self): + self.c = nn.Conv2DTranspose( 128, 128, kernel_size=3, strides=2, padding='SAME', use_bias=False) + def forward(self, inp): + return self.c(inp) + + dec = Decoder(name='decoder') + dec.init_weights() + + t1 = tf.placeholder(tf.float32, (None,64,64,128) ) + loss = dec(t1) + + grads = nn.gradients (loss, dec.get_weights() ) + + for i in range(100): + t = time.time() + q = nn.tf_sess.run ( [ loss, ], feed_dict={t1 : np.random.randint( 2**8, size=(4,64,64,128) ).astype(np.float32)} ) + print(f'time {time.time()-t} ') + + import code + code.interact(local=dict(globals(), **locals())) + #=================================================== + + + + + + + + + + + find_archi( 448 // 16) + import code + code.interact(local=dict(globals(), **locals())) + #=================================================== + + + + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + tf = nn.tf + tf_sess = nn.tf_sess + + img_path = Path(r'D:\DevelopPython\test\linus0.jpg') + dflimg = DFLIMG.load(img_path) + img = cv2_imread(img_path) + h,w,c = img.shape + + + mask = dflimg.get_xseg_mask() + mask = cv2.resize(mask, (w,h), interpolation=cv2.INTER_CUBIC ) + mask = np.clip(mask, 0, 1) + + inp = tf.placeholder(tf.float32, (1,h,w,1) ) + + x_blurred = nn.gaussian_blur(inp, max(1, w // 32) ) + + x = nn.tf_sess.run(x_blurred, feed_dict={ inp : mask[None,...,None].astype(np.float32) } ) + x = x[0] + #import code + #code.interact(local=dict(globals(), **locals())) + while True: + #cv2.imshow("",img ) + #cv2.waitKey(0) + cv2.imshow("", (mask*255).astype(np.uint8) ) + cv2.waitKey(0) + cv2.imshow("", np.clip( x*255, 0, 255).astype(np.uint8) ) + cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + #=================================================== + + + img_path = Path(r'F:\DeepFaceLabCUDA9.2SSE\workspace инауг илон гол\data_dst\aligned\00302_0.jpg') + dflimg = DFLIMG.load(img_path) + img = cv2_imread(img_path) + h,w,c = img.shape + + + mask = dflimg.get_xseg_mask() + mask = cv2.resize(mask, (w,h), cv2.INTER_CUBIC )[...,None] + + cnts = cv2.findContours(mask.astype(np.uint8), cv2.RETR_LIST , cv2.CHAIN_APPROX_TC89_KCOS ) + + # Get the largest found contour + + cnt = sorted(cnts[0], key = cv2.contourArea, reverse = True)[0].squeeze() + #import code + #code.interact(local=dict(globals(), **locals())) + #screen = np.zeros_like( mask, np.uint8 ) + #for x,y in cnt: + # cv2.circle(screen, (x,y), 1, (255,) ) + # + #while True: + # cv2.imshow("", (mask*255).astype(np.uint8) ) + # cv2.waitKey(0) + # cv2.imshow("", screen) + # cv2.waitKey(0) + + #import code + #code.interact(local=dict(globals(), **locals())) + + center = np.mean(cnt,0) + + cnt2 = cnt.copy().astype(np.float32) + + cnt2_c = center - cnt2 + cnt2_len = npla.norm(cnt2_c, axis=1, keepdims=True) + cnt2_vec = cnt2_c / cnt2_len + + l,t = cnt.min(0) + r,b = cnt.max(0) + c = np.mean(cnt,0) + cx, cy = c + + circle_rad = max( cy-t, b-cy, cx-l, r-cx ) + pts_count = 30 + + circle_pts = c + circle_rad*np.array( [ [np.sin(i*2*math.pi/pts_count ),np.cos(i*2*math.pi/pts_count ) ] for i in range(pts_count) ] ) + circle_pts = circle_pts.astype(np.int32) + + circle_pts2 = c + circle_rad*0.9*np.array( [ [np.sin(i*2*math.pi/pts_count ),np.cos(i*2*math.pi/pts_count ) ] for i in range(pts_count) ] ) + circle_pts2 = circle_pts2.astype(np.int32) + + # Anchor perimeter + pts_count = 120 + perim_pts = np.concatenate ( (np.concatenate ( [ np.arange(0,w+w/pts_count, w/pts_count)[...,None], np.array ( [[0]]*(pts_count+1) ) ], axis=-1 ), + np.concatenate ( [ np.arange(0,w+w/pts_count, w/pts_count)[...,None], np.array ( [[h]]*(pts_count+1) ) ], axis=-1 ), + np.concatenate ( [ np.array ( [[0]]*(pts_count+1) ), np.arange(0,h+h/pts_count, h/pts_count)[...,None] ], axis=-1 ), + np.concatenate ( [ np.array ( [[w]]*(pts_count+1) ), np.arange(0,h+h/pts_count, h/pts_count)[...,None] ], axis=-1 ) ), 0 ).astype(np.int32) + + + cnt2 += cnt2_vec * cnt2_len * 0.25 + cnt2 = cnt2.astype(np.int32) + cnt2 = np.concatenate ( (cnt2, perim_pts), 0 ) + cnt = np.concatenate ( (cnt, perim_pts), 0 ) + #for x,y in np.concatenate( [circle_pts, circle_pts2], 0 ): + screen = np.zeros_like( mask, np.uint8 ) + for x,y in np.concatenate( [cnt,cnt2], 0 ): + cv2.circle(screen, (x,y), 1, (255,) ) + + + + cv2.imshow("", screen) + cv2.waitKey(0) + #import code + #code.interact(local=dict(globals(), **locals())) + + + #new_img = mls_rigid_deformation_inv( img, circle_pts, circle_pts2 ) + new_img = mls_rigid_deformation_inv( img, cnt, cnt2, density=0.5 ) + #new_img = mls_similarity_deformation_inv( img, cnt, cnt2 ) + + + while True: + cv2.imshow("", img) + cv2.waitKey(0) + cv2.imshow("", new_img) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + #=================================================== + + layers = [ + [3,2], + [3,2], + [3,2], + [3,2], + ] + rf = calc_receptive_field_size(layers) + + x = find_archi(28) + print(x) + + + import code + code.interact(local=dict(globals(), **locals())) + + + cv2.imshow("", np.zeros( (256,256,3), dtype=np.uint8 ) ) + + while True: + ord_key = cv2.waitKeyEx(0) + if ord_key > 0: + print(f"ord_key {ord_key}") + + import code + code.interact(local=dict(globals(), **locals())) + + import code + code.interact(local=dict(globals(), **locals())) + + #============================================ + + + + image = cv2.imread(r'D:\DevelopPython\test\inpaint1.jpg') + mask = cv2.imread(r'D:\DevelopPython\test\inpaint1_mask.jpg')[:,:,0] + mask[mask > 0] = 255 + #import code + #code.interact(local=dict(globals(), **locals())) + #a = inpaint(image, mask, 9) + + a = Inpainter(image, mask).inpaint() + + cv2.imshow ("", np.clip(a, 0,255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + #======================================= + + + + image_paths = pathex.get_image_paths(r"E:\FakeFaceVideoSources\Datasets\CelebA\aligned_def\aligned") + image_paths_len = len(image_paths) + + while True: + + src1 = cv2_imread(image_paths[np.random.randint(image_paths_len)]).astype(np.float32) / 255.0 + src2 = cv2_imread(image_paths[np.random.randint(image_paths_len)]).astype(np.float32) / 255.0 + src3 = cv2_imread(image_paths[np.random.randint(image_paths_len)]).astype(np.float32) / 255.0 + + dst1 = cv2_imread(image_paths[np.random.randint(image_paths_len)]).astype(np.float32) / 255.0 + dst2 = cv2_imread(image_paths[np.random.randint(image_paths_len)]).astype(np.float32) / 255.0 + dst3 = cv2_imread(image_paths[np.random.randint(image_paths_len)]).astype(np.float32) / 255.0 + + while True: + t = time.time() + sot = imagelib.color_transfer_sot (src1, dst1, batch_size=30 ) + print(f'time took:{time.time()-t}') + + screen = np.concatenate([src1,dst1,sot], axis=0) + + + cv2.imshow ("", np.clip(screen*255, 0,255).astype(np.uint8) ) + cv2.waitKey(0) + + + import code + code.interact(local=dict(globals(), **locals())) + #======================================= + + img_path = Path(r'F:\DeepFaceLabCUDA9.2SSE\workspace ИНАУГ ГОЛ\data_dst\aligned\00001_0.jpg') + dflimg = DFLIMG.load(img_path) + img = cv2_imread(img_path) + h,w,c = img.shape + + + + + mask = dflimg.get_xseg_mask() + mask = cv2.resize(mask, (w,h), cv2.INTER_CUBIC )[...,None] + + cnts = cv2.findContours(mask.astype(np.uint8), cv2.RETR_LIST , cv2.CHAIN_APPROX_TC89_KCOS ) + + # Get the largest found contour + cnt = sorted(cnts[0], key = cv2.contourArea, reverse = True)[0].squeeze() + + center = np.mean(cnt,0) + + cnt2 = cnt.copy().astype(np.float32) + + cnt2_c = center - cnt2 + cnt2_len = npla.norm(cnt2_c, axis=1, keepdims=True) + cnt2_vec = cnt2_c / cnt2_len + + l,t = cnt.min(0) + r,b = cnt.max(0) + c = np.mean(cnt,0) + cx, cy = c + + circle_rad = max( cy-t, b-cy, cx-l, r-cx ) + pts_count = 30 + + circle_pts = c + circle_rad*np.array( [ [np.sin(i*2*math.pi/pts_count ),np.cos(i*2*math.pi/pts_count ) ] for i in range(pts_count) ] ) + circle_pts = circle_pts.astype(np.int32) + + circle_pts2 = c + circle_rad*0.9*np.array( [ [np.sin(i*2*math.pi/pts_count ),np.cos(i*2*math.pi/pts_count ) ] for i in range(pts_count) ] ) + circle_pts2 = circle_pts2.astype(np.int32) + + # Anchor perimeter + pts_count = 120 + perim_pts = np.concatenate ( (np.concatenate ( [ np.arange(0,w+w/pts_count, w/pts_count)[...,None], np.array ( [[0]]*(pts_count+1) ) ], axis=-1 ), + np.concatenate ( [ np.arange(0,w+w/pts_count, w/pts_count)[...,None], np.array ( [[h]]*(pts_count+1) ) ], axis=-1 ), + np.concatenate ( [ np.array ( [[0]]*(pts_count+1) ), np.arange(0,h+h/pts_count, h/pts_count)[...,None] ], axis=-1 ), + np.concatenate ( [ np.array ( [[w]]*(pts_count+1) ), np.arange(0,h+h/pts_count, h/pts_count)[...,None] ], axis=-1 ) ), 0 ).astype(np.int32) + + + cnt2 += cnt2_vec * cnt2_len * 0.05 + cnt2 = cnt2.astype(np.int32) + cnt2 = np.concatenate ( (cnt2, perim_pts), 0 ) + cnt = np.concatenate ( (cnt, perim_pts), 0 ) + #for x,y in np.concatenate( [circle_pts, circle_pts2], 0 ): + screen = np.zeros_like( mask, np.uint8 ) + for x,y in np.concatenate( [cnt,cnt2], 0 ): + cv2.circle(screen, (x,y), 1, (255,) ) + + + + #cv2.imshow("", screen) + #cv2.waitKey(0) + #import code + #code.interact(local=dict(globals(), **locals())) + + + #new_img = mls_rigid_deformation_inv( img, circle_pts, circle_pts2 ) + new_img = mls_rigid_deformation_inv( img, cnt, cnt2, density=0.5 ) + #new_img = mls_similarity_deformation_inv( img, cnt, cnt2 ) + + + while True: + cv2.imshow("", img) + cv2.waitKey(0) + cv2.imshow("", new_img) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + #=================================================== + + + landmarks_2D = np.array([ + [ 0.000213256, 0.106454 ], #17 + [ 0.0752622, 0.038915 ], #18 + [ 0.18113, 0.0187482 ], #19 + [ 0.29077, 0.0344891 ], #20 + [ 0.393397, 0.0773906 ], #21 + [ 0.586856, 0.0773906 ], #22 + [ 0.689483, 0.0344891 ], #23 + [ 0.799124, 0.0187482 ], #24 + [ 0.904991, 0.038915 ], #25 + [ 0.98004, 0.106454 ], #26 + [ 0.490127, 0.203352 ], #27 + [ 0.490127, 0.307009 ], #28 + [ 0.490127, 0.409805 ], #29 + [ 0.490127, 0.515625 ], #30 + [ 0.36688, 0.587326 ], #31 + [ 0.426036, 0.609345 ], #32 + [ 0.490127, 0.628106 ], #33 + [ 0.554217, 0.609345 ], #34 + [ 0.613373, 0.587326 ], #35 + [ 0.121737, 0.216423 ], #36 + [ 0.187122, 0.178758 ], #37 + [ 0.265825, 0.179852 ], #38 + [ 0.334606, 0.231733 ], #39 + [ 0.260918, 0.245099 ], #40 + [ 0.182743, 0.244077 ], #41 + [ 0.645647, 0.231733 ], #42 + [ 0.714428, 0.179852 ], #43 + [ 0.793132, 0.178758 ], #44 + [ 0.858516, 0.216423 ], #45 + [ 0.79751, 0.244077 ], #46 + [ 0.719335, 0.245099 ], #47 + [ 0.254149, 0.780233 ], #48 + [ 0.340985, 0.745405 ], #49 + [ 0.428858, 0.727388 ], #50 + [ 0.490127, 0.742578 ], #51 + [ 0.551395, 0.727388 ], #52 + [ 0.639268, 0.745405 ], #53 + [ 0.726104, 0.780233 ], #54 + [ 0.642159, 0.864805 ], #55 + [ 0.556721, 0.902192 ], #56 + [ 0.490127, 0.909281 ], #57 + [ 0.423532, 0.902192 ], #58 + [ 0.338094, 0.864805 ], #59 + [ 0.290379, 0.784792 ], #60 + [ 0.428096, 0.778746 ], #61 + [ 0.490127, 0.785343 ], #62 + [ 0.552157, 0.778746 ], #63 + [ 0.689874, 0.784792 ], #64 + [ 0.553364, 0.824182 ], #65 + [ 0.490127, 0.831803 ], #66 + [ 0.42689 , 0.824182 ] #67 + ], dtype=np.float32) + + landmarks_2D *= 256 + screen = np.zeros( (256,256,1) , np.uint8 ) + + for x,y in landmarks_2D: + cv2.circle(screen, (x,y), 1, (255,) ) + + cv2.imshow("", screen) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + for image_path in pathex.get_image_paths(r'E:\FakeFaceVideoSources\Putin\Photo'): + + img = cv2.imread(image_path).astype(np.float32) / 255.0 + dflimg = DFLJPG.load ( image_path) + + img_size = 128 + face_mat = LandmarksProcessor.get_transform_mat( dflimg.get_landmarks(), img_size, FaceType.MOUTH, scale=1.0) + wrp = cv2.warpAffine(img, face_mat, (img_size, img_size), cv2.INTER_LANCZOS4) + + cv2.imshow("", (wrp*255).astype(np.uint8) ) + cv2.waitKey(0) + + + + + import code + code.interact(local=dict(globals(), **locals())) + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + tf = nn.tf + tf_sess = nn.tf_sess + + w = 64 + h = 64 + + """ + triangles_count = 2 + triangles_np = np.array([ [ [0.0,0.0,-2.1], + [1.0,0.0,-2.1], + [0.0,1.0,-2.1] ], + + [ [0.0,0.0,-2.0], + [1.0,1.0,-2.0], + [0.0,1.0,-2.0] ], + + ], dtype=np.float32) + + triangles_colors_np = np.array([ [0.0,0.0,1.0], + [0.0,1.0,0.0], + ], dtype=np.float32) + + """ + + triangles_count = 1 + + triangles_np = np.array([ [ [-0.01,-0.01,-5.0], + [1.0,-0.01,-5.0], + [-0.01,1.0,-5.0] ], + + ], dtype=np.float32) + + triangles_colors_np = np.array([ [0.0,1.0,0.0], + ], dtype=np.float32) + + + #camera_pos_np = np.array([ [0.0,0.0,0.0] ], dtype=np.float32) + #camera_dir_np = np.array([ [0.0,0.0,-1.0] ], dtype=np.float32) + #camera_pos_t = tf.placeholder(tf.float32, (3,) ) + #camera_dir_t = tf.placeholder(tf.float32, (3,) ) + + # Create ray grid + mh=0.5-np.linspace(0,1,h) + mw=np.linspace(0,1,w)-0.5 + mw, mh = np.meshgrid(mw,mh) + rd_np = np.concatenate( [ mw[...,None], mh[...,None], -np.ones_like(mw)[...,None] ] , -1 ) + rd_np /= np.linalg.norm(rd_np, axis=-1, keepdims=True) + + sun_dir_np = np.array ([0.0,0.0,-1.0], dtype=np.float32) + + ro_t = tf.zeros ( (h,w,3), tf.float32 ) #tf.placeholder(tf.float32, (h,w,3) ) + rd_t = tf.placeholder(tf.float32, (h,w,3) ) + + sun_dir_t = tf.placeholder(tf.float32, (3,) ) + + + target_t = tf.placeholder(tf.float32, (h,w,3) ) + + #triangles_t = tf.placeholder(tf.float32, (triangles_count,3,3) ) + triangles_t = tf.get_variable ("w", (triangles_count,3,3), dtype=nn.floatx)#, initializer=tf.initializers.zeros ) + nn.batch_set_value ( [(triangles_t, [ [ [-0.01,-0.01,-5.0], + [1.0,-0.01,-5.0], + [-0.01,1.0,-5.0] ], + ] )] ) + + triangles_colors_t = tf.placeholder(tf.float32, (triangles_count,3) ) + + tris = tf.tile( triangles_t[None,None,...], (h,w,1,1,1) ) + + + ro_tris = tf.tile( ro_t[...,None,:], (1,1,triangles_count,1) ) + rd_tris = tf.tile( rd_t[...,None,:], (1,1,triangles_count,1) ) + + # Ray triangle intersection + # code borrowed from https://www.iquilezles.org/www/articles/intersectors/intersectors.htm + # result is u,v,t per [h,w,tri] + tris_v1v0 = tris[...,1,:] - tris[...,0,:] + tris_v2v0 = tris[...,2,:] - tris[...,0,:] + tris_rov0 = ro_tris-tris[:,:,:,0] + + tris_n = tf.linalg.cross (tris_v1v0, tris_v2v0) + tris_q = tf.linalg.cross (tris_rov0, rd_tris) + tris_d = 1.0 / tf.reduce_sum ( tf.multiply(rd_tris, tris_n), -1 ) + tris_u = tris_d * tf.reduce_sum ( tf.multiply(-tris_q, tris_v2v0), -1 ) + tris_v = tris_d * tf.reduce_sum ( tf.multiply(tris_q, tris_v1v0), -1 ) + tris_t = tris_d * tf.reduce_sum ( tf.multiply(-tris_n, tris_rov0), -1 ) + + tris_n /= tf.linalg.norm(tris_n, axis=-1, keepdims=True) + + #tris_hit_pos = ro_tris+rd_tris*tris_t[...,None] + + #import code + #code.interact(local=dict(globals(), **locals())) + + @tf.custom_gradient + def z_one_clip(x): + """ + x < 0 -> 0 + x >= 0 -> 1 + x >= 1 -> 0 + """ + #r = tf.clip_by_value ( tf.sign(x)+1, 0, 1 ) + r = tf.clip_by_value ( tf.sign(x)+tf.sign(x-1), -1, 1) + x = 1-tf.abs(r) + + + def grad(dy): + return r#tf.clip_by_value ( tf.sign(x)+tf.sign(x-1), -1, 1) + + return x, grad + + # Invert distances, so the most near get highest value, and far starts from 1 + tris_f_t = tf.reduce_max ( tris_t, axis=-1, keepdims=True) - tris_t + 1 + + + # Apply UV clip : zeros t values which rays outside triangle + #tris_uv_f_t = tf.reduce_min( tf.concat(( tris_u[...,None], + # (1-tris_u)[...,None], + # tris_v[...,None], + # (1-(tris_u+tris_v) )[...,None] + # ), -1), -1 ) + + tris_f_t *= z_one_clip(tris_u) + #tris_f_t *= z_one_clip(1-tris_u) + tris_f_t *= z_one_clip(tris_v) + tris_f_t *= z_one_clip(tris_u+tris_v) + + #tris_f_t *= z_one_clip(tris_uv_f_t) + + # Apply backplane clip : zeros tris_f_t by negative tris_t values + #tris_f_t *= z_one_clip(tris_t) + + + # Apply nearest tri clip + #tris_inv_t *= tf.sign( tris_inv_t - tf.reduce_max ( tris_inv_t, axis=-1, keepdims=True) )+1 + # + + #tris_t = tf.clip_by_value( tf.sign( tris_inv_t - tf.reduce_max ( tris_inv_t, axis=-1, keepdims=True) )+1, 0, 1) + #tris_t = tris_inv_t + + # Compute color + tris_colors = tf.tile( triangles_colors_t[None,None,...], (h,w,1,1) ) + + triangles_sun_dirs_t = tf.tile ( sun_dir_t[None,None,None,...], (h,w,triangles_count,1) ) + + #dif_color * scene.sun_power * max(0.0, dot( normal, -scene.sun_dir ) ) + + sun_dot = tf.reduce_sum ( tf.multiply(tris_n, -triangles_sun_dirs_t), -1, keepdims=True )#, 0, 1 ) + + tris_t = tris_f_t #tf.clip_by_value( tf.sign(tris_f_t), 0, 1) + + x = tris_t[...,None]* tris_colors #* sun_dot + #x = tris_t + # Sum axis of all tris colors + x = tf.reduce_sum(x, axis=-2) + + target_np = pickle.loads( Path(r'D:\tri.dat').read_bytes() ) + + + (xg,xv), = nn.gradients( tf.square(x-target_t) , [triangles_t]) + + while True: + r, rxg= nn.tf_sess.run([x,xg*-1], feed_dict={ rd_t : rd_np, sun_dir_t:sun_dir_np, + #triangles_t:triangles_np, + triangles_colors_t : triangles_colors_np, + target_t : target_np, + } + + ) + + cur_triangles_t = nn.tf_sess.run(triangles_t) + print(cur_triangles_t) + + cv2.imshow("", (r*255).astype(np.uint8) ) + cv2.waitKey(200) + + + + nn.batch_set_value ( [(triangles_t, cur_triangles_t+ rxg/10000.0) ] ) + + + #Path(r'D:\tri.dat').write_bytes( pickle.dumps(r) ) + + #import code + #code.interact(local=dict(globals(), **locals())) + + + #============================= + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + tf = nn.tf + tf_sess = nn.tf_sess + + var = tf.get_variable ("w", (1,), dtype=nn.floatx)#, initializer=tf.initializers.zeros ) + + + nn.batch_set_value ( [(var, [-5] )] ) + + x = var + + #x = 1.0 / ( 1 + tf.exp(-100*(x-0.9)) ) + x = tf.abs(tf.nn.tanh( x ) )#*100-90 ) + + #x = tf.abs( x ) + + (xg,xv), = nn.gradients(x, [var]) + + r = nn.tf_sess.run( [xg, x] ) + + print(r) + + #cv2.imshow("", (result*255).astype(np.uint8) ) + #cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + + #================== + p = np.float32( [2,1] ) + + pts = np.float32([ [0,0], [1,0], [2,0],[3,0] ]) + a = pts[:-1,:] + b = pts[1:,:] + edges = np.concatenate( ( pts[:-1,None,:], pts[1:,None,:] ), axis=-2) + + pa = p-a + ba = b-a + + h = np.clip( np.einsum('ij,ij->i', pa, ba) / np.einsum('ij,ij->i', ba, ba), 0, 1 ) + + x = npla.norm ( pa - ba*h[...,None], axis=1 ) + np.argmin(x) + + import code + code.interact(local=dict(globals(), **locals())) + + """ + float sdSegment( in vec2 p, in vec2 a, in vec2 b ) + { + vec2 pa = p-a, ba = b-a; + float h = clamp( dot(pa,ba)/dot(ba,ba), 0.0, 1.0 ); + return length( pa - ba*h ); + } + """ + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + tf = nn.tf + tf_sess = nn.tf_sess + + img = cv2.imread(r'D:\DevelopPython\test\00000.png').astype(np.float32) / 255.0 + + + + inp_t = tf.placeholder( tf.float32 , (None,None,None,None) ) + + + + + + """ + weight = tf.constant ( + [ + [0.0, 1.0, 0.0], + [1, -4.0, 1.0 ], + [0.0, 1.0, 0.0], + + ], dtype=tf.float32) + + weight = tf.constant ( + [ [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, -1.0, 2.0, -1.0, 0.0], + [0.0, 2.0, -4.0, 2.0, 0.0], + [0.0, -1.0, 2.0, -1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + + ], dtype=tf.float32) + + weight = tf.constant ( + [ [-1.0, 2.0, -2.0, 2.0, -1.0], + [2.0, -6.0, 8.0, -6.0, 2.0], + [-2.0, 8.0, -12.0, 8.0, -2.0], + [2.0, -6.0, 8.0, -6.0, 2.0], + [-1.0, 2.0, -2.0, 2.0, -1.0], + + ], dtype=tf.float32) + """ + weight = tf.constant ( + [ [0.0, 0.0, -1.0, 0.0, 0.0], + [0.0, -1.0, -2.0, -1.0, 0.0], + [-1.0, -2.0, 16.0,- 2.0, -1.0], + [0.0, -1.0, -2.0, -1.0, 0.0], + [0.0, 0.0, -1.0, 0.0, 0.0], + + ], dtype=tf.float32) + weight = weight [...,None,None] + weight = tf.tile(weight, (1,1,3,1) ) + x = tf.nn.depthwise_conv2d(inp_t, weight, [1,1,1,1], 'SAME', data_format="NHWC") + x = tf.reduce_mean( x, nn.conv2d_ch_axis, keepdims=True ) + x = tf.clip_by_value(x, 0, 1) + + + result = tf_sess.run ( tf.reduce_sum(tf.abs(x)), feed_dict={ inp_t:img[None,...] } ) + print(result) + result = tf_sess.run (x, feed_dict={ inp_t:img[None,...] } ) + + #import code + #code.interact(local=dict(globals(), **locals())) + + while True: + cv2.imshow ("", np.clip(img*255, 0,255).astype(np.uint8) ) + cv2.waitKey(0) + + cv2.imshow ("", np.clip(result[0]*255, 0,255).astype(np.uint8) ) + cv2.waitKey(0) + + + import code + code.interact(local=dict(globals(), **locals())) + + #################### + + + + from core.imagelib import sd + resolution = 256 + while True: + circle_mask = sd.random_circle_faded ([resolution,resolution] ) + + cv2.imshow ("", np.clip(circle_mask*255, 0,255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + """ + img = cv2_imread(r'D:\DeepFaceLabCUDA9.2SSE\workspace\data_src\aligned\XSegDataset\obstructions\1.png').astype(np.float32) / 255.0 + + a = img[...,3:4] + a[a>0] = 1.0 + + #a = cv2.dilate (a, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(4,4)), iterations = 1 ) + #a = cv2.erode (a, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(4,4)), iterations = 1 ) + + cv2.imshow ("", np.clip(a*255, 0,255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + """ + #======================================================== + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + + + generator = SampleGeneratorFaceSkinSegDataset(root_path=Path(r'D:\DeepFaceLabCUDA9.2SSE\workspace\data_src\aligned'), + debug=True, + resolution=256, + face_type=FaceType.WHOLE_FACE, + batch_size=1, + generators_count=1 ) + while True: + img,mask = generator.generate_next() + + + cv2.imshow ("", np.clip(img[0]*255, 0,255).astype(np.uint8) ) + cv2.waitKey(0) + cv2.imshow ("", np.clip(mask[0]*255, 0,255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + #======================================================== + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + tf = nn.tf + tf_sess = nn.tf_sess + + #inp = tf.placehold + resolution = 16 + + def tf_random_1D_subline (len): + + low_bound = tf.random.uniform( (1,), maxval=len, dtype=tf.int32 )[0] + high_bound = low_bound + tf.random.uniform( (1,), maxval=len-low_bound , dtype=tf.int32 )[0] + return tf.range(low_bound, high_bound+1)[0] + + + def tf_random_2D_patches (batch_size, resolution, ch, dtype=None, data_format=None): + if dtype is None: + dtype = tf.float32 + + if data_format is None: + data_format = nn.data_format + + if data_format == "NHWC": + z = tf.zeros( (batch_size,resolution,resolution,ch), dtype=dtype ) + else: + z = tf.zeros( (batch_size,ch,resolution,resolution), dtype=dtype ) + + for i in range(batch_size): + wr = tf_random_1D_subline(resolution) + hr = tf_random_1D_subline(resolution) + + if data_format == "NHWC": + z[i,hr,wr,:] = tf.constant ([1,1,1], dtype=dtype ) + else: + z[i,:,hr,wr] = tf.constant ([1,1,1], dtype=dtype ) + + return z + + x = tf_random_2D_patches (1, 16, 3) + + y = tf_sess.run ( x ) + print (y) + + import code + code.interact(local=dict(globals(), **locals())) + + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + + + training_data_src_path = r'F:\DeepFaceLabCUDA9.2SSE\_internal\pretrain_CelebA' + + generator = SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_src_path, batch_size=1, + sample_process_options=SampleProcessor.Options(random_flip=True), + output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'ct_mode': 'idt', 'face_type':FaceType.FULL, 'data_format':nn.data_format, 'resolution': 256}, + {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':FaceType.FULL, 'data_format':nn.data_format, 'resolution': 256}, + {'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':False , 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE_EYES, 'face_type':FaceType.FULL, 'data_format':nn.data_format, 'resolution': 256}, + ], + generators_count=1 ) + while True: + bgr,bgr_ct,mask = generator.generate_next() + + + cv2.imshow ("", np.clip(bgr[0]*255, 0,255).astype(np.uint8) ) + cv2.waitKey(0) + cv2.imshow ("", np.clip(bgr_ct[0]*255, 0,255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + #================ + + + #in_t = tf.placeholder(tf.float32, (2,3)) + # + # Mask remain positive t values + #tris_inv_t_mask = tf.clip_by_value( tf.sign(tris_inv_t), 0, 1 ) + + # Filter nearest tri + #tris_max_t = tf.reduce_max ( tris_t, axis=-1, keepdims=True) + #tris_inv_t = tris_max_t - tris_t + + # Compute distance clip + # Invert distances, so near get highest value + + #tris_max_t = tf.reduce_max ( tris_t, axis=-1, keepdims=True) + #tris_inv_t = tris_max_t - tris_t + + # Cut distances by uv_cut, so unwanted tris get zero dist + #tris_unwanted_cut = tris_inv_t * tris_uv_clip + + # Highest(near) t becomes 1, otherwise 0 + #tris_dist_clip = tf.sign( tris_unwanted_cut - tf.reduce_max ( tris_unwanted_cut, axis=-1, keepdims=True) )+1 + + # Expand clip dims in order to mult on colors + #x = tris_unwanted_cut[...,None] * tris_dist_clip[...,None] * tris_uv_clip[...,None] * tris_colors + + import code + code.interact(local=dict(globals(), **locals())) + + + #======================= + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + + + training_data_src_path = r'F:\DeepFaceLabCUDA9.2SSE\workspace\data_src\aligned' + + t = SampleProcessor.Types + generator = SampleGeneratorFace(training_data_src_path, batch_size=1, + sample_process_options=SampleProcessor.Options(), + output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'data_format':nn.data_format, 'resolution': 256 } ], + generators_count=1, rnd_seed=0 ) + + while True: + x = generator.generate_next()[0][0] + + cv2.imshow ("", np.clip(x*255, 0,255).astype(np.uint8) ) + cv2.waitKey(0) + + + import code + code.interact(local=dict(globals(), **locals())) + + #======================== + + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + tf = nn.tf + """ + class BilinearInterpolation(KL.Layer): + def __init__(self, size=(2,2), **kwargs): + self.size = size + super(BilinearInterpolation, self).__init__(**kwargs) + + def compute_output_shape(self, input_shape): + return (input_shape[0], input_shape[1]*self.size[1], input_shape[2]*self.size[0], input_shape[3]) + + + def call(self, X): + _,h,w,_ = K.int_shape(X) + + #X = K.concatenate( [ X, X[:,:,-2:-1,:] ],axis=2 ) + #X = K.concatenate( [ X, X[:,:,-2:-1,:] ],axis=2 ) + #X = K.concatenate( [ X, X[:,-2:-1,:,:] ],axis=1 ) + #X = K.concatenate( [ X, X[:,-2:-1,:,:] ],axis=1 ) + + X_sh = K.shape(X) + batch_size, height, width, num_channels = X_sh[0], X_sh[1], X_sh[2], X_sh[3] + + output_h, output_w = (h*self.size[1], w*self.size[0]) + + x_linspace = np.linspace(-1. , 1., output_w)#- 2/output_w + y_linspace = np.linspace(-1. , 1., output_h)# + + x_coordinates, y_coordinates = np.meshgrid(x_linspace, y_linspace) + x_coordinates = K.constant(x_coordinates, dtype=K.floatx() ) + y_coordinates = K.constant(y_coordinates, dtype=K.floatx() ) + + + + x = x_coordinates + y = y_coordinates + + x = .5 * (x + 1.0) * K.cast(width, dtype='float32') + y = .5 * (y + 1.0) * K.cast(height, dtype='float32') + x0 = K.cast(x, 'int32') + x1 = x0 + 1 + y0 = K.cast(y, 'int32') + y1 = y0 + 1 + max_x = int(K.int_shape(X)[2] -1) + max_y = int(K.int_shape(X)[1] -1) + + x0 = K.clip(x0, 0, max_x) + x1 = K.clip(x1, 0, max_x) + y0 = K.clip(y0, 0, max_y) + y1 = K.clip(y1, 0, max_y) + + + pixels_batch = K.constant ( np.arange(0, batch_size) * (height * width), dtype=K.floatx() ) + + pixels_batch = K.expand_dims(pixels_batch, axis=-1) + + base = K.tile(pixels_batch, (1, output_h * output_w ) ) + base = K.flatten(base) + + # base_y0 = base + (y0 * width) + base_y0 = y0 * width + base_y0 = base + base_y0 + # base_y1 = base + (y1 * width) + base_y1 = y1 * width + base_y1 = base_y1 + base + + indices_a = base_y0 + x0 + indices_b = base_y1 + x0 + indices_c = base_y0 + x1 + indices_d = base_y1 + x1 + + flat_image = K.reshape(X, (-1, num_channels) ) + flat_image = K.cast(flat_image, dtype='float32') + pixel_values_a = K.gather(flat_image, indices_a) + pixel_values_b = K.gather(flat_image, indices_b) + pixel_values_c = K.gather(flat_image, indices_c) + pixel_values_d = K.gather(flat_image, indices_d) + + x0 = K.cast(x0, 'float32') + x1 = K.cast(x1, 'float32') + y0 = K.cast(y0, 'float32') + y1 = K.cast(y1, 'float32') + + area_a = K.expand_dims(((x1 - x) * (y1 - y)), 1) + area_b = K.expand_dims(((x1 - x) * (y - y0)), 1) + area_c = K.expand_dims(((x - x0) * (y1 - y)), 1) + area_d = K.expand_dims(((x - x0) * (y - y0)), 1) + + values_a = area_a * pixel_values_a + values_b = area_b * pixel_values_b + values_c = area_c * pixel_values_c + values_d = area_d * pixel_values_d + interpolated_image = values_a + values_b + values_c + values_d + + new_shape = (batch_size, output_h, output_w, num_channels) + interpolated_image = K.reshape(interpolated_image, new_shape) + + #interpolated_image = interpolated_image[:,:-4,:-4,:] + return interpolated_image + + def get_config(self): + config = {"size": self.size} + base_config = super(BilinearInterpolation, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def batch_dot(x, y, axes=None): + if x.ndim < 2 or y.ndim < 2: + raise ValueError('Batch dot requires inputs of rank 2 or more.') + + if isinstance(axes, int): + axes = [axes, axes] + elif isinstance(axes, tuple): + axes = list(axes) + + if axes is None: + if y.ndim == 2: + axes = [x.ndim - 1, y.ndim - 1] + else: + axes = [x.ndim - 1, y.ndim - 2] + + if any([isinstance(a, (list, tuple)) for a in axes]): + raise ValueError('Multiple target dimensions are not supported. ' + + 'Expected: None, int, (int, int), ' + + 'Provided: ' + str(axes)) + + # Handle negative axes + if axes[0] < 0: + axes[0] += x.ndim + if axes[1] < 0: + axes[1] += y.ndim + + if 0 in axes: + raise ValueError('Can not perform batch dot over axis 0.') + + if x.shape[0] != y.shape[0]: + raise ValueError('Can not perform batch dot on inputs' + ' with different batch sizes.') + + d1 = x.shape[axes[0]] + d2 = y.shape[axes[1]] + if d1 != d2: + raise ValueError('Can not do batch_dot on inputs with shapes ' + + str(x.shape) + ' and ' + str(y.shape) + + ' with axes=' + str(axes) + '. x.shape[%d] != ' + 'y.shape[%d] (%d != %d).' % (axes[0], axes[1], d1, d2)) + + result = [] + axes = [axes[0] - 1, axes[1] - 1] # ignore batch dimension + for xi, yi in zip(x, y): + result.append(np.tensordot(xi, yi, axes)) + result = np.array(result) + + if result.ndim == 1: + result = np.expand_dims(result, -1) + + return result + """ + def np_bilinear(X, size ): + batch_size,h,w,num_channels = X.shape + + zero_h_line = np.zeros ( (batch_size,h,1,num_channels) ) + + X = np.concatenate( [ zero_h_line, X ],axis=2 ) + X = np.concatenate( [ zero_h_line, X ],axis=2 ) + X = np.concatenate( [ X, zero_h_line ],axis=2 ) + X = np.concatenate( [ X, zero_h_line ],axis=2 ) + + batch_size,h,w,num_channels = X.shape + zero_w_line = np.zeros ( (batch_size,1,w,num_channels) ) + + X = np.concatenate( [ zero_w_line, X ],axis=1 ) + X = np.concatenate( [ zero_w_line, X ],axis=1 ) + X = np.concatenate( [ X, zero_w_line ],axis=1 ) + X = np.concatenate( [ X, zero_w_line ],axis=1 ) + + #import code + #code.interact(local=dict(globals(), **locals())) + + batch_size,h,w,num_channels = X.shape + + output_w, output_h = size + output_w += 4 + output_h += 4 + + xc = np.linspace(0, w-1, w).astype(X.dtype) + yc = np.linspace(0, h-1, h).astype(X.dtype) + xc,yc = np.meshgrid (xc,yc) + + + #x_linspace = np.linspace(-1., 1., output_w) + #y_linspace = np.linspace(-1. , 1. - 2/output_h, output_h)# + #x_coordinates, y_coordinates = np.meshgrid(x_linspace, y_linspace) + + #x = cv_x = cv2.resize (xc, (output_w,output_h) ) + #y = cv_y = cv2.resize (yc, (output_w,output_h) ) + x = np.linspace(0., w-1, output_w) + y = np.linspace(0., h-1, output_h) + x, y = np.meshgrid(x, y) + + aff = np.array (\ + [ [1,0,0], + [0,1,0], + ]) # 2,3 + + #aff = cv2.getRotationMatrix2D( (0, 0), 60, 1.0) + grids = np.stack ( [x,y,np.ones_like(x)] ).reshape ( (3,output_h*output_w) ) + + sampled_grids = np.dot(aff,grids).reshape ( (2,output_h,output_w) ) + x = sampled_grids[0] + y = sampled_grids[1] + #import code + #code.interact(local=dict(globals(), **locals())) + + + x0 = x.astype(np.int32) + x1 = x0 + 1 + y0 = y.astype(np.int32) + y1 = y0 + 1 + + ind_x0 = np.clip(x0,0,w-1) + ind_x1 = np.clip(x1,0,w-1) + ind_y0 = np.clip(y0,0,h-1) + ind_y1 = np.clip(y1,0,h-1) + + indices_a = ind_y0 * w + ind_x0 + indices_b = ind_y1 * w + ind_x0 + indices_c = ind_y0 * w + ind_x1 + indices_d = ind_y1 * w + ind_x1 + + flat_image = np.reshape(X, (-1, num_channels) ) + + pixel_values_a = np.reshape( flat_image[np.ndarray.flatten (indices_a)], (output_h,output_w,num_channels) ) + pixel_values_b = np.reshape( flat_image[np.ndarray.flatten (indices_b)], (output_h,output_w,num_channels) ) + pixel_values_c = np.reshape( flat_image[np.ndarray.flatten (indices_c)], (output_h,output_w,num_channels) ) + pixel_values_d = np.reshape( flat_image[np.ndarray.flatten (indices_d)], (output_h,output_w,num_channels) ) + + x0 = x0.astype(x.dtype) + x1 = x1.astype(x.dtype) + y0 = y0.astype(y.dtype) + y1 = y1.astype(y.dtype) + + area_a = (x1 - x) * (y1 - y) + area_b = (x1 - x) * (y - y0) + area_c = (x - x0) * (y1 - y) + area_d = (x - x0) * (y - y0) + + values_a = area_a[...,None] * pixel_values_a + values_b = area_b[...,None] * pixel_values_b + values_c = area_c[...,None] * pixel_values_c + values_d = area_d[...,None] * pixel_values_d + + interpolated_image = values_a + values_b + values_c + values_d + + interpolated_image = interpolated_image[2:-2,2:-2,:] + + return interpolated_image + + + #pixel_values_a = K.gather(flat_image, indices_a) + #pixel_values_b = K.gather(flat_image, indices_b) + #pixel_values_c = K.gather(flat_image, indices_c) + #pixel_values_d = K.gather(flat_image, indices_d) + + + + new_shape = (batch_size, output_h, output_w, num_channels) + interpolated_image = K.reshape(interpolated_image, new_shape) + + #interpolated_image = interpolated_image[:,:-4,:-4,:] + return interpolated_image + + filepath = r'D:\DevelopPython\test\00000.png' + img = cv2.imread(filepath).astype(np.float32) / 255.0 + h,w,c = img.shape + + #img = np.random.random ( (4,4,3) ) + + #xc = np.linspace(0, 4-1, 4) + #yc = np.linspace(0, 4-1, 4) + #xc,yc = np.meshgrid (xc,yc) + #img = xc+yc + #img = img[...,None] + + + while True: + random_w = 512#np.random.randint (1,128) + random_h = 512#np.random.randint (1,128) + + np_x = np_bilinear(img[None,...], size=(random_w,random_h)) + cv_x = cv2.resize (img, (random_w,random_h)) + + #import code + #code.interact(local=dict(globals(), **locals())) + + print( np.sum(np.abs(np_x-cv_x)) ) + #import code + #code.interact(local=dict(globals(), **locals())) + + cv2.imshow("", (np_x * 255).astype(np.uint8) ) + cv2.waitKey(0) + + cv2.imshow("", (cv_x * 255).astype(np.uint8) ) + cv2.waitKey(0) + + #import tensorflow as tf + #tf_inp = tf.keras.Input ( (256,256,3) ) + #tf.keras.Input ( ()) + #tf_x = tf.image.resize_images (tf_inp, (512,512) ) + #tf_unc = tf.keras.backend.function([tf_inp],[tf_x]) + #tf_x, = tf_unc ([ img[None,...] ]) + + inp = Input ( (256,256,3) ) + keras_x = BilinearInterpolation() ( inp ) + + func = K.function([inp],[keras_x]) + + + #print (np.sum(np.abs(keras_x-tf_x)) ) + while True: + keras_x, = func ([ img[None,...] ]) + cv2.imshow("", (keras_x[0] * 255).astype(np.uint8) ) + cv2.waitKey(0) + #cv2.imshow("", tf_x[0].astype(np.uint8) ) + #cv2.waitKey(0) + + + #================================ + + src_paths = pathex.get_image_paths(r'F:\DeepFaceLabCUDA9.2SSE\workspace\data_src\aligned') + dst_paths = pathex.get_image_paths(r'F:\DeepFaceLabCUDA9.2SSE\workspace\data_dst\aligned') + + dst_all = None + dst_count = 0 + for path in io.progress_bar_generator (dst_paths, "Computing"): + img = cv2_imread(path).astype(np.float32) / 255.0 + if dst_all is None: + dst_all = img + else: + dst_all += img + + dst_count += 1 + + dst_all /= dst_count + dst_all = np.clip(dst_all, 0, 1) + + + for path in io.progress_bar_generator (src_paths, "Computing"): + img = cv2_imread(path).astype(np.float32) / 255.0 + + ct = imagelib.color_transfer_idt(img, dst_all) + + cv2.imshow ("", np.clip(ct*255, 0,255).astype(np.uint8) ) + cv2.waitKey(0) + + cv2.imshow ("", np.clip(dst_all*255, 0,255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + + lowests = [] + + for src_path in io.progress_bar_generator (src_paths[0:10], "Computing"): + dst_path = dst_paths[np.random.randint(dst_paths_len)] + + src_uint8 = cv2_imread(src_path) + src = src_uint8.astype(np.float32) / 255.0 + + dst_uint8 = cv2_imread(dst_path) + dst = dst_uint8.astype(np.float32) / 255.0 + + src_rct = imagelib.reinhard_color_transfer(src_uint8, dst_uint8).astype(np.float32) / 255.0 + src_lct = np.clip( imagelib.linear_color_transfer (src, dst), 0.0, 1.0 ) + src_mkl = imagelib.color_transfer_mkl (src, dst) + src_idt = imagelib.color_transfer_idt (src, dst) + src_sot = imagelib.color_transfer_sot (src, dst) + + dst_mean = np.mean(dst, axis=(0,1) ) + src_mean = np.mean(src, axis=(0,1) ) + src_rct_mean = np.mean(src_rct, axis=(0,1) ) + src_lct_mean = np.mean(src_lct, axis=(0,1) ) + src_mkl_mean = np.mean(src_mkl, axis=(0,1) ) + src_idt_mean = np.mean(src_idt, axis=(0,1) ) + src_sot_mean = np.mean(src_sot, axis=(0,1) ) + + dst_std = np.sqrt ( np.var(dst, axis=(0,1) ) + 1e-5 ) + src_std = np.sqrt ( np.var(src, axis=(0,1) ) + 1e-5 ) + src_rct_std = np.sqrt ( np.var(src_rct, axis=(0,1) ) + 1e-5 ) + src_lct_std = np.sqrt ( np.var(src_lct, axis=(0,1) ) + 1e-5 ) + src_mkl_std = np.sqrt ( np.var(src_mkl, axis=(0,1) ) + 1e-5 ) + src_idt_std = np.sqrt ( np.var(src_idt, axis=(0,1) ) + 1e-5 ) + src_sot_std = np.sqrt ( np.var(src_sot, axis=(0,1) ) + 1e-5 ) + + def_mean_sum = np.sum( np.square(src_mean-dst_mean) ) + rct_mean_sum = np.sum( np.square(src_rct_mean-dst_mean) ) + lct_mean_sum = np.sum( np.square(src_lct_mean-dst_mean) ) + mkl_mean_sum = np.sum( np.square(src_mkl_mean-dst_mean) ) + idt_mean_sum = np.sum( np.square(src_idt_mean-dst_mean) ) + sot_mean_sum = np.sum( np.square(src_sot_mean-dst_mean) ) + + def_std_sum = np.sum( np.square(src_std-dst_std) ) + rct_std_sum = np.sum( np.square(src_rct_std-dst_std) ) + lct_std_sum = np.sum( np.square(src_lct_std-dst_std) ) + mkl_std_sum = np.sum( np.square(src_mkl_std-dst_std) ) + idt_std_sum = np.sum( np.square(src_idt_std-dst_std) ) + sot_std_sum = np.sum( np.square(src_sot_std-dst_std) ) + + lowests.append([ def_mean_sum+def_std_sum, + rct_mean_sum+rct_std_sum, + lct_mean_sum+lct_std_sum, + mkl_mean_sum+mkl_std_sum, + idt_mean_sum+idt_std_sum, + sot_mean_sum+sot_std_sum + ]) + + #cv2.imshow("", src_rct ) + #cv2.waitKey(0) + + + np.mean(np.array(lowests), 0) + + + #========================================== + + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + tf = nn.tf + + + class ConvBlock(nn.ModelBase): + + def on_build(self, in_planes, out_planes): + self.in_planes = in_planes + self.out_planes = out_planes + + self.bn1 = nn.BatchNorm2D(in_planes) + self.conv1 = nn.Conv2D (in_planes, out_planes//2, kernel_size=3, strides=1, padding='SAME', use_bias=False ) + + self.bn2 = nn.BatchNorm2D(out_planes//2) + self.conv2 = nn.Conv2D (out_planes//2, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) + + self.bn3 = nn.BatchNorm2D(out_planes//4) + self.conv3 = nn.Conv2D (out_planes//4, out_planes//4, kernel_size=3, strides=1, padding='SAME', use_bias=False ) + + if self.in_planes != self.out_planes: + self.down_bn1 = nn.BatchNorm2D(in_planes) + self.down_conv1 = nn.Conv2D (in_planes, out_planes, kernel_size=1, strides=1, padding='VALID', use_bias=False ) + else: + self.down_bn1 = None + self.down_conv1 = None + + def forward(self, input): + x = input + x = self.bn1(x) + x = tf.nn.relu(x) + x = out1 = self.conv1(x) + + x = self.bn2(x) + x = tf.nn.relu(x) + x = out2 = self.conv2(x) + + x = self.bn3(x) + x = tf.nn.relu(x) + x = out3 = self.conv3(x) + x = tf.concat ([out1, out2, out3], axis=-1) + + if self.in_planes != self.out_planes: + downsample = self.down_bn1(input) + downsample = tf.nn.relu (downsample) + downsample = self.down_conv1 (downsample) + x = x + downsample + else: + x = x + input + + return x + + class HourGlass (nn.ModelBase): + def on_build(self, in_planes, depth): + self.b1 = ConvBlock (in_planes, 256) + self.b2 = ConvBlock (in_planes, 256) + + if depth > 1: + self.b2_plus = HourGlass(256, depth-1) + else: + self.b2_plus = ConvBlock(256, 256) + + self.b3 = ConvBlock(256, 256) + + def forward(self, input): + up1 = self.b1(input) + + low1 = tf.nn.avg_pool(input, [1,2,2,1], [1,2,2,1], 'VALID') + low1 = self.b2 (low1) + + low2 = self.b2_plus(low1) + low3 = self.b3(low2) + + up2 = nn.upsample2d(low3) + + return up1+up2 + + class FAN (nn.ModelBase): + def __init__(self): + super().__init__(name='FAN') + + def on_build(self): + self.conv1 = nn.Conv2D (3, 64, kernel_size=7, strides=2, padding='SAME') + self.bn1 = nn.BatchNorm2D(64) + + self.conv2 = ConvBlock(64, 128) + self.conv3 = ConvBlock(128, 128) + self.conv4 = ConvBlock(128, 256) + + self.m = [] + self.top_m = [] + self.conv_last = [] + self.bn_end = [] + self.l = [] + self.bl = [] + self.al = [] + for i in range(4): + self.m += [ HourGlass(256, 4) ] + self.top_m += [ ConvBlock(256, 256) ] + + self.conv_last += [ nn.Conv2D (256, 256, kernel_size=1, strides=1, padding='VALID') ] + self.bn_end += [ nn.BatchNorm2D(256) ] + + self.l += [ nn.Conv2D (256, 68, kernel_size=1, strides=1, padding='VALID') ] + + if i < 4-1: + self.bl += [ nn.Conv2D (256, 256, kernel_size=1, strides=1, padding='VALID') ] + self.al += [ nn.Conv2D (68, 256, kernel_size=1, strides=1, padding='VALID') ] + + def forward(self, x) : + x = self.conv1(x) + x = self.bn1(x) + x = tf.nn.relu(x) + + x = self.conv2(x) + x = tf.nn.avg_pool(x, [1,2,2,1], [1,2,2,1], 'VALID') + x = self.conv3(x) + x = self.conv4(x) + + + outputs = [] + previous = x + for i in range(4): + ll = self.m[i] (previous) + + ll = self.top_m[i] (ll) + + + ll = self.conv_last[i] (ll) + + ll = self.bn_end[i] (ll) + ll = tf.nn.relu(ll) + + tmp_out = self.l[i](ll) + outputs.append(tmp_out) + + if i < 4 - 1: + ll = self.bl[i](ll) + previous = previous + ll + self.al[i](tmp_out) + return outputs[-1] + + rnd_data = np.random.uniform (size=(1,3,256,256)).astype(np.float32) + rnd_data = np.ones ((1,3,256,256)).astype(np.float32) + rnd_data_tf = np.transpose(rnd_data, (0,2,3,1) ) + + rnd_data_tf = cv2.imread ( r"D:\DevelopPython\test\00000.png" ).astype(np.float32) / 255.0 + rnd_data_tf = rnd_data_tf[None,...] + rnd_data = np.transpose(rnd_data_tf, (0,3,1,2) ) + + + + + import torch + import face_alignment + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D,device='cpu').face_alignment_net + fa.eval() + + #transfer weights + def convd2d_from_torch(torch_layer): + result = [ torch_layer.weight.data.numpy().transpose(2,3,1,0) ] + if torch_layer.bias is not None: + result += [ torch_layer.bias.data.numpy() ] + return result + + def bn2d_from_torch(torch_layer): + return [ torch_layer.weight.data.numpy(), + torch_layer.bias.data.numpy(), + torch_layer.running_mean.data.numpy(), + torch_layer.running_var.data.numpy(), + ] + + def transfer_conv_block(dst,src): + dst.bn1.set_weights ( bn2d_from_torch(src.bn1) ) + dst.conv1.set_weights ( convd2d_from_torch(src.conv1) ) + dst.bn2.set_weights ( bn2d_from_torch(src.bn2) ) + dst.conv2.set_weights ( convd2d_from_torch(src.conv2) ) + dst.bn3.set_weights ( bn2d_from_torch(src.bn3) ) + dst.conv3.set_weights ( convd2d_from_torch(src.conv3) ) + + if dst.down_bn1 is not None: + dst.down_bn1.set_weights ( bn2d_from_torch(src.downsample[0]) ) + dst.down_conv1.set_weights ( convd2d_from_torch(src.downsample[2]) ) + + def transfer_hourglass(dst, src, level): + + transfer_conv_block (dst.b1, getattr (src, f'b1_{level}' ) ) + transfer_conv_block (dst.b2, getattr (src, f'b2_{level}' ) ) + + if level > 1: + transfer_hourglass (dst.b2_plus, src, level-1) + else: + transfer_conv_block (dst.b2_plus, getattr (src, f'b2_plus_{level}' ) ) + + transfer_conv_block (dst.b3, getattr (src, f'b3_{level}' ) ) + + + with tf.device("/CPU:0"): + FAN = FAN() + #FAN.load_weights(r"D:\DevelopPython\test\2DFAN-4.npy") + + FAN.build() + FAN.conv1.set_weights ( convd2d_from_torch(fa.conv1) ) + FAN.bn1.set_weights ( bn2d_from_torch(fa.bn1) ) + + transfer_conv_block(FAN.conv2, fa.conv2) + transfer_conv_block(FAN.conv3, fa.conv3) + transfer_conv_block(FAN.conv4, fa.conv4) + + for i in range(4): + transfer_hourglass(FAN.m[i], getattr(fa, f'm{i}'), 4) + transfer_conv_block(FAN.top_m[i], getattr(fa, f'top_m_{i}')) + + FAN.conv_last[i].set_weights ( convd2d_from_torch( getattr(fa, f'conv_last{i}') ) ) + FAN.bn_end[i].set_weights ( bn2d_from_torch( getattr(fa, f'bn_end{i}') ) ) + FAN.l[i].set_weights ( convd2d_from_torch( getattr(fa, f'l{i}') ) ) + + if i < 4-1: + FAN.bl[i].set_weights ( convd2d_from_torch( getattr(fa, f'bl{i}') ) ) + FAN.al[i].set_weights ( convd2d_from_torch( getattr(fa, f'al{i}') ) ) + + FAN.save_weights(r"D:\DevelopPython\test\3DFAN-4.npy") + + #import code + #code.interact(local=dict(globals(), **locals())) + + + + def transform(point, center, scale, resolution): + pt = np.array ( [point[0], point[1], 1.0] ) + h = 200.0 * scale + m = np.eye(3) + m[0,0] = resolution / h + m[1,1] = resolution / h + m[0,2] = resolution * ( -center[0] / h + 0.5 ) + m[1,2] = resolution * ( -center[1] / h + 0.5 ) + m = np.linalg.inv(m) + return np.matmul (m, pt)[0:2] + + def get_pts_from_predict(a, center, scale): + a_ch, a_h, a_w = a.shape + + b = a.reshape ( (a_ch, a_h*a_w) ) + c = b.argmax(1).reshape ( (a_ch, 1) ).repeat(2, axis=1).astype(np.float) + c[:,0] %= a_w + c[:,1] = np.apply_along_axis ( lambda x: np.floor(x / a_w), 0, c[:,1] ) + + for i in range(a_ch): + pX, pY = int(c[i,0]), int(c[i,1]) + if pX > 0 and pX < 63 and pY > 0 and pY < 63: + diff = np.array ( [a[i,pY,pX+1]-a[i,pY,pX-1], a[i,pY+1,pX]-a[i,pY-1,pX]] ) + c[i] += np.sign(diff)*0.25 + + c += 0.5 + + return np.array( [ transform (c[i], center, scale, a_w) for i in range(a_ch) ] ) + + + tf_FAN_in = tf.placeholder(tf.float32, (1,256,256, 3)) + tf_FAN_out = FAN(tf_FAN_in) + + tf_x = nn.tf_sess.run(tf_FAN_out, feed_dict={tf_FAN_in:rnd_data_tf} )[0] + + fa_out_tensor = fa( torch.autograd.Variable( torch.from_numpy(rnd_data), volatile=True) )[-1][0].data.cpu() + torch_x = fa_out_tensor.numpy() + torch_x = np.transpose (torch_x, (1,2,0)) + + diff = np.mean(np.abs(tf_x-torch_x)) + print (f"diff = {diff}") + + tf_p = get_pts_from_predict(tf_x, [127.0,127.0], 1.0) + torchp = get_pts_from_predict(torch_x, [127.0,127.0], 1.0) + + import code + code.interact(local=dict(globals(), **locals())) + + #======================== + + ct_1_filepath = r'E:\FakeFaceVideoSources\Datasets\CelebA\aligned_def\aligned\00001.jpg' + ct_1_img = cv2.imread(ct_1_filepath).astype(np.float32) / 255.0 + ct_1_img_shape = ct_1_img.shape + ct_1_dflimg = DFLJPG.load ( ct_1_filepath) + + + face_mat = LandmarksProcessor.get_transform_mat( ct_1_dflimg.get_landmarks(), 256, FaceType.HEAD) + + import code + code.interact(local=dict(globals(), **locals())) + + + + def channel_hist_match(source, template, hist_match_threshold=255, mask=None): + # Code borrowed from: + # https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x + masked_source = source + masked_template = template + + if mask is not None: + masked_source = source * mask + masked_template = template * mask + + oldshape = source.shape + source = source.ravel() + template = template.ravel() + masked_source = masked_source.ravel() + masked_template = masked_template.ravel() + s_values, bin_idx, s_counts = np.unique(source, return_inverse=True, + return_counts=True) + t_values, t_counts = np.unique(template, return_counts=True) + + s_quantiles = np.cumsum(s_counts).astype(np.float64) + s_quantiles = hist_match_threshold * s_quantiles / s_quantiles[-1] + t_quantiles = np.cumsum(t_counts).astype(np.float64) + t_quantiles = 255 * t_quantiles / t_quantiles[-1] + interp_t_values = np.interp(s_quantiles, t_quantiles, t_values) + + return interp_t_values[bin_idx].reshape(oldshape) + + img = cv2.imread(r'D:\DevelopPython\test\ct_src.jpg').astype(np.float32) / 255.0 + + while True: + + np_rnd = np.random.rand + + + + inBlack = np.array([np_rnd()*0.25 , np_rnd()*0.25 , np_rnd()*0.25], dtype=np.float32) + inWhite = np.array([1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25, 1.0-np_rnd()*0.25], dtype=np.float32) + inGamma = np.array([0.5+np_rnd(), 0.5+np_rnd(), 0.5+np_rnd()], dtype=np.float32) + outBlack = np.array([0.0, 0.0, 0.0], dtype=np.float32) + outWhite = np.array([1.0, 1.0, 1.0], dtype=np.float32) + + img2 = ( ( (img - inBlack) / (inWhite - inBlack) ) ** (1/inGamma) ) * (outWhite - outBlack) + outBlack + img2 = np.clip(img2, 0, 1) + + + #cv2.imshow("", img) + #cv2.waitKey(0) + cv2.imshow("", (img2*255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + + """ + inBlack = 23.0 + inWhite = 190.0 + inGamma = 1.61 + outBlack = 0.0 + outWhite = 255.0 + vec3 inPixel = source.rgb; + vec3 outPixel = (pow(((inPixel * 255.0) - vec3(inBlack)) / (inWhite - inBlack), vec3(inGamma)) * (outWhite - outBlack) + outBlack) / 255.0; + + + + lut_in = [0, 127, 255] + lut_out = [50, 127, 255] + lut_8u = np.interp(np.arange(0, 256), lut_in, lut_out).astype(np.uint8) + img2 = cv2.LUT(img, lut_8u) + + s = img.ravel() + s_values, bin_idx, s_counts = np.unique(s, return_inverse=True, return_counts=True) + s_quantiles = np.cumsum(s_counts).astype(np.float64) + s_quantiles = 255 * s_quantiles / s_quantiles[-1] + + interp_t_values = np.interp(s_quantiles, s_quantiles, s_values) + + d = s_quantiles[bin_idx].reshape(s.shape) + + image_histogram, bins = np.histogram(s, 255, density=True) + """ + + import code + code.interact(local=dict(globals(), **locals())) + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.CPU() ) + tf = nn.tf + + img = cv2.imread(r'D:\DevelopPython\test\mask_0.png')[...,0:1].astype(np.float32) / 255.0 + + t = time.time() + + ero_k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(15,15)) + + cv_erode = cv2.erode(img, ero_k, iterations = 1 ) + print(f"time {time.time() - t}") + + + inp_t = tf.placeholder( tf.float32, (None,None,None,None) ) + + eroded_t = tf.nn.erosion2d(inp_t, ero_k[...,None].astype(np.float32), strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME") + eroded_t = eroded_t - tf.ones_like(inp_t) + + t = time.time() + tf_erode = nn.tf_sess.run (eroded_t , feed_dict={inp_t: img[None,...] } ) + print(f"time {time.time() - t}") + + while True: + cv2.imshow("", (cv_erode*255).astype(np.uint8)) + cv2.waitKey(0) + cv2.imshow("", (tf_erode[0]*255).astype(np.uint8)) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + + + src_paths = pathex.get_image_paths(r'F:\DeepFaceLabCUDA9.2SSE\workspace\data_src\aligned') + dst_paths = pathex.get_image_paths(r'F:\DeepFaceLabCUDA9.2SSE\workspace\data_dst\aligned') + + best_ct = CTComputerSubprocessor(src_paths, dst_paths).run() + + print(f"best ct_mode is >> {best_ct} <<") + import code + code.interact(local=dict(globals(), **locals())) + + + + lowests = [] + + for src_path in io.progress_bar_generator (src_paths[0:10], "Computing"): + dst_path = dst_paths[np.random.randint(dst_paths_len)] + + src_uint8 = cv2_imread(src_path) + src = src_uint8.astype(np.float32) / 255.0 + + dst_uint8 = cv2_imread(dst_path) + dst = dst_uint8.astype(np.float32) / 255.0 + + src_rct = imagelib.reinhard_color_transfer(src_uint8, dst_uint8).astype(np.float32) / 255.0 + src_lct = np.clip( imagelib.linear_color_transfer (src, dst), 0.0, 1.0 ) + src_mkl = imagelib.color_transfer_mkl (src, dst) + src_idt = imagelib.color_transfer_idt (src, dst) + src_sot = imagelib.color_transfer_sot (src, dst) + + dst_mean = np.mean(dst, axis=(0,1) ) + src_mean = np.mean(src, axis=(0,1) ) + src_rct_mean = np.mean(src_rct, axis=(0,1) ) + src_lct_mean = np.mean(src_lct, axis=(0,1) ) + src_mkl_mean = np.mean(src_mkl, axis=(0,1) ) + src_idt_mean = np.mean(src_idt, axis=(0,1) ) + src_sot_mean = np.mean(src_sot, axis=(0,1) ) + + dst_std = np.sqrt ( np.var(dst, axis=(0,1) ) + 1e-5 ) + src_std = np.sqrt ( np.var(src, axis=(0,1) ) + 1e-5 ) + src_rct_std = np.sqrt ( np.var(src_rct, axis=(0,1) ) + 1e-5 ) + src_lct_std = np.sqrt ( np.var(src_lct, axis=(0,1) ) + 1e-5 ) + src_mkl_std = np.sqrt ( np.var(src_mkl, axis=(0,1) ) + 1e-5 ) + src_idt_std = np.sqrt ( np.var(src_idt, axis=(0,1) ) + 1e-5 ) + src_sot_std = np.sqrt ( np.var(src_sot, axis=(0,1) ) + 1e-5 ) + + def_mean_sum = np.sum( np.square(src_mean-dst_mean) ) + rct_mean_sum = np.sum( np.square(src_rct_mean-dst_mean) ) + lct_mean_sum = np.sum( np.square(src_lct_mean-dst_mean) ) + mkl_mean_sum = np.sum( np.square(src_mkl_mean-dst_mean) ) + idt_mean_sum = np.sum( np.square(src_idt_mean-dst_mean) ) + sot_mean_sum = np.sum( np.square(src_sot_mean-dst_mean) ) + + def_std_sum = np.sum( np.square(src_std-dst_std) ) + rct_std_sum = np.sum( np.square(src_rct_std-dst_std) ) + lct_std_sum = np.sum( np.square(src_lct_std-dst_std) ) + mkl_std_sum = np.sum( np.square(src_mkl_std-dst_std) ) + idt_std_sum = np.sum( np.square(src_idt_std-dst_std) ) + sot_std_sum = np.sum( np.square(src_sot_std-dst_std) ) + + lowests.append([ def_mean_sum+def_std_sum, + rct_mean_sum+rct_std_sum, + lct_mean_sum+lct_std_sum, + mkl_mean_sum+mkl_std_sum, + idt_mean_sum+idt_std_sum, + sot_mean_sum+sot_std_sum + ]) + + #cv2.imshow("", src_rct ) + #cv2.waitKey(0) + + + np.mean(np.array(lowests), 0) + + import code + code.interact(local=dict(globals(), **locals())) + + img = cv2.imread(r'D:\DevelopPython\test\ct_src.jpg').astype(np.float32) / 255.0 + img2 = cv2.imread(r'D:\DevelopPython\test\ct_trg.jpg').astype(np.float32) / 255.0 + #img = img[...,::-1] + #img2 = img2[...,::-1] + def clr(source,target): + rgb_s = source.reshape ( (-1,3) ) + rgb_t = target.reshape ( (-1,3) ) + + mean_s = np.mean(rgb_s, 0) + mean_t = np.mean(rgb_t, 0) + + cov_s = np.cov( rgb_s.T ) + cov_t = np.cov( rgb_t.T ) + + U_s, A_s, _ = np.linalg.svd(cov_s) + U_t, A_t, _ = np.linalg.svd(cov_t) + + rgbh_s = np.concatenate ( [rgb_s, np.ones( (rgb_s.shape[0], 1), dtype=np.float32)], -1 ) + T_t = np.eye(4) + T_t[0:3,3] = mean_t + T_s = np.eye(4) + T_s[0:3,3] = -mean_s + + R_t = scipy.linalg.block_diag(U_t, 1) + R_s = scipy.linalg.block_diag(np.linalg.inv(U_s), 1) + + S_t = scipy.linalg.block_diag ( np.diag( A_t ** (0.5) ), 1) + S_s = scipy.linalg.block_diag ( np.diag( A_s ** (-0.5) ), 1) + + rgbh_e = np.dot(np.dot(np.dot(np.dot(np.dot(np.dot(T_t, R_t),S_t),S_s),R_s),T_s),rgbh_s.T) + + result = rgbh_e.T[...,0:3].reshape(source.shape ) + result = np.clip(result, 0, 1) + return result + import code + code.interact(local=dict(globals(), **locals())) + + c2 = clr(img,img2) + + from core.imagelib import color_transfer_mkl + c = color_transfer_mkl(img,img2) + + + cv2.imshow("", (img*255).astype(np.uint8) ) + cv2.waitKey(0) + cv2.imshow("", (img2*255).astype(np.uint8) ) + cv2.waitKey(0) + while True: + cv2.imshow("", (c*255).astype(np.uint8) ) + cv2.waitKey(0) + cv2.imshow("", (c2*255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.CPU() ) + tf = nn.tf + + """ + def tf_channel_histogram (input, bins, data_range): + range_min, range_max = data_range + + bin_range = (range_max-range_min) / (bins-1) + reduce_axes = [*range(input.shape.ndims)][1:] + ones_mask = tf.ones_like(input) + zero_mask = tf.zeros_like(input) + + x = input + x += bin_range + + output = [] + + for i in range(bins, 0, -1): + cond = tf.greater_equal(x, i*bin_range ) + x_ones = tf.where (cond, ones_mask, zero_mask ) + x_zeros = tf.where (cond, zero_mask, ones_mask ) + x = x * x_zeros + output.append ( tf.expand_dims(tf.reduce_sum (x_ones, axis=reduce_axes ), -1) ) + + return tf.concat(output[::-1],-1) + """ + def channel_hist_match(source, template, hist_match_threshold=255, mask=None): + masked_source = source + masked_template = template + + if mask is not None: + masked_source = source * mask + masked_template = template * mask + + oldshape = source.shape + source = source.ravel() + template = template.ravel() + masked_source = masked_source.ravel() + masked_template = masked_template.ravel() + s_values, bin_idx, s_counts = np.unique(source, return_inverse=True, + return_counts=True) + t_values, t_counts = np.unique(template, return_counts=True) + + import code + code.interact(local=dict(globals(), **locals())) + s_quantiles = np.cumsum(s_counts).astype(np.float64) + s_quantiles = hist_match_threshold * s_quantiles / s_quantiles[-1] + t_quantiles = np.cumsum(t_counts).astype(np.float64) + t_quantiles = 255 * t_quantiles / t_quantiles[-1] + interp_t_values = np.interp(s_quantiles, t_quantiles, t_values) + + return interp_t_values[bin_idx].reshape(oldshape) + + def tf_channel_histogram (input, bins, data_range): + range_min, range_max = data_range + bin_range = (range_max-range_min) / (bins-1) + reduce_axes = [*range(input.shape.ndims)][1:] + x = input + x += bin_range/2 + output = [] + for i in range(bins-1, -1, -1): + y = x - (i*bin_range) + ones_mask = tf.sign( tf.nn.relu(y) ) + x = x * (1.0 - ones_mask) + output.append ( tf.expand_dims(tf.reduce_sum (ones_mask, axis=reduce_axes ), -1) ) + return tf.concat(output[::-1],-1) + + def tf_histogram(input, bins=256, data_range=(0,1.0)): + return tf.concat ( [tf.expand_dims( tf_channel_histogram( input[...,i], bins=bins, data_range=data_range ), -1 ) for i in range(input.shape[-1])], -1 ) + + img = cv2.imread(r'D:\DevelopPython\test\00000.png')#.astype(np.float32) / 255.0 + img2 = cv2.imread(r'D:\DevelopPython\test\00004.jpg')#.astype(np.float32) + + x = channel_hist_match(img,img2) + import code + code.interact(local=dict(globals(), **locals())) + nph = np.histogram(img[...,0], bins=256, range=(0,1.0) ) + + inp_t = tf.placeholder( tf.float32, (None,None,None) ) + hist_t = tf_channel_histogram(inp_t, bins=256, data_range=(0,1.0) ) + #hist_t = tf_histogram(inp_t, bins=256, data_range=(0,1.0) ) + + tfh = nn.tf_sess.run (hist_t , feed_dict={inp_t: img[None,...,0] } ) + + import code + code.interact(local=dict(globals(), **locals())) + + from core.leras import nn + nn.initialize_main_env() + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + tf = nn.tf + + def tf_suppress_half_mean(t, eps=0.00001): + if t.shape.ndims != 1: + raise ValueError("tf_suppress_half_mean: t rank must be 1") + t_mean_eps = tf.reduce_mean(t) - eps + q = tf.clip_by_value(t, t_mean_eps, tf.reduce_max(t) ) + q = tf.clip_by_value(q-t_mean_eps, 0, eps) + q = q * (t/eps) + return q + + inp = tf.placeholder( tf.float32, (None,) ) + res = tf_suppress_half_mean(inp) + + x = nn.tf_sess.run (res , feed_dict={inp: np.array([1,2,3,4]) } ) + print(x) + import code + code.interact(local=dict(globals(), **locals())) + + from core.leras import nn + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + tf = nn.tf + + + img = cv2.imread ( r"D:\DevelopPython\test\images\96x96_0.png" ).astype(np.float32) / 255.0 + from facelib import FaceEnhancer + fe = FaceEnhancer() + img_enh = fe.enhance(img, preserve_size=True) + + while True: + cv2.imshow ("", (img*255).astype(np.uint8) ) + cv2.waitKey(0) + + cv2.imshow ("", (img_enh*255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + + + + + def np_gen_ca(shape, dtype=np.float32, eps_std=0.05): + """ + Super fast implementation of Convolution Aware Initialization for 4D shapes + Convolution Aware Initialization https://arxiv.org/abs/1702.06295 + """ + if len(shape) != 4: + raise ValueError("only shape with rank 4 supported.") + + row, column, stack_size, filters_size = shape + + fan_in = stack_size * (row * column) + + kernel_shape = (row, column) + + kernel_fft_shape = np.fft.rfft2(np.zeros(kernel_shape)).shape + + basis_size = np.prod(kernel_fft_shape) + if basis_size == 1: + x = np.random.normal( 0.0, eps_std, (filters_size, stack_size, basis_size) ).astype(dtype) + else: + nbb = stack_size // basis_size + 1 + + x = np.random.normal(0.0, 1.0, (filters_size, nbb, basis_size, basis_size)).astype(dtype) + + x = x + np.transpose(x, (0,1,3,2) ) * (1-np.eye(basis_size)) + + u, _, v = np.linalg.svd(x) + x = np.transpose(u, (0,1,3,2) ) + + x = np.reshape(x, (filters_size, -1, basis_size) ) + x = x[:,:stack_size,:] + + x = np.reshape(x, ( (filters_size,stack_size,) + kernel_fft_shape ) ) + + x = np.fft.irfft2( x, kernel_shape ) \ + + np.random.normal(0, eps_std, (filters_size,stack_size,)+kernel_shape).astype(dtype) + + x = x * np.sqrt( (2/fan_in) / np.var(x) ) + x = np.transpose( x, (2, 3, 1, 0) ) + return x + + from core.leras import nn + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + tf = nn.tf + from tensorflow.python.ops import init_ops + + class CAInitializer (init_ops.Initializer): + def __init__(self, eps_std=0.05): + self.eps_std = eps_std + + def gen_ca_4d_func(self, dtype=np.float32): + + def func(shape): + """ + Super fast implementation of Convolution Aware Initialization for 4D shapes + Convolution Aware Initialization https://arxiv.org/abs/1702.06295 + """ + if len(shape) != 4: + raise ValueError("only shape with rank 4 supported.") + + row, column, stack_size, filters_size = shape + + fan_in = stack_size * (row * column) + + kernel_shape = (row, column) + + kernel_fft_shape = np.fft.rfft2(np.zeros(kernel_shape)).shape + + basis_size = np.prod(kernel_fft_shape) + if basis_size == 1: + x = np.random.normal( 0.0, self.eps_std, (filters_size, stack_size, basis_size) ) + else: + nbb = stack_size // basis_size + 1 + x = np.random.normal(0.0, 1.0, (filters_size, nbb, basis_size, basis_size)) + x = x + np.transpose(x, (0,1,3,2) ) * (1-np.eye(basis_size)) + u, _, v = np.linalg.svd(x) + x = np.transpose(u, (0,1,3,2) ) + x = np.reshape(x, (filters_size, -1, basis_size) ) + x = x[:,:stack_size,:] + + x = np.reshape(x, ( (filters_size,stack_size,) + kernel_fft_shape ) ) + + x = np.fft.irfft2( x, kernel_shape ) \ + + np.random.normal(0, self.eps_std, (filters_size,stack_size,)+kernel_shape) + + x = x * np.sqrt( (2/fan_in) / np.var(x) ) + x = np.transpose( x, (2, 3, 1, 0) ) + return x.astype(dtype) + return func + + def __call__(self, shape, dtype=None, partition_info=None): + return tf.py_func( self.gen_ca_4d_func(dtype.as_numpy_dtype), [shape], dtype ) + + import code + code.interact(local=dict(globals(), **locals())) + + + op = CAInitializer()( (3,3,128,128), tf.float32 ) + tf_ca = nn.tf_sess.run(op) + + import code + code.interact(local=dict(globals(), **locals())) + + shape = (1,1,1024,1024) + + #t = time.time() + #np_ca = CAGenerateWeights(shape, np.float32, 'channels_last', eps_std=0.05) + #print(f"time = {time.time() -t}") + + t = time.time() + np_ca2 = np_gen_ca(shape, np.float32, eps_std=0.05) + print(f"time = {time.time() -t}") + + #input = tf.placeholder(tf.float32, (4,) ) + + + + #y = tf.py_func(my_func, [input], tf.float32) + + + #import code + #code.interact(local=dict(globals(), **locals())) + + #with tf.device("/GPU:0"): + + + + + + from core.leras import nn + nn.initialize( device_config=nn.DeviceConfig.WorstGPU() ) + tf = nn.tf + + shape = (3,3,64,128) + + fan_in = shape[-2] * np.prod( shape[:-2] ) + fan_out = shape[-1] * np.prod( shape[:-2] ) + variance = 2 / fan_in + + row, column, in_ch, out_ch = shape + + transpose_dimensions = (2, 3, 1, 0) + kernel_shape = (row, column) + correct_ifft = np.fft.irfft2 + correct_fft = np.fft.rfft2 + + eps_std = 0.05 + floatx = np.float32 + + + """ + + a = np.array ( [ [ [1,2], [3,4] ], + [ [1,2], [3,4] ], + [ [1,2], [3,4] ], + [ [1,2], [3,4] ] + ]) + import code + code.interact(local=dict(globals(), **locals())) + + x, = nn.tf_sess.run( [ tf.spectral.rfft2d( tf.ones( (3,3) ) ) ] ) + + a = np.array( [ np.complex64(v.real) for v in np.ndarray.flatten(x) ] ).reshape (x.shape) + a_r = np.array( [v.real for v in np.ndarray.flatten(x) ] ).reshape (x.shape) + + a_p = tf.placeholder ( tf.complex64, (3,2) ) + + y, = nn.tf_sess.run( [ tf.spectral.irfft2d(a_p) ], feed_dict={a_p:a} ) + + import code + code.interact(local=dict(globals(), **locals())) + """ + """ + import code + code.interact(local=dict(globals(), **locals())) + + for i in range(nbb): + a = tf.random.normal( (size, size), 0.0, 1.0, dtype=dtype ) + a = a + tf.transpose(a) - tf.linalg.diag(tf.linalg.diag_part(a)) + + + + + + + + s, u, v = tf.linalg.svd(a) + + import code + code.interact(local=dict(globals(), **locals())) + + li.append ( tf.transpose(u) ) + + return tf.concat(li, 0)[:filters, :] + """ + + + from tensorflow.python.ops import init_ops + + class ConvolutionAwareInitializer(init_ops.Initializer): + """ + Tensorflow initializer implementation of Convolution Aware Initialization + https://arxiv.org/pdf/1702.06295.pdf + """ + def __init__(self, eps_std=0.05): + self.eps_std = eps_std + + def __call__(self, shape, dtype=tf.float32): + if len(shape) != 4: + raise ValueError("only shape with rank 4 supported.") + + row, column, stack_size, filters_size = shape + + fan_in = stack_size * (row * column) + + kernel_shape = (row, column) + + kernel_fft_shape = np.fft.rfft2(np.zeros(kernel_shape)).shape + + basis_size = np.prod(kernel_fft_shape) + if basis_size == 1: + x = tf.random.normal( (filters_size, stack_size, basis_size), 0.0, self.eps_std, dtype=dtype ) + else: + nbb = stack_size // basis_size + 1 + + x = tf.random.normal( (filters_size, nbb, basis_size, basis_size), 0.0, 1.0, dtype=dtype ) + x = x + tf.transpose(x, (0,1,3,2) ) - tf.linalg.diag(tf.linalg.diag_part(x)) + s, u, v = tf.linalg.svd(x) + x = tf.transpose(u, (0,1,3,2) ) + + x = tf.reshape(x, (filters_size, -1, basis_size) ) + x = x[:,:stack_size,:] + + x = tf.reshape(x, ( (filters_size,stack_size,) + kernel_fft_shape ) ) + + x = tf.spectral.irfft2d( tf.complex(x, tf.zeros_like(x) ), kernel_shape ) \ + + tf.random.normal( (filters_size,stack_size,)+kernel_shape, 0, self.eps_std) + + x_variance = tf.reduce_mean( tf.square(x - tf.reduce_mean(x, keepdims=True) ) ) + + x = x * tf.sqrt( (2/fan_in) / x_variance ) + x = tf.transpose( x, (2, 3, 1, 0) ) + return x + + import code + code.interact(local=dict(globals(), **locals())) + + + + + + #return array_ops.ones(shape, dtype) + + + np_ca = CAGenerateWeights(shape, np.float32, 'channels_last', eps_std=0.05) + + + #input = tf.placeholder(tf.float32, (4,) ) + + def my_func(): + pass + + #y = tf.py_func(my_func, [input], tf.float32) + + + #import code + #code.interact(local=dict(globals(), **locals())) + + #with tf.device("/GPU:0"): + #tf_op = ConvolutionAwareInitializer()(shape) + + t = time.time() + tf_ca = nn.tf_sess.run ([tf_op,tf_op2,tf_op3,tf_op4]) + print(f"time {time.time() - t}") + + import code + code.interact(local=dict(globals(), **locals())) + + + #===================================== + + + #====================================== + + + + import code + code.interact(local=dict(globals(), **locals())) + + + + + + + + + + + + + + + + + print("importing") + """ + from core.leras import nn + nn.import_tf( device_config=nn.device.Config(force_gpu_idx=1) ) + tf = nn.tf + filepath = r'D:\DevelopPython\test\00000.png' + + img = cv2.imread(filepath).astype(np.float32) / 255.0 + h,w,c = img.shape + + inp = tf.placeholder(tf.float32, (1, h,w,c) ) + x = nn.gaussian_blur()(inp) + + q = nn.style_loss()(x, inp) + import code + code.interact(local=dict(globals(), **locals())) + + a = nn.tf_sess.run (x, feed_dict={inp:img[None,...]} ) + + cv2.imshow("", (a[0]*255).astype(np.uint8)) + cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + """ + + from core.leras import nn + nn.import_tf( device_config=nn.device.Config(force_gpu_idx=1) ) + + tf = nn.tf + tf_sess = nn.tf_sess + + import code + code.interact(local=dict(globals(), **locals())) + + class SubEncoder(nn.ModelBase): + def on_build(self): + self.conv1 = nn.Conv2D( 3, 3, kernel_size=3, padding='SAME', wscale_gain=np.sqrt(2) ) + + def call(self, x): + x = self.conv1(x) + return x + + class Encoder(nn.ModelBase): + def on_build(self): + self.conv1 = SubEncoder() + self.dense1 = nn.Dense( 64*64*3, 1 ) + + def call(self, x): + x = self.conv1(x) + x = nn.flatten()(x) + x = self.dense1(x) + return x + + with tf.device('/CPU:0'): + + inp = tf.placeholder(tf.float32, (None, 64,64,3) ) + real = tf.placeholder(tf.float32, (None, 1) ) + + encoder = Encoder(name='encoder') + encoder.init_weights() + #encoder.save_weights(r"D:\enc.h5") + #encoder.load_weights(r"D:\enc.h5") + + with tf.device('/GPU:0'): + x = encoder(inp) + + loss = tf.reduce_sum(tf.square(x - real)) + + enc_opt = nn.RMSprop(name='enc_opt') + + with tf.device('/GPU:0'): + grads_vars = nn.gradients(loss, encoder.get_weights() ) + + apply_op = enc_opt.get_updates (grads_vars ) + + enc_opt.init_weights() + + inp1 = np.random.uniform (size=(1,64,64,3)) + real1 = np.random.uniform (size=(1,1)) + + l, _ = nn.tf_sess.run ( [loss, apply_op], feed_dict={inp:inp1, real:real1}) + + #print ( tf_get_value(W) ) + + import code + code.interact(local=dict(globals(), **locals())) + + + src_real = np.random.uniform ( size=(1,64,64,3) ).astype(np.float32) + dst_real = np.random.uniform ( size=(1,64,64,3) ).astype(np.float32) + + with tf.device('/CPU:0'): + src_inp = KL.Input( (64,64,3) ) + dst_inp = KL.Input( (64,64,3) ) + + x = src_inp + x = KL.Conv2D(3, kernel_size=3, padding='same')(x) + enc = KM.Model(src_inp, x) + + code_inp = KL.Input( enc.outputs[0].shape[1:] ) + x = code_inp + x = KL.Conv2D(3, kernel_size=3, padding='same')(x) + dec_src = KM.Model(code_inp, x) + + code_inp = KL.Input( enc.outputs[0].shape[1:] ) + x = code_inp + x = KL.Conv2D(3, kernel_size=3, padding='same')(x) + dec_dst = KM.Model(code_inp, x) + + """ + with tf.device('/GPU:0'): + with tf.GradientTape() as src_tape: + t = time.time() + + code = enc(src_inp) + pred_src = dec_src(code) + + print(f"src took: {time.time()-t}") + src_loss = tf.math.reduce_mean ( tf.math.abs(pred_src-src_real) ) + """ + with tf.device('/GPU:0'): + with tf.GradientTape() as dst_tape: + t = time.time() + + code = enc(dst_inp) + pred_dst = dec_dst(code) + + print(f"dst took: {time.time()-t}") + dst_loss = tf.math.reduce_mean ( tf.math.abs(pred_dst-dst_real) ) + + + #grad1 = src_tape.gradient(src_loss, enc.trainable_variables) + grad2 = dst_tape.gradient(dst_loss, enc.trainable_variables) + + """ + grad = [] + for g1,g2 in zip(grad1,grad2): + g = tf.concat( [tf.expand_dims(g1,0), tf.expand_dims(g2,0)], axis=0) + g = tf.reduce_mean(g, 0) + grad.append(g) + """ + + apply_op = optimizer.apply_gradients(zip(grad2, enc.trainable_variables)) + + tf_sess.run ( tf.global_variables_initializer() ) + + dl, _ = nn.tf_sess.run ( [dst_loss, apply_op], feed_dict={dst_inp: src_real}) + print(dl) + + + + + + import code + code.interact(local=dict(globals(), **locals())) + + + + from core.leras import nn + exec( nn.import_all( device_config=nn.device.Config(force_gpu_idx=1) ), locals(), globals() )# + + from facelib import FaceEnhancer + + filepath = r'D:\DevelopPython\test\00000.jpg' + img = cv2.imread(filepath).astype(np.float32) / 255.0 + + fe = FaceEnhancer() + final_img = fe.enhance( img ) + + cv2.imshow("", (final_img*255).astype(np.uint8) ) + cv2.waitKey(0) + + + + + + + + + from core.leras import nn + exec( nn.import_all( device_config=nn.device.Config(force_gpu_idx=1, use_fp16=False) ), locals(), globals() )# + + filepath = r'D:\DevelopPython\test\ct_00003.jpg' + img = cv2.imread(filepath).astype(np.float32) / 255.0 + + inp = Input( (None,None,3) ) + x = AveragePooling2D(pool_size=2, strides=2, padding='same')(inp) + model = keras.models.Model (inp, x) + x, = K.function([inp],[x]) ( [img[None,...] ]) + + cv2.imshow("", (x[0]*255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + + + + + + from core.leras import nn + exec( nn.import_all( device_config=nn.device.Config() ), locals(), globals() ) + + + batch_size = 1024 + + i_t = Input ( (256,256,3) ) + j_t = Input ( (256,256,3) ) + + outputs = [] + #for i in range(batch_size): + outputs += [ K.sum( K.abs(i_t-j_t), axis=[1,2,3] ) ] + + func = K.function ( [i_t,j_t], outputs) + + + k1 = np.random.random ( size=(batch_size,256,256,3) ) + k2 = np.random.random ( size=(1,256,256,3) ) + + t = time.time() + result = func ([k1,k2]) + print (f"time took: {time.time()-t}") + t = time.time() + result = func ([k1,k2]) + print (f"time took: {time.time()-t}") + import code + code.interact(local=dict(globals(), **locals())) + + + + + a = [] + + for filename in io.progress_bar_generator(image_paths, ""): + a.append ( cv2_imread(filename) ) + + + + import code + code.interact(local=dict(globals(), **locals())) + + cap = cv2.VideoCapture(r'D:\DevelopPython\test\test1.mp4') + + import code + code.interact(local=dict(globals(), **locals())) + + libdll = CDLL(r"D:\DevelopPython\Projects\TestCPPDLL\x64\Release\TestCPPDLL.dll") + libdll.ST2DFloat.argtypes = ( \ + c_int, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p + ) + + dflimg = DFLJPG.load ( r'D:\DevelopPython\test\dflimg_1.jpg') + real_lmrks = dflimg.get_landmarks() + lmrks = dflimg.get_source_landmarks()[17:].astype(np.float32) + + pts1_bytes = lmrks.reshape ( np.prod(lmrks.shape) ).tobytes() + pts2_bytes = landmarks_2D.reshape ( np.prod(landmarks_2D.shape) ).tobytes() + rot_buf = create_string_buffer( 4*4 ) + trans_buf = create_string_buffer( 4*2 ) + scale_buf = create_string_buffer( 4*1 ) + + libdll.ST2DFloat ( len(lmrks), pts1_bytes, pts2_bytes, rot_buf, trans_buf, scale_buf ) + rot = np.frombuffer(rot_buf, dtype=np.float32) + trans = np.frombuffer(trans_buf, dtype=np.float32) + scale = np.frombuffer(scale_buf, dtype=np.float32) + + + mat = np.concatenate ([ rot.reshape ( (2,2) ), + trans.reshape ( (2,1) )*(1/scale) ], axis=-1 ) + new_lmrks = LandmarksProcessor.transform_points(lmrks, mat) + + #import code + #code.interact(local=dict(globals(), **locals())) + new_lmrks *= 3 + new_lmrks += [127,127] + + + img = np.zeros ( (256,256,3), dtype=np.uint8 ) + + for pt in new_lmrks: + x,y = pt + cv2.circle(img, (x, y), 1, (255,0,0) ) + + for pt in real_lmrks: + x,y = pt.astype(np.int) + cv2.circle(img, (x, y), 1, (0,0,255) ) + + cv2.imshow ("", img) + cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + + img_src = cv2.imread(r'D:\DevelopPython\test\ct_trg1.jpg')/255.0 + + img_trg = cv2.imread(r'D:\DevelopPython\test\ct_src1.jpg')/255.0 + + + screen1 = (color_transfer_mix(img_src, img_trg)*255.0).astype(np.uint8) + screen2 = (color_transfer_mix2(img_src, img_trg)*255.0).astype(np.uint8) + screen = np.concatenate([screen1,screen2], axis=1) + cv2.imshow ("", screen ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + + img_src = cv2.imread(r'D:\DevelopPython\test\ct_trg1.jpg') + img_src_float = img_src.astype(np.float32) + img_src_shape = img_src.shape + + img_src_lab = cv2.cvtColor(img_src, cv2.COLOR_BGR2LAB) + img_src_l = img_src_lab.copy() + img_src_l[...,0] = (np.ones_like (img_src_l[...,0])*100).astype(np.uint8) + img_src_l = cv2.cvtColor(img_src_l, cv2.COLOR_LAB2BGR) + + img_trg = cv2.imread(r'D:\DevelopPython\test\ct_src1.jpg') + img_trg_float = img_trg.astype(np.float32) + img_trg_shape = img_trg.shape + + + img_trg_lab = cv2.cvtColor(img_trg, cv2.COLOR_BGR2LAB) + img_trg_l = img_trg_lab.copy() + img_trg_l[...,0] = (np.ones_like (img_trg_l[...,0])*100).astype(np.uint8) + img_trg_l = cv2.cvtColor(img_trg_l, cv2.COLOR_LAB2BGR) + + + t = time.time() + + + #img_rct_light = imagelib.color_transfer_sot( img_src_lab[...,0:1].astype(np.float32), img_trg_lab[...,0:1].astype(np.float32) ) + + img_src_light = img_src_lab[...,0:1] + img_trg_light = img_trg_lab[...,0:1] + + +# +# + #img_rct_light = img_src_light * 1.0#( img_trg_light.mean()/img_src_light.mean() ) + #img_rct_light = np.clip (img_rct_light, 0, 100).astype(np.uint8) + + img_rct_light = imagelib.linear_color_transfer( img_src_lab[...,0:1].astype(np.float32)/255.0, + img_trg_lab[...,0:1].astype(np.float32)/255.0 )[...,0:1] *255.0 + + + img_rct_light = np.clip (img_rct_light, 0, 255).astype(np.uint8) + + + + + img_rct = imagelib.color_transfer_sot( img_src_l.astype(np.float32), img_trg_l.astype(np.float32) ) + img_rct = np.clip(img_rct, 0, 255) + img_rct_l = img_rct.astype(np.uint8) + + img_rct = cv2.cvtColor(img_rct_l, cv2.COLOR_BGR2LAB) + img_rct[...,0] = img_rct_light[...,0]#img_src_lab[...,0] + img_rct = cv2.cvtColor(img_rct, cv2.COLOR_LAB2BGR) + + #import code + #code.interact(local=dict(globals(), **locals())) + + screen1 = np.concatenate ([img_src, img_src_l, img_trg, img_trg_l, img_rct_l, + ], axis=1) + screen2 = np.concatenate ([ + np.repeat(img_src_lab[...,0:1], 3, -1), + np.repeat(img_trg_lab[...,0:1], 3, -1), + np.repeat(img_rct_light, 3, -1), + + img_rct,img_rct], axis=1) + screen = np.concatenate([screen1, screen2], axis=0 ) + cv2.imshow ("", screen ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + img_src = cv2.imread(r'D:\DevelopPython\test\ct_src.jpg') + img_src_shape = img_src.shape + + img = cv2.cvtColor(img_src, cv2.COLOR_BGR2LAB) + img[...,0] = (np.ones_like (img[...,0])*100).astype(np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_LAB2BGR) + + + #screen = np.concatenate ([img_src, img_trg, img_rct], axis=1) + cv2.imshow ("", img ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + image_paths = pathex.get_image_paths(r"E:\FakeFaceVideoSources\Datasets\CelebA\aligned_def\aligned") + + out_path = Path("E:\FakeFaceVideoSources\Datasets\CelebA\aligned_def\aligned_out") + + + while True: + + filename1 = image_paths[ np.random.randint(len(image_paths)) ] + filename2 = image_paths[ np.random.randint(len(image_paths)) ] + + + img1 = cv2_imread(filename1).astype(np.float32) + img1_mask = LandmarksProcessor.get_image_hull_mask (img1.shape , DFLJPG.load (filename1).get_landmarks() ) + + img2 = cv2_imread(filename2).astype(np.float32) + img2_mask = LandmarksProcessor.get_image_hull_mask (img2.shape , DFLJPG.load (filename2).get_landmarks() ) + + mask = img1_mask*img2_mask + + img1_masked = np.clip(img1*mask, 0,255) + img2_masked = np.clip(img2*mask, 0,255) + + + + img1_sot = imagelib.color_transfer_sot (img1_masked, img2_masked) + img1_sot = np.clip(img1_sot, 0, 255) + + l,t,w,h = cv2.boundingRect( (mask*255).astype(np.uint8) ) + + img_ct = cv2.seamlessClone( img1_sot.astype(np.uint8), img2.astype(np.uint8), (mask*255).astype(np.uint8), (int(l+w/2),int(t+h/2)) , cv2.NORMAL_CLONE ) + + #img_ct = out_img.astype(dtype=np.float32) / 255.0 + + #img_ct = imagelib.color_transfer_sot (img1_masked, img2_masked) + #img_ct = imagelib.linear_color_transfer ( img1_masked/255.0, img2_masked/255.0) * 255.0 + #img_ct = np.clip(img_ct, 0, 255) + + #img1_mask_blur = cv2.blur(img1_mask, (21,21) )[...,None] + + #img_ct = img1*(1-img1_mask)+img_ct*img1_mask + + screen = np.concatenate ([img1, img2, img1_sot, img_ct], axis=1) + + + cv2.imshow("", screen.astype(np.uint8) ) + cv2.waitKey(0) + + + import code + code.interact(local=dict(globals(), **locals())) + + img_src = cv2.imread(r'D:\DevelopPython\test\ct_src2.jpg') + img_src_float = img_src.astype(np.float32) + img_src_shape = img_src.shape + + img_trg = cv2.imread(r'D:\DevelopPython\test\ct_trg2.jpg') + img_trg_float = img_trg.astype(np.float32) + img_trg_shape = img_trg.shape + + + t = time.time() + + img_rct = imagelib.linear_color_transfer( (img_src/255.0).astype(np.float32), (img_trg/255.0).astype(np.float32) ) * 255.0 + img_rct = np.clip(img_rct, 0, 255) + img_rct = img_rct.astype(np.uint8) + + screen = np.concatenate ([img_src, img_trg, img_rct], axis=1) + cv2.imshow ("", screen ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + #=============================================================================== + + from core.leras import nn + exec( nn.import_all( device_config=nn.device.Config(cpu_only=True) ), locals(), globals() ) + + import torch + import torch.nn as nn + import torch.nn.functional as F + + + import face_alignment + fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,enable_cuda=False,enable_cudnn=False,use_cnn_face_detector=False).face_alignemnt_net + fa.eval() + + + def ConvBlock(out_planes, input, srctorch): + in_planes = K.int_shape(input)[-1] + x = input + x = BatchNormalization(momentum=0.1, epsilon=1e-05, weights=t2kw_bn2d(srctorch.bn1) )(x) + x = ReLU() (x) + x = out1 = Conv2D( int(out_planes/2), kernel_size=3, strides=1, padding='valid', use_bias = False, weights=t2kw_conv2d(srctorch.conv1) ) (ZeroPadding2D(1)(x)) + + x = BatchNormalization(momentum=0.1, epsilon=1e-05, weights=t2kw_bn2d(srctorch.bn2) )(x) + x = ReLU() (x) + x = out2 = Conv2D( int(out_planes/4), kernel_size=3, strides=1, padding='valid', use_bias = False, weights=t2kw_conv2d(srctorch.conv2) ) (ZeroPadding2D(1)(x)) + + x = BatchNormalization(momentum=0.1, epsilon=1e-05, weights=t2kw_bn2d(srctorch.bn3) )(x) + x = ReLU() (x) + x = out3 = Conv2D( int(out_planes/4), kernel_size=3, strides=1, padding='valid', use_bias = False, weights=t2kw_conv2d(srctorch.conv3) ) (ZeroPadding2D(1)(x)) + + x = Concatenate()([out1, out2, out3]) + + if in_planes != out_planes: + downsample = BatchNormalization(momentum=0.1, epsilon=1e-05, weights=t2kw_bn2d(srctorch.downsample[0]) )(input) + downsample = ReLU() (downsample) + downsample = Conv2D( out_planes, kernel_size=1, strides=1, padding='valid', use_bias = False, weights=t2kw_conv2d(srctorch.downsample[2]) ) (downsample) + x = Add ()([x, downsample]) + else: + x = Add ()([x, input]) + + + return x + + def HourGlass (depth, input, srctorch): + up1 = ConvBlock(256, input, srctorch._modules['b1_%d' % (depth)]) + + low1 = AveragePooling2D (pool_size=2, strides=2, padding='valid' )(input) + low1 = ConvBlock (256, low1, srctorch._modules['b2_%d' % (depth)]) + + if depth > 1: + low2 = HourGlass (depth-1, low1, srctorch) + else: + low2 = ConvBlock(256, low1, srctorch._modules['b2_plus_%d' % (depth)]) + + low3 = ConvBlock(256, low2, srctorch._modules['b3_%d' % (depth)]) + + up2 = UpSampling2D(size=2) (low3) + return Add() ( [up1, up2] ) + + FAN_Input = Input ( (256, 256, 3) ) + + x = FAN_Input + + x = Conv2D (64, kernel_size=7, strides=2, padding='valid', weights=t2kw_conv2d(fa.conv1))(ZeroPadding2D(3)(x)) + x = BatchNormalization(momentum=0.1, epsilon=1e-05, weights=t2kw_bn2d(fa.bn1))(x) + x = ReLU()(x) + + x = ConvBlock (128, x, fa.conv2) + x = AveragePooling2D (pool_size=2, strides=2, padding='valid') (x) + x = ConvBlock (128, x, fa.conv3) + x = ConvBlock (256, x, fa.conv4) + + outputs = [] + previous = x + for i in range(4): + ll = HourGlass (4, previous, fa._modules['m%d' % (i) ]) + ll = ConvBlock (256, ll, fa._modules['top_m_%d' % (i)]) + + ll = Conv2D(256, kernel_size=1, strides=1, padding='valid', weights=t2kw_conv2d( fa._modules['conv_last%d' % (i)] ) ) (ll) + ll = BatchNormalization(momentum=0.1, epsilon=1e-05, weights=t2kw_bn2d( fa._modules['bn_end%d' % (i)] ) )(ll) + ll = ReLU() (ll) + + tmp_out = Conv2D(68, kernel_size=1, strides=1, padding='valid', weights=t2kw_conv2d( fa._modules['l%d' % (i)] ) ) (ll) + outputs.append(tmp_out) + + if i < 4 - 1: + ll = Conv2D(256, kernel_size=1, strides=1, padding='valid', weights=t2kw_conv2d( fa._modules['bl%d' % (i)] ) ) (ll) + previous = Add() ( [previous, ll, KL.Conv2D(256, kernel_size=1, strides=1, padding='valid', weights=t2kw_conv2d( fa._modules['al%d' % (i)] ) ) (tmp_out) ] ) + + + + rnd_data = np.random.randint (256, size=(1,256,256,3) ).astype(np.float32) + + with torch.no_grad(): + fa_out_tensor = fa( torch.autograd.Variable( torch.from_numpy(rnd_data.transpose(0,3,1,2) ) ) )[-1].data.cpu() + fa_out = fa_out_tensor.numpy() + + FAN_model = Model(FAN_Input, outputs[-1] ) + FAN_model.save_weights (r"D:\DevelopPython\test\2DFAN-4.h5") + FAN_model_func = K.function (FAN_model.inputs, FAN_model.outputs) + + m_out, = FAN_model_func([ rnd_data ]) + + m_out = m_out.transpose(0,3,1,2) + + diff = np.sum(np.abs(np.ndarray.flatten(fa_out)-np.ndarray.flatten(m_out))) + print (f"====== diff {diff} =======") + + import code + code.interact(local=dict(globals(), **locals())) + + + from core.leras import nn + exec( nn.import_all( device_config=nn.device.Config(cpu_only=True) ), locals(), globals() ) + + import torch + import torch.nn as nn + import torch.nn.functional as F + + + + + def conv_bn(inp, oup, kernel, stride, padding=1): + return nn.Sequential( + nn.Conv2d(inp, oup, kernel, stride, padding, bias=False), + nn.BatchNorm2d(oup), + nn.ReLU(inplace=True)) + + class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, use_res_connect, expand_ratio=6): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + self.use_res_connect = use_res_connect + + self.conv = nn.Sequential( + nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), + nn.BatchNorm2d(inp * expand_ratio), + nn.ReLU(inplace=True), + nn.Conv2d( + inp * expand_ratio, + inp * expand_ratio, + 3, + stride, + 1, + groups=inp * expand_ratio, + bias=False), + nn.BatchNorm2d(inp * expand_ratio), + nn.ReLU(inplace=True), + nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + + class PFLDInference(nn.Module): + def __init__(self): + super(PFLDInference, self).__init__() + + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d( + 64, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + + self.conv3_1 = InvertedResidual(64, 64, 2, False, 2) + + self.block3_2 = InvertedResidual(64, 64, 1, True, 2) + self.block3_3 = InvertedResidual(64, 64, 1, True, 2) + self.block3_4 = InvertedResidual(64, 64, 1, True, 2) + self.block3_5 = InvertedResidual(64, 64, 1, True, 2) + + self.conv4_1 = InvertedResidual(64, 128, 2, False, 2) + + self.conv5_1 = InvertedResidual(128, 128, 1, False, 4) + self.block5_2 = InvertedResidual(128, 128, 1, True, 4) + self.block5_3 = InvertedResidual(128, 128, 1, True, 4) + self.block5_4 = InvertedResidual(128, 128, 1, True, 4) + self.block5_5 = InvertedResidual(128, 128, 1, True, 4) + self.block5_6 = InvertedResidual(128, 128, 1, True, 4) + + self.conv6_1 = InvertedResidual(128, 16, 1, False, 2) # [16, 14, 14] + + self.conv7 = conv_bn(16, 32, 3, 2) # [32, 7, 7] + self.conv8 = nn.Conv2d(32, 128, 7, 1, 0) # [128, 1, 1] + self.bn8 = nn.BatchNorm2d(128) + + self.avg_pool1 = nn.AvgPool2d(14) + self.avg_pool2 = nn.AvgPool2d(7) + self.fc = nn.Linear(176, 196) + + def forward(self, x): # x: 3, 112, 112 + x = self.relu(self.bn1(self.conv1(x))) # [64, 56, 56] + x = self.relu(self.bn2(self.conv2(x))) # [64, 56, 56] + x = self.conv3_1(x) + x = self.block3_2(x) + x = self.block3_3(x) + x = self.block3_4(x) + out1 = self.block3_5(x) + + x = self.conv4_1(out1) + x = self.conv5_1(x) + x = self.block5_2(x) + x = self.block5_3(x) + x = self.block5_4(x) + x = self.block5_5(x) + x = self.block5_6(x) + x = self.conv6_1(x) + + x1 = self.avg_pool1(x) + x1 = x1.view(x1.size(0), -1) + + x = self.conv7(x) + x2 = self.avg_pool2(x) + x2 = x2.view(x2.size(0), -1) + + x3 = self.relu(self.conv8(x)) + x3 = x3.view(x1.size(0), -1) + + multi_scale = torch.cat([x1, x2, x3], 1) + landmarks = self.fc(multi_scale) + + return out1, landmarks + + mw = torch.load(r"D:\DevelopPython\test\PFLD.pth.tar", map_location=torch.device('cpu')) + mw = mw['plfd_backbone'] + + + """ + + class TorchBatchNorm2D(keras.engine.Layer): + def __init__(self, axis=-1, momentum=0.1, epsilon=1e-5, **kwargs): + super(TorchBatchNorm2D, self).__init__(**kwargs) + self.supports_masking = True + self.axis = axis + self.momentum = momentum + self.epsilon = epsilon + + def build(self, input_shape): + dim = input_shape[self.axis] + if dim is None: + raise ValueError('Axis ' + str(self.axis) + ' of ' + 'input tensor should have a defined dimension ' + 'but the layer received an input with shape ' + + str(input_shape) + '.') + shape = (dim,) + self.gamma = self.add_weight(shape=shape, name='gamma', initializer='ones', regularizer=None, constraint=None) + self.beta = self.add_weight(shape=shape, name='beta', initializer='zeros', regularizer=None, constraint=None) + self.moving_mean = self.add_weight(shape=shape, name='moving_mean', initializer='zeros', trainable=False) + self.moving_variance = self.add_weight(shape=shape, name='moving_variance', initializer='ones', trainable=False) + self.built = True + + def call(self, inputs, training=None): + input_shape = K.int_shape(inputs) + + broadcast_shape = [1] * len(input_shape) + broadcast_shape[self.axis] = input_shape[self.axis] + + reduction_axes = list(range(len(input_shape))) + del reduction_axes[self.axis] + + broadcast_mean = K.mean(inputs, reduction_axes, keepdims=True) + broadcast_variance = K.var(inputs, reduction_axes, keepdims=True) + broadcast_gamma = K.reshape(self.gamma, broadcast_shape) + broadcast_beta = K.reshape(self.beta, broadcast_shape) + + return (inputs - broadcast_mean) / ( K.sqrt(broadcast_variance + K.constant(self.epsilon, dtype=K.floatx() )) ) * broadcast_gamma + broadcast_beta + + def get_config(self): + config = { 'axis': self.axis, 'momentum': self.momentum, 'epsilon': self.epsilon } + base_config = super(TorchBatchNorm2D, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + + def InvertedResidual(out_dim, strides, use_res_connect, expand_ratio=6, name_prefix=""): + + def func(inp): + c = K.int_shape(inp)[-1] + x = inp + x = Conv2D (c*expand_ratio, kernel_size=1, strides=1, padding='valid', use_bias=False, name=name_prefix+'_conv0')(x) + x = TorchBatchNorm2D(name=name_prefix+'_conv1')(x) + x = ReLU()(x) + + x = DepthwiseConv2D ( kernel_size=3, strides=strides, padding='valid', use_bias=False, name=name_prefix+'_conv3')(ZeroPadding2D(1)(x)) + x = TorchBatchNorm2D(name=name_prefix+'_conv4')(x) + x = ReLU()(x) + + x = Conv2D (out_dim, kernel_size=1, strides=1, padding='valid', use_bias=False, name=name_prefix+'_conv6')(x) + x = TorchBatchNorm2D(name=name_prefix+'_conv7')(x) + + + if use_res_connect: + x = Add()([inp, x]) + + return x + + return func + + PFLD_Input = Input ( (112, 112, 3) ) + + x = PFLD_Input + + x = Conv2D (64, kernel_size=3, strides=2, padding='valid', use_bias=False, name='conv1')(ZeroPadding2D(1)(x)) + x = TorchBatchNorm2D(name='bn1')(x) + x = ReLU()(x) + + + x = Conv2D (64, kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv2')(ZeroPadding2D(1)(x)) + x = TorchBatchNorm2D(name='bn2')(x) + x = ReLU()(x) + + + + x = InvertedResidual(64, 2, False, 2, 'conv3_1')(x) + x = InvertedResidual(64, 1, True, 2, 'block3_2')(x) + x = InvertedResidual(64, 1, True, 2, 'block3_3')(x) + x = InvertedResidual(64, 1, True, 2, 'block3_4')(x) + x = InvertedResidual(64, 1, True, 2, 'block3_5')(x) + x = InvertedResidual(128, 2, False, 2, 'conv4_1')(x) + x = InvertedResidual(128, 1, False, 4, 'conv5_1')(x) + + x = InvertedResidual(128, 1, True, 4, 'block5_2')(x) + x = InvertedResidual(128, 1, True, 4, 'block5_3')(x) + x = InvertedResidual(128, 1, True, 4, 'block5_4')(x) + x = InvertedResidual(128, 1, True, 4, 'block5_5')(x) + x = InvertedResidual(128, 1, True, 4, 'block5_6')(x) + + x = InvertedResidual(16, 1, False, 2, 'conv6_1')(x) + + x1 = AveragePooling2D(14)(x) + x1 = x1 = Flatten()(x1) + + x = Conv2D (32, kernel_size=3, strides=2, padding='valid', use_bias=False, name='conv7_0')(ZeroPadding2D(1)(x)) + x = TorchBatchNorm2D(name='conv7_1')(x) + x = ReLU()(x) + + x2 = AveragePooling2D(7)(x) + x2 = x2 = Flatten()(x2) + + x3 = Conv2D (128, kernel_size=7, strides=1, padding='valid', name='conv8')(x) + + x3 = ReLU()(x3) + x3 = Flatten()(x3) + + x = Concatenate(axis=-1)([x1,x2,x3]) + x = Dense(196, name='fc')(x) + + PFLD_model = Model(PFLD_Input, x ) + + try: + PFLD_model.get_layer('conv1').set_weights ( tdict2kw_conv2d (mw['conv1.weight']) ) + PFLD_model.get_layer('bn1').set_weights ( tdict2kw_bn2d (mw, 'bn1') ) + PFLD_model.get_layer('conv2').set_weights ( tdict2kw_conv2d (mw['conv2.weight']) ) + PFLD_model.get_layer('bn2').set_weights ( tdict2kw_bn2d (mw, 'bn2') ) + + for block_name in ['conv3_1', 'block3_2', 'block3_3', 'block3_4','block3_5','conv4_1','conv5_1','block5_2', \ + 'block5_3','block5_4','block5_5','block5_6','conv6_1']: + PFLD_model.get_layer(block_name+'_conv0').set_weights ( tdict2kw_conv2d (mw[block_name+'.conv.0.weight']) ) + PFLD_model.get_layer(block_name+'_conv1').set_weights ( tdict2kw_bn2d (mw, block_name+'.conv.1') ) + PFLD_model.get_layer(block_name+'_conv3').set_weights ( tdict2kw_depconv2d (mw[block_name+'.conv.3.weight']) ) + PFLD_model.get_layer(block_name+'_conv4').set_weights ( tdict2kw_bn2d (mw, block_name+'.conv.4') ) + PFLD_model.get_layer(block_name+'_conv6').set_weights ( tdict2kw_conv2d (mw[block_name+'.conv.6.weight']) ) + PFLD_model.get_layer(block_name+'_conv7').set_weights ( tdict2kw_bn2d (mw, block_name+'.conv.7') ) + + PFLD_model.get_layer('conv7_0').set_weights ( tdict2kw_conv2d (mw['conv7.0.weight']) ) + PFLD_model.get_layer('conv7_1').set_weights ( tdict2kw_bn2d (mw, 'conv7.1') ) + PFLD_model.get_layer('conv8').set_weights ( tdict2kw_conv2d (mw['conv8.weight'], mw['conv8.bias']) ) + + PFLD_model.get_layer('fc').set_weights ( [ np.transpose(mw['fc.weight'].numpy(), [1,0] ), mw['fc.bias'] ] ) + except: + pass + + PFLD_model.save_weights (r"D:\DevelopPython\test\PFLD.h5") + PFLD_model_func = K.function (PFLD_model.inputs, PFLD_model.outputs) + + + q, = PFLD_model_func([ image[None,...] ]) + """ + + #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + #w = np.transpose(w, [0,2,3,1]) + #diff = np.sum(np.abs(np.ndarray.flatten(q)-np.ndarray.flatten(w))) + #print (f"====== diff {diff} =======") + #lmrks_98 = lmrks_98[:,::-1] + #lmrks_68 = LandmarksProcessor.convert_98_to_68 (lmrks_98) + + torchpfld = PFLDInference() + torchpfld.load_state_dict(mw) + torchpfld.eval() + + image = cv2.imread(r'D:\DevelopPython\test\00000.png').astype(np.float32) / 255.0 + image = cv2.resize (image, (112,112) ) + #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image_torch = torch.autograd.Variable(torch.from_numpy( np.transpose(image[None,...], [0,3,1,2] ) ) ) + + with torch.no_grad(): + _, w = torchpfld( image_torch ) + w = w.cpu().numpy() + + + + lmrks_98 = w.reshape ( (-1,2) )*112 + for x, y in lmrks_98.astype(np.int): + cv2.circle(image, (x, y), 1, (0,1,0) ) + + #LandmarksProcessor.draw_landmarks (image, lmrks_68) + + #cv2.imshow ("1", image.astype(np.uint8) ) + cv2.imshow ("1", (image*255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + + + + """ + + + + + import code + code.interact(local=dict(globals(), **locals())) + + import mxnet as mx + sym = mx.symbol.load (r"D:\DevelopPython\test\al.json") + + module = mx.module.Module(symbol=sym, + data_names=['data'], + label_names=None, + context=mx.cpu(), + work_load_list=None) + save_dict = mx.nd.load(r"D:\DevelopPython\test\al.params") + + import code + code.interact(local=dict(globals(), **locals())) + + + def process(w,h, data ): + d = {} + cur_lc = 0 + all_lines = [] + for s, pts_loop_ar in data: + lines = [] + for pts, loop in pts_loop_ar: + pts_len = len(pts) + lines.append ( [ [ pts[i], pts[(i+1) % pts_len ] ] for i in range(pts_len - (0 if loop else 1) ) ] ) + lines = np.concatenate (lines) + + lc = lines.shape[0] + all_lines.append(lines) + d[s] = cur_lc, cur_lc+lc + cur_lc += lc + all_lines = np.concatenate (all_lines, 0) + + #calculate signed distance for all points and lines + line_count = all_lines.shape[0] + pts_count = w*h + + all_lines = np.repeat ( all_lines[None,...], pts_count, axis=0 ).reshape ( (pts_count*line_count,2,2) ) + + pts = np.empty( (h,w,line_count,2), dtype=np.float32 ) + pts[...,1] = np.arange(h)[:,None,None] + pts[...,0] = np.arange(w)[:,None] + pts = pts.reshape ( (h*w*line_count, -1) ) + + a = all_lines[:,0,:] + b = all_lines[:,1,:] + pa = pts-a + ba = b-a + ph = np.clip ( np.einsum('ij,ij->i', pa, ba) / np.einsum('ij,ij->i', ba, ba), 0, 1 ) + dists = npla.norm ( pa - ba*ph[...,None], axis=1).reshape ( (h,w,line_count) ) + + def get_dists(name, thickness=0): + s,e = d[name] + result = dists[...,s:e] + if thickness != 0: + result = np.abs(result)-thickness + return np.min (result, axis=-1) + + return get_dists + + t = time.time() + + gdf = process ( 256,256, + ( + ('x', ( ( [ [0,0],[150,50],[30,180],[255,255] ], True), ) ), + ) + ) + + mask = gdf('x',3) + mask = np.clip ( 1- ( np.sqrt( np.maximum(mask,0) ) / 5 ), 0, 1) + # mask = 1-np.clip( np.cbrt(mask) / 15, 0, 1) + + + def alpha_to_color (img_alpha, color): + if len(img_alpha.shape) == 2: + img_alpha = img_alpha[...,None] + h,w,c = img_alpha.shape + result = np.zeros( (h,w, len(color) ), dtype=np.float32 ) + result[:,:] = color + + return result * img_alpha + + mask = alpha_to_color(mask, (0,1,0) ) + + + print(f"time took {time.time() - t}") + + cv2.imshow ("1", (mask*255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + ct_1_filepath = r'D:\DevelopPython\test\ct_00003.jpg' + ct_1_img = cv2.imread(ct_1_filepath).astype(np.float32) / 255.0 + ct_1_img_shape = ct_1_img.shape + ct_1_dflimg = DFLJPG.load ( ct_1_filepath) + + ct_2_filepath = r'D:\DevelopPython\test\ct_trg.jpg' + ct_2_img = cv2.imread(ct_2_filepath).astype(np.float32) / 255.0 + ct_2_img_shape = ct_2_img.shape + ct_2_dflimg = DFLJPG.load ( ct_2_filepath) + + result = cv2.bilateralFilter( ct_1_img , 0, 1000,1) + + + #result = color_transfer_mkl ( ct_2_img, ct_1_img ) + + + + #import code + #code.interact(local=dict(globals(), **locals())) + cv2.imshow ("1", (result*255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + #t = time.time() + ct_1_mask = LandmarksProcessor.get_image_hull_mask (ct_1_img_shape , ct_1_dflimg.get_landmarks() ) + ct_2_mask = LandmarksProcessor.get_image_hull_mask (ct_2_img_shape , ct_2_dflimg.get_landmarks() ) + + #ct_1_cmask = ( LandmarksProcessor.get_cmask (ct_1_img_shape , ct_1_dflimg.get_landmarks() ) *255).astype(np.uint8) + #ct_2_cmask = ( LandmarksProcessor.get_cmask (ct_2_img_shape , ct_2_dflimg.get_landmarks() ) *255).astype(np.uint8) + #print (f"time took:{time.time()-t}") + + #LandmarksProcessor.draw_landmarks (ct_1_img, ct_1_dflimg.get_landmarks(), color=(0,255,0), transparent_mask=False, ie_polys=None) + #cv2.imshow ("asd", (ct_1_cmask*255).astype(np.uint8) ) + #cv2.waitKey(0) + #cv2.imshow ("asd", (ct_2_cmask*255).astype(np.uint8) ) + #cv2.waitKey(0) + + + + + import ebsynth + while True: + + mask = (np.ones_like(ct_2_img)*255).astype(np.uint8) + + mask[:,0:16,:] = 0 + mask[:,-16:0,:] = 0 + mask[0:16,:,:] = 0 + mask[-16:0,:,:] = 0 + t = time.time() + img = imagelib.seamless_clone( ct_2_img.copy(), ct_1_img.copy(), ct_1_mask[...,0] ) + + print (f"time took: {time.time()-t}") + screen = np.concatenate ( (ct_1_img, ct_2_img, img), axis=1) + cv2.imshow ("1", (screen*255).astype(np.uint8) ) + cv2.waitKey(0) + + #import code + #code.interact(local=dict(globals(), **locals())) + import ebsynth + while True: + t = time.time() + + img = ebsynth.color_transfer(ct_2_img, ct_1_img) + print (f"time took: {time.time()-t}") + screen = np.concatenate ( (ct_1_img, ct_2_img, img), axis=1) + cv2.imshow ("1", screen ) + cv2.waitKey(0) + + import ebsynth + while True: + t = time.time() + + img = ebsynth.color_transfer(ct_1_img, ct_2_img) + img2 = ebsynth.color_transfer(ct_1_img, ct_2_img, ct_1_cmask, ct_2_cmask) + print (f"time took: {time.time()-t}") + screen = np.concatenate ( (ct_1_img, ct_1_cmask, ct_2_img, ct_2_cmask, img, img2), axis=1) + cv2.imshow ("asd", screen ) + cv2.waitKey(0) + + + import code + code.interact(local=dict(globals(), **locals())) + + + + from core.leras import nn + exec( nn.import_all( device_config=nn.device.Config(cpu_only=True) ), locals(), globals() ) + + class PixelShufflerTorch(KL.Layer): + def __init__(self, size=(2, 2), data_format='channels_last', **kwargs): + super(PixelShufflerTorch, self).__init__(**kwargs) + self.data_format = data_format + self.size = size + + def call(self, inputs): + input_shape = K.shape(inputs) + if K.int_shape(input_shape)[0] != 4: + raise ValueError('Inputs should have rank 4; Received input shape:', str(K.int_shape(inputs))) + + batch_size, h, w, c = input_shape[0], input_shape[1], input_shape[2], K.int_shape(inputs)[-1] + rh, rw = self.size + oh, ow = h * rh, w * rw + oc = c // (rh * rw) + + out = inputs + out = K.permute_dimensions(out, (0, 3, 1, 2)) #NCHW + + out = K.reshape(out, (batch_size, oc, rh, rw, h, w)) + out = K.permute_dimensions(out, (0, 1, 4, 2, 5, 3)) + out = K.reshape(out, (batch_size, oc, oh, ow)) + + out = K.permute_dimensions(out, (0, 2, 3, 1)) + return out + + def compute_output_shape(self, input_shape): + if len(input_shape) != 4: + raise ValueError('Inputs should have rank ' + str(4) + '; Received input shape:', str(input_shape)) + + height = input_shape[1] * self.size[0] if input_shape[1] is not None else None + width = input_shape[2] * self.size[1] if input_shape[2] is not None else None + channels = input_shape[3] // self.size[0] // self.size[1] + + if channels * self.size[0] * self.size[1] != input_shape[3]: + raise ValueError('channels of input and size are incompatible') + + return (input_shape[0], + height, + width, + channels) + + def get_config(self): + config = {'size': self.size, + 'data_format': self.data_format} + base_config = super(PixelShufflerTorch, self).get_config() + + return dict(list(base_config.items()) + list(config.items())) + + import torch + import torch.nn as nn + import torch.nn.functional as F + model_weights = torch.load(r"D:\DevelopPython\test\RankSRGAN_NIQE.pth") + + def res_block(inp, name_prefix): + x = inp + x = Conv2D (ndf, kernel_size=3, strides=1, padding='same', activation="relu", name=name_prefix+"0")(x) + x = Conv2D (ndf, kernel_size=3, strides=1, padding='same', name=name_prefix+"2")(x) + return Add()([inp,x]) + + RankSRGAN_Input = Input ( (None, None,3) ) + ndf = 64 + x = RankSRGAN_Input + + x = x0 = Conv2D (ndf, kernel_size=3, strides=1, padding='same', name="model0")(x) + for i in range(16): + x = res_block(x, "model1%.2d" %i ) + x = Conv2D (ndf, kernel_size=3, strides=1, padding='same', name="model1160")(x) + x = Add()([x0,x]) + + x = ReLU() ( PixelShufflerTorch() ( Conv2D (ndf*4, kernel_size=3, strides=1, padding='same', name="model2")(x) ) ) + x = ReLU() ( PixelShufflerTorch() ( Conv2D (ndf*4, kernel_size=3, strides=1, padding='same', name="model5")(x) ) ) + + x = Conv2D (ndf, kernel_size=3, strides=1, padding='same', activation="relu", name="model8")(x) + x = Conv2D (3, kernel_size=3, strides=1, padding='same', name="model10")(x) + RankSRGAN_model = Model(RankSRGAN_Input, x ) + + RankSRGAN_model.get_layer("model0").set_weights (tdict2kw_conv2d (model_weights['model.0.weight'], model_weights['model.0.bias'])) + + for i in range(16): + RankSRGAN_model.get_layer("model1%.2d0" %i).set_weights (tdict2kw_conv2d (model_weights['model.1.sub.%d.res.0.weight' % i], model_weights['model.1.sub.%d.res.0.bias' % i])) + RankSRGAN_model.get_layer("model1%.2d2" %i).set_weights (tdict2kw_conv2d (model_weights['model.1.sub.%d.res.2.weight' % i], model_weights['model.1.sub.%d.res.2.bias' % i])) + + RankSRGAN_model.get_layer("model1160").set_weights (tdict2kw_conv2d (model_weights['model.1.sub.16.weight'], model_weights['model.1.sub.16.bias'])) + RankSRGAN_model.get_layer("model2").set_weights (tdict2kw_conv2d (model_weights['model.2.weight'], model_weights['model.2.bias'])) + RankSRGAN_model.get_layer("model5").set_weights (tdict2kw_conv2d (model_weights['model.5.weight'], model_weights['model.5.bias'])) + RankSRGAN_model.get_layer("model8").set_weights (tdict2kw_conv2d (model_weights['model.8.weight'], model_weights['model.8.bias'])) + RankSRGAN_model.get_layer("model10").set_weights (tdict2kw_conv2d (model_weights['model.10.weight'], model_weights['model.10.bias'])) + + RankSRGAN_model.save (r"D:\DevelopPython\test\RankSRGAN.h5") + + RankSRGAN_model_func = K.function (RankSRGAN_model.inputs, RankSRGAN_model.outputs) + + image = cv2.imread(r'D:\DevelopPython\test\00002.jpg').astype(np.float32) / 255.0 + + q, = RankSRGAN_model_func([ image[None,...] ]) + + cv2.imshow ("", np.clip ( q[0]*255, 0, 255).astype(np.uint8) ) + cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + + + + image = cv2.imread(r'D:\DevelopPython\test\00000.png').astype(np.float32) / 255.0 + image_shape = image.shape + + def apply_motion_blur(image, size, angle): + k = np.zeros((size, size), dtype=np.float32) + k[ (size-1)// 2 , :] = np.ones(size, dtype=np.float32) + k = cv2.warpAffine(k, cv2.getRotationMatrix2D( (size / 2 -0.5 , size / 2 -0.5 ) , angle, 1.0), (size, size) ) + k = k * ( 1.0 / np.sum(k) ) + return cv2.filter2D(image, -1, k) + + for i in range(0, 9999): + img = np.clip ( apply_motion_blur(image, 15, i % 360), 0, 1 ) + + cv2.imshow ("", ( img*255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + from core.leras import nn + exec( nn.import_all( device_config=nn.device.Config(cpu_only=True) ), locals(), globals() ) + PMLTile = nn.PMLTile + PMLK = nn.PMLK + + def rgb_to_lab(inp): + rgb_pixels = (inp / 12.92 * K.cast(inp <= 0.04045, dtype=K.floatx() ) ) \ + + K.pow( (inp + 0.055) / 1.055, 2.4) * K.cast(inp > 0.04045 , dtype=K.floatx() ) + + xyz_pixels = K.dot(rgb_pixels, K.constant( np.array([ + # X Y Z + [0.412453, 0.212671, 0.019334], # R + [0.357580, 0.715160, 0.119193], # G + [0.180423, 0.072169, 0.950227], # B + ]))) / K.constant([0.950456, 1.0, 1.088754]) + + fxfyfz_pixels = (xyz_pixels * 7.787 + 16/116 ) * K.cast(xyz_pixels <= 0.008856, dtype=K.floatx() ) \ + + K.pow(xyz_pixels, 1/3) * K.cast(xyz_pixels > 0.008856 , dtype=K.floatx() ) + + x = K.dot(fxfyfz_pixels, K.constant( np.array([ + # l a b + [ 0.0, 500.0, 0.0], # fx + [116.0, -500.0, 200.0], # fy + [ 0.0, 0.0, -200.0], # fz + ]))) + K.constant([-16.0, 0.0, 0.0]) + + x = K.round ( x ) + return x + + def lab_to_rgb(inp): + fxfyfz_pixels = K.dot(inp + K.constant([16.0, 0.0, 0.0]), K.constant(np.array([ + # fx fy fz + [1/116.0, 1/116.0, 1/116.0], # l + [1/500.0, 0.0, 0.0], # a + [ 0.0, 0.0, -1/200.0], # b + ]))) + + xyz_pixels = ( ( (fxfyfz_pixels - 16/116 ) / 7.787 ) * K.cast(fxfyfz_pixels <= 6/29, dtype=K.floatx() ) \ + + K.pow(fxfyfz_pixels, 3) * K.cast(fxfyfz_pixels > 6/29, dtype=K.floatx() ) \ + ) * K.constant([0.950456, 1.0, 1.088754]) + + rgb_pixels = K.dot(xyz_pixels, K.constant(np.array([ + # r g b + [ 3.2404542, -0.9692660, 0.0556434], # x + [-1.5371385, 1.8760108, -0.2040259], # y + [-0.4985314, 0.0415560, 1.0572252], # z + ]))) + rgb_pixels = K.clip(rgb_pixels, 0.0, 1.0) + + return (rgb_pixels * 12.92 * K.cast(rgb_pixels <= 0.0031308, dtype=K.floatx() ) ) \ + + ( (K.pow(rgb_pixels, 1/2.4) * 1.055) - 0.055) * K.cast(rgb_pixels > 0.0031308, dtype=K.floatx() ) + + def rct_flow(img_src_t, img_trg_t): + if len(K.int_shape(img_src_t)) != len(K.int_shape(img_trg_t)): + raise ValueError( len(img_src_t.shape) != len(img_trg_t.shape) ) + + + initial_shape = K.shape(img_src_t) + h,w,c = K.int_shape(img_src_t)[-3::] + + img_src_t = K.reshape ( img_src_t, (-1,h,w,c) ) + img_trg_t = K.reshape ( img_trg_t, (-1,h,w,c) ) + + + + img_src_lab_t = rgb_to_lab(img_src_t) + img_src_lab_L_t = img_src_lab_t[...,0:1] + img_src_lab_a_t = img_src_lab_t[...,1:2] + img_src_lab_b_t = img_src_lab_t[...,2:3] + + img_src_lab_L_mean_t = K.mean(img_src_lab_L_t, axis=(-1,-2,-3), keepdims=True ) + img_src_lab_L_std_t = K.std(img_src_lab_L_t, axis=(-1,-2,-3), keepdims=True ) + img_src_lab_a_mean_t = K.mean(img_src_lab_a_t, axis=(-1,-2,-3), keepdims=True ) + img_src_lab_a_std_t = K.std(img_src_lab_a_t, axis=(-1,-2,-3), keepdims=True ) + img_src_lab_b_mean_t = K.mean(img_src_lab_b_t, axis=(-1,-2,-3), keepdims=True ) + img_src_lab_b_std_t = K.std(img_src_lab_b_t, axis=(-1,-2,-3), keepdims=True ) + + img_trg_lab_t = rgb_to_lab(img_trg_t) + img_trg_lab_L_t = img_trg_lab_t[...,0:1] + img_trg_lab_a_t = img_trg_lab_t[...,1:2] + img_trg_lab_b_t = img_trg_lab_t[...,2:3] + img_trg_lab_L_mean_t = K.mean(img_trg_lab_L_t, axis=(-1,-2,-3), keepdims=True ) + img_trg_lab_L_std_t = K.std(img_trg_lab_L_t, axis=(-1,-2,-3), keepdims=True ) + img_trg_lab_a_mean_t = K.mean(img_trg_lab_a_t, axis=(-1,-2,-3), keepdims=True ) + img_trg_lab_a_std_t = K.std(img_trg_lab_a_t, axis=(-1,-2,-3), keepdims=True ) + img_trg_lab_b_mean_t = K.mean(img_trg_lab_b_t, axis=(-1,-2,-3), keepdims=True ) + img_trg_lab_b_std_t = K.std(img_trg_lab_b_t, axis=(-1,-2,-3), keepdims=True ) + + img_new_lab_L_t = (img_src_lab_L_std_t / img_trg_lab_L_std_t)*(img_trg_lab_L_t-img_trg_lab_L_mean_t) + img_src_lab_L_mean_t + img_new_lab_a_t = (img_src_lab_a_std_t / img_trg_lab_a_std_t)*(img_trg_lab_a_t-img_trg_lab_a_mean_t) + img_src_lab_a_mean_t + img_new_lab_b_t = (img_src_lab_b_std_t / img_trg_lab_b_std_t)*(img_trg_lab_b_t-img_trg_lab_b_mean_t) + img_src_lab_b_mean_t + + img_new_t = lab_to_rgb( K.concatenate ( [img_new_lab_L_t, img_new_lab_a_t, img_new_lab_b_t], -1) ) + + img_new_t = K.reshape ( img_new_t, initial_shape ) + + return img_new_t + + + # class ImagePatches(PMLTile.Operation): + # def __init__(self, images, ksizes, strides, rates=(1,1,1,1), padding="VALID"): + # """ + # Compatible to tensorflow.extract_image_patches. + # Extract patches from images and put them in the "depth" output dimension. + # Args: + # images: A tensor with a shape of [batch, rows, cols, depth] + # ksizes: The size of the oatches with a shape of [1, patch_rows, patch_cols, 1] + # strides: How far the center of two patches are in the image with a shape of [1, stride_rows, stride_cols, 1] + # rates: How far two consecutive pixel are in the input. Equivalent to dilation. Expect shape of [1, rate_rows, rate_cols, 1] + # padding: A string of "VALID" or "SAME" defining padding. + + # Does not work with symbolic height and width. + # """ + # i_shape = images.shape.dims + # patch_row_eff = ksizes[1] + ((ksizes[1] - 1) * (rates[1] -1)) + # patch_col_eff = ksizes[2] + ((ksizes[2] - 1) * (rates[2] -1)) + + # if padding.upper() == "VALID": + # out_rows = math.ceil((i_shape[1] - patch_row_eff + 1.) / float(strides[1])) + # out_cols = math.ceil((i_shape[2] - patch_col_eff + 1.) / float(strides[2])) + # pad_str = "PAD = I;" + # else: + # out_rows = math.ceil( i_shape[1] / float(strides[1]) ) + # out_cols = math.ceil( i_shape[2] / float(strides[2]) ) + # dim_calc = "NY={NY}; NX={NX};".format(NY=out_rows, NX=out_cols) + # pad_top = max(0, ( (out_rows - 1) * strides[1] + patch_row_eff - i_shape[1] ) // 2) + # pad_left = max(0, ( (out_cols - 1) * strides[2] + patch_col_eff - i_shape[2] ) // 2) + # # we simply assume padding right == padding left + 1 (same for top/down). + # # This might lead to us padding more as we would need but that won't matter. + # # TF splits padding between both sides so left_pad +1 should keep us on the safe side. + # pad_str = """PAD[b, y, x, d : B, Y + {PT} * 2 + 1, X + {PL} * 2 + 1, D] = + # =(I[b, y - {PT}, x - {PL}, d]);""".format(PT=pad_top, PL=pad_left) + + # o_shape = (i_shape[0], out_rows, out_cols, ksizes[1]*ksizes[2]*i_shape[-1]) + # code = """function (I[B,Y,X,D]) -> (O) {{ + # {PAD} + # TMP[b, ny, nx, y, x, d: B, {NY}, {NX}, {KY}, {KX}, D] = + # =(PAD[b, ny * {SY} + y * {RY}, nx * {SX} + x * {RX}, d]); + # O = reshape(TMP, B, {NY}, {NX}, {KY} * {KX} * D); + # }} + # """.format( + # PAD=pad_str, + # NY=out_rows, NX=out_cols, + # KY=ksizes[1], KX=ksizes[2], + # SY=strides[1], SX=strides[2], + # RY=rates[1], RX=rates[2] + # ) + # super(ImagePatches, self).__init__(code, + # [('I', images),], + # [('O', PMLTile.Shape(images.shape.dtype, o_shape))]) + + img_src = cv2.imread(r'D:\DevelopPython\test\ct_src.jpg').astype(np.float32) / 255.0 + img_src = np.expand_dims (img_src, 0) + img_src_shape = img_src.shape + + img_trg = cv2.imread(r'D:\DevelopPython\test\ct_trg.jpg').astype(np.float32) / 255.0 + img_trg = np.expand_dims (img_trg, 0) + img_trg_shape = img_trg.shape + + img_src_t = Input ( img_src_shape[1:] ) + img_trg_t = Input ( img_src_shape[1:] ) + + img_rct_t = rct_flow (img_src_t, img_trg_t) + + img_rct = K.function ( [img_src_t, img_trg_t], [ img_rct_t ]) ( [img_src[...,::-1], img_trg[...,::-1]]) [0][0][...,::-1] + + + img_rct_true = imagelib.reinhard_color_transfer ( np.clip( (img_trg[0]*255).astype(np.uint8), 0, 255), + np.clip( (img_src[0]*255).astype(np.uint8), 0, 255) ) + + img_rct_true = img_rct_true / 255.0 + + print("diff ", np.sum(np.abs(img_rct-img_rct_true)) ) + + cv2.imshow ("", ( img_rct*255).astype(np.uint8) ) + cv2.waitKey(0) + cv2.imshow ("", ( img_rct_true*255).astype(np.uint8) ) + cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + + + wnd_size = 15#img_src.shape[1] // 8 - 1 + pad_size = wnd_size // 2 + sh = img_src.shape[1] + pad_size*2 + + step_size = 1 + k = (sh-wnd_size) // step_size + 1 + + img_src_padded_t = K.spatial_2d_padding (img_src_t, ((pad_size,pad_size), (pad_size,pad_size)) ) + img_trg_padded_t = K.spatial_2d_padding (img_trg_t, ((pad_size,pad_size), (pad_size,pad_size)) ) + #ImagePatches.function + img_src_patches_t = nn.tf.extract_image_patches ( img_src_padded_t, [1,k,k,1], [1,1,1,1], [1,step_size,step_size,1], "VALID") + img_trg_patches_t = nn.tf.extract_image_patches ( img_trg_padded_t, [1,k,k,1], [1,1,1,1], [1,step_size,step_size,1], "VALID") + + + img_src_patches_t = \ + K.concatenate ([ K.expand_dims( K.permute_dimensions ( img_src_patches_t[...,2::3], (0,3,1,2) ), -1), + K.expand_dims( K.permute_dimensions ( img_src_patches_t[...,1::3], (0,3,1,2) ), -1), + K.expand_dims( K.permute_dimensions ( img_src_patches_t[...,0::3], (0,3,1,2) ), -1) ], -1 ) + + img_trg_patches_t = \ + K.concatenate ([ K.expand_dims( K.permute_dimensions ( img_trg_patches_t[...,2::3], (0,3,1,2) ), -1), + K.expand_dims( K.permute_dimensions ( img_trg_patches_t[...,1::3], (0,3,1,2) ), -1), + K.expand_dims( K.permute_dimensions ( img_trg_patches_t[...,0::3], (0,3,1,2) ), -1) ], -1 ) + + #img_src_patches_lab_t = bgr_to_lab(img_src_patches_t) + #img_src_patches_lab = K.function ( [img_src_t], [ img_src_patches_lab_t ]) ([img_src]) [0][0] + + img_rct_patches_t = rct_flow (img_src_patches_t, img_trg_patches_t) + + img_rct_t = K.reshape ( img_rct_patches_t[...,pad_size,pad_size,:], (-1,256,256,3) ) + + img_rct = K.function ( [img_src_t, img_trg_t], [ img_rct_t ]) ( [img_src, img_trg]) [0][0][...,::-1] + + #import code + #code.interact(local=dict(globals(), **locals())) + + #for i in range( img_rct.shape[0] ): + cv2.imshow ("", ( img_rct*255).astype(np.uint8) ) + cv2.waitKey(0) + + + + #cv2.imshow ("", (img_new*255).astype(np.uint8) ) + #cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + bgr_to_lab_f = K.function ( [image_tensor], [ image_lab_t ]) + lab_to_bgr_f = K.function ( [image_tensor], [ lab_to_bgr(image_tensor) ]) + + img_src_lab = bgr_to_lab_f( [img_src[...,::-1]] ) + img_src_lab_bgr, = lab_to_bgr_f( [img_src_lab] ) + + diff = np.sum ( np.abs(img_src-img_src_lab_bgr[...,::-1] ) ) + print ("bgr->lab->bgr diff ", diff) + + #image_cv_lab = cv2.cvtColor( img_src[0], cv2.COLOR_BGR2LAB) + #print ("lab and cv lab diff ", np.sum(np.abs(image_lab[0].astype(np.int8)-image_cv_lab.astype(np.int8))) ) + #print ("lab and cv lab diff ", np.sum(np.abs(image_lab[0]-image_cv_lab)) ) + + #cv2.imshow ("", image_bgr.astype(np.uint8) ) + #cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + ########### + """ + img_src_lab_t = bgr_to_lab(img_src_t) + img_src_lab_L_t = img_src_lab_t[...,0:1] + img_src_lab_a_t = img_src_lab_t[...,1:2] + img_src_lab_b_t = img_src_lab_t[...,2:3] + img_src_lab_L_mean_t = K.mean(img_src_lab_L_t) + img_src_lab_L_std_t = K.std(img_src_lab_L_t) + img_src_lab_a_mean_t = K.mean(img_src_lab_a_t) + img_src_lab_a_std_t = K.std(img_src_lab_a_t) + img_src_lab_b_mean_t = K.mean(img_src_lab_b_t) + img_src_lab_b_std_t = K.std(img_src_lab_b_t) + + img_trg_lab_t = bgr_to_lab(img_trg_t) + img_trg_lab_L_t = img_trg_lab_t[...,0:1] + img_trg_lab_a_t = img_trg_lab_t[...,1:2] + img_trg_lab_b_t = img_trg_lab_t[...,2:3] + img_trg_lab_L_mean_t = K.mean(img_trg_lab_L_t) + img_trg_lab_L_std_t = K.std(img_trg_lab_L_t) + img_trg_lab_a_mean_t = K.mean(img_trg_lab_a_t) + img_trg_lab_a_std_t = K.std(img_trg_lab_a_t) + img_trg_lab_b_mean_t = K.mean(img_trg_lab_b_t) + img_trg_lab_b_std_t = K.std(img_trg_lab_b_t) + img_new_lab_L_t = (img_src_lab_L_std_t / img_trg_lab_L_std_t)*(img_trg_lab_L_t-img_trg_lab_L_mean_t) + img_src_lab_L_mean_t + img_new_lab_a_t = (img_src_lab_a_std_t / img_trg_lab_a_std_t)*(img_trg_lab_a_t-img_trg_lab_a_mean_t) + img_src_lab_a_mean_t + img_new_lab_b_t = (img_src_lab_b_std_t / img_trg_lab_b_std_t)*(img_trg_lab_b_t-img_trg_lab_b_mean_t) + img_src_lab_b_mean_t + img_new_t = lab_to_bgr( K.concatenate ( [img_new_lab_L_t, img_new_lab_a_t, img_new_lab_b_t], -1) ) + rct_f = K.function ( [img_src_t, img_trg_t], [ img_new_t ]) + img_new, = rct_f([ img_src[...,::-1], img_trg[...,::-1] ])[0][...,::-1] + """ + + """ + def bgr_to_lab(inp): + rgb_pixels = (inp / 12.92 * K.cast(inp <= 0.04045, dtype=K.floatx() ) ) \ + + K.pow( (inp + 0.055) / 1.055, 2.4) * K.cast(inp > 0.04045 , dtype=K.floatx() ) + + xyz_pixels = K.dot(rgb_pixels, K.constant( np.array([ + # X Y Z + [0.412453, 0.212671, 0.019334], # R + [0.357580, 0.715160, 0.119193], # G + [0.180423, 0.072169, 0.950227], # B + ]))) / K.constant([0.950456, 1.0, 1.088754]) + + fxfyfz_pixels = (xyz_pixels * 7.787 + 16/116 ) * K.cast(xyz_pixels <= 0.008856, dtype=K.floatx() ) \ + + K.pow(xyz_pixels, 1/3) * K.cast(xyz_pixels > 0.008856 , dtype=K.floatx() ) + + return K.dot(fxfyfz_pixels, K.constant( np.array([ + # l a b + [ 0.0, 500.0, 0.0], # fx + [116.0, -500.0, 200.0], # fy + [ 0.0, 0.0, -200.0], # fz + ]))) + K.constant([-16.0, 0.0, 0.0]) + + def lab_to_bgr(inp): + fxfyfz_pixels = K.dot(inp + K.constant([16.0, 0.0, 0.0]), K.constant(np.array([ + # fx fy fz + [1/116.0, 1/116.0, 1/116.0], # l + [1/500.0, 0.0, 0.0], # a + [ 0.0, 0.0, -1/200.0], # b + ]))) + + xyz_pixels = ( ( (fxfyfz_pixels - 16/116 ) / 7.787 ) * K.cast(fxfyfz_pixels <= 6/29, dtype=K.floatx() ) \ + + K.pow(fxfyfz_pixels, 3) * K.cast(fxfyfz_pixels > 6/29, dtype=K.floatx() ) \ + ) * K.constant([0.950456, 1.0, 1.088754]) + + rgb_pixels = K.dot(xyz_pixels, K.constant(np.array([ + # r g b + [ 3.2404542, -0.9692660, 0.0556434], # x + [-1.5371385, 1.8760108, -0.2040259], # y + [-0.4985314, 0.0415560, 1.0572252], # z + ]))) + rgb_pixels = K.clip(rgb_pixels, 0.0, 1.0) + + return (rgb_pixels * 12.92 * K.cast(rgb_pixels <= 0.0031308, dtype=K.floatx() ) ) \ + + ( (K.pow(rgb_pixels, 1/2.4) * 1.055) - 0.055) * K.cast(rgb_pixels > 0.0031308, dtype=K.floatx() ) + """ + + ###### + from core.leras import nn + exec( nn.import_all( device_config=nn.device.Config(force_gpu_idx=0) ), locals(), globals() ) + + shape = (64, 64, 3) + def encflow(x): + x = keras.layers.Conv2D(128, 5, strides=2, padding="same")(x) + x = keras.layers.Conv2D(256, 5, strides=2, padding="same")(x) + x = keras.layers.Dense(3)(keras.layers.Flatten()(x)) + return x + + def modelify(model_functor): + def func(tensor): + return keras.models.Model (tensor, model_functor(tensor)) + return func + + encoder = modelify (encflow)( keras.Input(shape) ) + + inp = x = keras.Input(shape) + code_t = encoder(x) + loss = K.mean(code_t) + + train_func = K.function ([inp],[loss], keras.optimizers.Adam().get_updates(loss, encoder.trainable_weights) ) + train_func ([ np.zeros ( (1, 64, 64, 3) ) ]) + + import code + code.interact(local=dict(globals(), **locals())) + + ########### + + image = cv2.imread(r'D:\DevelopPython\test\00000.png').astype(np.float32)# / 255.0 + image = (image - image.mean( (0,1)) ) / image.std( (0,1) ) + cv2.imshow ("", ((image +127)).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + ########### + + from core.leras import nn + exec( nn.import_all( device_config=nn.device.Config() ), locals(), globals() ) + + def gaussian_blur(radius=2.0): + def gaussian(x, mu, sigma): + return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2)) + + def make_kernel(sigma): + kernel_size = max(3, int(2 * 2 * sigma + 1)) + mean = np.floor(0.5 * kernel_size) + kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)]) + np_kernel = np.outer(kernel_1d, kernel_1d).astype(dtype=K.floatx()) + kernel = np_kernel / np.sum(np_kernel) + return kernel + + gauss_kernel = make_kernel(radius) + gauss_kernel = gauss_kernel[:, :,np.newaxis, np.newaxis] + + def func(input): + inputs = [ input[:,:,:,i:i+1] for i in range( K.int_shape( input )[-1] ) ] + + outputs = [] + for i in range(len(inputs)): + outputs += [ K.conv2d( inputs[i] , K.constant(gauss_kernel) , strides=(1,1), padding="same") ] + + return K.concatenate (outputs, axis=-1) + return func + + def style_loss_test(gaussian_blur_radius=0.0, loss_weight=1.0, wnd_size=0, step_size=1): + if gaussian_blur_radius > 0.0: + gblur = gaussian_blur(gaussian_blur_radius) + + def bgr_to_lab(inp): + linear_mask = K.cast(inp <= 0.04045, dtype=K.floatx() ) + exponential_mask = K.cast(inp > 0.04045, dtype=K.floatx() ) + rgb_pixels = (inp / 12.92 * linear_mask) + (((inp + 0.055) / 1.055) ** 2.4) * exponential_mask + rgb_to_xyz = K.constant([ + # X Y Z + [0.180423, 0.072169, 0.950227], # B + [0.357580, 0.715160, 0.119193], # G + [0.412453, 0.212671, 0.019334], # R + ]) + + xyz_pixels = K.dot(rgb_pixels, rgb_to_xyz) + xyz_normalized_pixels = xyz_pixels * [1/0.950456, 1.0, 1/1.088754] + + epsilon = 6/29 + linear_mask = K.cast(xyz_normalized_pixels <= (epsilon**3), dtype=K.floatx() ) + exponential_mask = K.cast(xyz_normalized_pixels > (epsilon**3), dtype=K.floatx() ) + fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask + + # convert to lab + fxfyfz_to_lab = K.constant([ + # l a b + [ 0.0, 500.0, 0.0], # fx + [116.0, -500.0, 200.0], # fy + [ 0.0, 0.0, -200.0], # fz + ]) + lab = K.dot(fxfyfz_pixels, fxfyfz_to_lab) + K.constant([-16.0, 0.0, 0.0]) + return lab[...,0:1], lab[...,1:2], lab[...,2:3] + + def sd(content, style, loss_weight): + content_nc = K.int_shape(content)[-1] + style_nc = K.int_shape(style)[-1] + if content_nc != style_nc: + raise Exception("style_loss() content_nc != style_nc") + + cl,ca,cb = bgr_to_lab(content) + sl,sa,sb = bgr_to_lab(style) + axes = [1,2] + cl_mean, cl_std = K.mean(cl, axis=axes, keepdims=True), K.var(cl, axis=axes, keepdims=True)+ 1e-5 + ca_mean, ca_std = K.mean(ca, axis=axes, keepdims=True), K.var(ca, axis=axes, keepdims=True)+ 1e-5 + cb_mean, cb_std = K.mean(cb, axis=axes, keepdims=True), K.var(cb, axis=axes, keepdims=True)+ 1e-5 + + sl_mean, sl_std = K.mean(sl, axis=axes, keepdims=True), K.var(sl, axis=axes, keepdims=True)+ 1e-5 + sa_mean, sa_std = K.mean(sa, axis=axes, keepdims=True), K.var(sa, axis=axes, keepdims=True)+ 1e-5 + sb_mean, sb_std = K.mean(sb, axis=axes, keepdims=True), K.var(sb, axis=axes, keepdims=True)+ 1e-5 + + + loss = K.mean( K.square( cl - ( (sl - sl_mean) * ( cl_std / sl_std ) + cl_mean ) ) ) + \ + K.mean( K.square( ca - ( (sa - sa_mean) * ( ca_std / sa_std ) + ca_mean ) ) ) + \ + K.mean( K.square( cb - ( (sb - sb_mean) * ( cb_std / sb_std ) + cb_mean ) ) ) + + + #import code + #code.interact(local=dict(globals(), **locals())) + + + return loss * ( loss_weight / float(content_nc) ) + + def func(target, style): + if wnd_size == 0: + if gaussian_blur_radius > 0.0: + return sd( gblur(target), gblur(style), loss_weight=loss_weight) + else: + return sd( target, style, loss_weight=loss_weight ) + return func + + image = cv2.imread(r'D:\DevelopPython\test\00000.png').astype(np.float32) / 255.0 + image2 = cv2.imread(r'D:\DevelopPython\test\00000.jpg').astype(np.float32) / 255.0 + + inp_t = Input ( (256,256,3) ) + inp2_t = Input ( (256,256,3) ) + + loss_t = style_loss_test(gaussian_blur_radius=16.0, loss_weight=0.01 )(inp_t, inp2_t) + + loss, = K.function ([inp_t,inp2_t], [loss_t]) ( [ image[np.newaxis,...], image2[np.newaxis,...] ] ) + + + + import code + code.interact(local=dict(globals(), **locals())) + + + ########### + + image = cv2.imread(r'D:\DevelopPython\test\00000.png').astype(np.float32) / 255.0 + + from core.imagelib import LinearMotionBlur + + image = LinearMotionBlur(image, 5, 135) + + cv2.imshow("", (image*255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + ########### + + + from core.leras import nn + exec( nn.import_all( device_config=nn.device.Config(force_gpu_idx=0) ), locals(), globals() ) + + from core.imagelib import DCSCN + + dc = DCSCN() + + image = cv2.imread(r'D:\DevelopPython\test\sr1.png').astype(np.float32) / 255.0 + + image_up = dc.upscale(image) + cv2.imwrite (r'D:\DevelopPython\test\sr1_result.png', (image_up*255).astype(np.uint8) ) + + + import code + code.interact(local=dict(globals(), **locals())) + + ########### + from core.leras import nn + exec( nn.import_all( device_config=nn.device.Config(force_gpu_idx=0) ), locals(), globals() ) + PMLTile = nn.PMLTile + PMLK = nn.PMLK + + shape = (64, 64, 3) + def encflow(x): + x = LeakyReLU()(keras.layers.Conv2D(128, 5, strides=2, padding="same")(x)) + x = keras.layers.Conv2D(256, 5, strides=2, padding="same")(x) + x = keras.layers.Conv2D(512, 5, strides=2, padding="same")(x) + x = keras.layers.Conv2D(1024,5, strides=2, padding="same")(x) + x = keras.layers.Dense(64)(keras.layers.Flatten()(x)) + x = keras.layers.Dense(4 * 4 * 1024)(x) + x = keras.layers.Reshape((4, 4, 1024))(x) + x = keras.layers.Conv2DTranspose(512, 3, strides=2, padding="same")(x) + return x + + def decflow(x): + x = x[0] + x = LeakyReLU()(keras.layers.Conv2DTranspose(512, 3, strides=2, padding="same")(x)) + x = keras.layers.Conv2DTranspose(256, 3, strides=2, padding="same")(x) + x = keras.layers.Conv2DTranspose(128, 3, strides=2, padding="same")(x) + x = keras.layers.Conv2D(3, 5, strides=1, padding="same")(x) + return x + + def modelify(model_functor): + def func(tensor): + return keras.models.Model (tensor, model_functor(tensor)) + return func + + encoder = modelify (encflow)( keras.Input(shape) ) + decoder1 = modelify (decflow)( [ Input(K.int_shape(x)[1:]) for x in encoder.outputs ] ) + decoder2 = modelify (decflow)( [ Input(K.int_shape(x)[1:]) for x in encoder.outputs ] ) + + inp = x = keras.Input(shape) + code = encoder(x) + x1 = decoder1(code) + x2 = decoder2(code) + + loss = K.mean(K.square(inp-x1))+K.mean(K.square(inp-x2)) + train_func = K.function ([inp],[loss], keras.optimizers.Adam().get_updates(loss, encoder.trainable_weights+decoder1.trainable_weights+decoder2.trainable_weights) ) + view_func1 = K.function ([inp],[x1]) + view_func2 = K.function ([inp],[x2]) + + for i in range(100): + print("Loop %i" % i) + data = np.zeros ( (1, 64, 64, 3) ) + train_func ( [data]) + view_func1 ([data]) + view_func2 ([data]) + print("Saving weights") + encoder.save_weights(r"D:\DevelopPython\test\testweights.h5") + decoder1.save_weights(r"D:\DevelopPython\test\testweights1.h5") + decoder2.save_weights(r"D:\DevelopPython\test\testweights2.h5") + + import code + code.interact(local=dict(globals(), **locals())) + + + from core.leras import nn + exec( nn.import_all( device_config=nn.device.Config() ), locals(), globals() ) + PMLTile = nn.PMLTile + PMLK = nn.PMLK + + import tensorflow as tf + tfkeras = tf.keras + tfK = tfkeras.backend + lin = np.broadcast_to( np.linspace(1,10,10), (10,10) ) + + #a = np.broadcast_to ( np.concatenate( [np.linspace(1,4,4), np.linspace(5,1,5)] ), (9,9) ) + #a = (a + a.T)-1 + #a = a[np.newaxis,:,:,np.newaxis] + + class ReflectionPadding2D(): + class TileOP(PMLTile.Operation): + def __init__(self, input, hpad, w_pad): + if K.image_data_format() == 'channels_last': + if input.shape.ndims == 4: + H, W = input.shape.dims[1:3] + if (type(H) == int and hpad >= H) or \ + (type(W) == int and w_pad >= W): + raise ValueError("Paddings must be less than dimensions.") + + c = """ function (I[B, H, W, C] ) -> (O) {{ + WE = W + {w_pad}*2; + HE = H + {hpad}*2; + """.format(hpad=hpad, w_pad=w_pad) + if w_pad > 0: + c += """ + LEFT_PAD [b, h, w , c : B, H, WE, C ] = =(I[b, h, {w_pad}-w, c]), w < {w_pad} ; + HCENTER [b, h, w , c : B, H, WE, C ] = =(I[b, h, w-{w_pad}, c]), w < W+{w_pad}-1 ; + RIGHT_PAD[b, h, w , c : B, H, WE, C ] = =(I[b, h, 2*W - (w-{w_pad}) -2, c]); + LCR = LEFT_PAD+HCENTER+RIGHT_PAD; + """.format(hpad=hpad, w_pad=w_pad) + else: + c += "LCR = I;" + + if hpad > 0: + c += """ + TOP_PAD [b, h, w , c : B, HE, WE, C ] = =(LCR[b, {hpad}-h, w, c]), h < {hpad}; + VCENTER [b, h, w , c : B, HE, WE, C ] = =(LCR[b, h-{hpad}, w, c]), h < H+{hpad}-1 ; + BOTTOM_PAD[b, h, w , c : B, HE, WE, C ] = =(LCR[b, 2*H - (h-{hpad}) -2, w, c]); + TVB = TOP_PAD+VCENTER+BOTTOM_PAD; + """.format(hpad=hpad, w_pad=w_pad) + else: + c += "TVB = LCR;" + + c += "O = TVB; }" + + inp_dims = input.shape.dims + out_dims = (inp_dims[0], inp_dims[1]+hpad*2, inp_dims[2]+w_pad*2, inp_dims[3]) + else: + raise NotImplemented + else: + raise NotImplemented + + super(ReflectionPadding2D.TileOP, self).__init__(c, [('I', input) ], + [('O', PMLTile.Shape(input.shape.dtype, out_dims ) )]) + + def __init__(self, hpad, w_pad): + self.hpad, self.w_pad = hpad, w_pad + + def __call__(self, inp): + return ReflectionPadding2D.TileOP.function(inp, self.hpad, self.w_pad) + + sh_w = 9 + sh_h = 9 + sh = (1,sh_h,sh_w,1) + w_pad, hpad = 8,8 + + + t1 = tfK.placeholder (sh ) + t2 = tf.pad(t1, [ [0,0], [hpad,hpad], [w_pad,w_pad], [0,0] ], 'REFLECT') + + pt1 = K.placeholder (sh ) + pt2 = ReflectionPadding2D(hpad, w_pad )(pt1) + + tfunc = tfK.function ([t1],[t2]) + ptfunc = K.function([pt1],[pt2]) + + for i in range(100): + a = np.random.uniform( size=sh) + # a = np.broadcast_to ( np.concatenate( [np.linspace(1,4,4), np.linspace(5,1,5)] ), (9,9) ) + # a = np.broadcast_to (np.linspace(1,9,9), (9,9) ) + # a = (a + a.T)-1 + # a = a[np.newaxis,:,:,np.newaxis] + + t = tfunc([a]) [0][0,:,:,0] + pt = ptfunc ([a])[0][0,:,:,0] + if np.allclose(t, pt): + print ("all_close = True") + else: + print ("all_close = False\r\n") + print(t,"") + print(pt) + import code + code.interact(local=dict(globals(), **locals())) + + image = cv2.imread(r'D:\DevelopPython\test\00000.png').astype(np.float32) / 255.0 + image = np.expand_dims (image, 0) + image_shape = image.shape + + image2 = cv2.imread(r'D:\DevelopPython\test\00001.png').astype(np.float32) / 255.0 + image2 = np.expand_dims (image2, 0) + image2_shape = image2.shape + + + # class ReflectionPadding2D(): + # def __init__(self, hpad, w_pad): + # self.hpad, self.w_pad = hpad, w_pad + + # def __call__(self, inp): + # hpad, w_pad = self.hpad, self.w_pad + # if K.image_data_format() == 'channels_last': + # if inp.shape.ndims == 4: + # w = K.concatenate ([ inp[:,:,w_pad:0:-1,:], + # inp, + # inp[:,:,-2:-w_pad-2:-1,:] ], axis=2 ) + + # h = K.concatenate ([ w[:,hpad:0:-1,:,:], + # w, + # w[:,-2:-hpad-2:-1,:,:] ], axis=1 ) + + # return h + # else: + # raise NotImplemented + # else: + # raise NotImplemented + #f = ReflectionPadding2D.function(t, [1,65,65,1], [1,1,1,1], [1,1,1,1]) + + #x, = K.function ([t],[f]) ([ image ]) + + #image = np.random.uniform ( size=(1,256,256,3) ) + #image2 = np.random.uniform ( size=(1,256,256,3) ) + + #t1 = K.placeholder ( (None,) + image_shape[1:], name="t1" ) + #t2 = K.placeholder ( (None,None,None,None), name="t2" ) + + #l1_t = DSSIMObjective() (t1,t2 ) + #l1, = K.function([t1, t2],[l1_t]) ([image, image2]) + # + #print (l1) + #t[:,0:64,64::2,:].source.op.code + """ +t1[:,0:64,64::2,128:] + +function (I[N0, N1, N2, N3]) -> (O) + {\n + Start0 = max(0, 0); + Offset0 = Start0; + O[ + i0, i1, i2, i3: + ((N0 - (Start0)) + 1 - 1)/1, + (64 + 1 - 1)/1, + (192 + 2 - 1)/2, + (-125 + 1 - 1)/1] + = + =(I[1*i0+Offset0, 1*i1+0, 2*i2+64, 1*i3+128]); + + } + + """ + import code + code.interact(local=dict(globals(), **locals())) + + + + + + + ''' + >>> t[:,0:64,64::2,:].source.op.code +function (I[N0, N1, N2, N3]) -> (O) { + +O[i0, i1, i2, i3: (1 + 1 - 1)/1, (64 + 1 - 1)/1, (64 + 2 - 1)/2, (1 + 1 - 1)/1] = + =(I[1*i0+0, 1*i1+0, 2*i2+64, 1*i3+0]); + + + Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size, + int64 dilation_rate, int64 stride, + Padding padding_type, int64* output_size, + int64* padding_before, + int64* padding_after) { + if (stride <= 0) { + return errors::InvalidArgument("Stride must be > 0, but got ", stride); + } + if (dilation_rate < 1) { + return errors::InvalidArgument("Dilation rate must be >= 1, but got ", + dilation_rate); + } + + // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2. + int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1; + switch (padding_type) { + case Padding::VALID: + *output_size = (input_size - effective_filter_size + stride) / stride; + *padding_before = *padding_after = 0; + break; + case Padding::EXPLICIT: + *output_size = (input_size + *padding_before + *padding_after - + effective_filter_size + stride) / + stride; + break; + case Padding::SAME: + *output_size = (input_size + stride - 1) / stride; + const int64 padding_needed = + std::max(int64{0}, (*output_size - 1) * stride + + effective_filter_size - input_size); + // For odd values of total padding, add more padding at the 'right' + // side of the given dimension. + *padding_before = padding_needed / 2; + *padding_after = padding_needed - *padding_before; + break; + } + if (*output_size < 0) { + return errors::InvalidArgument( + "Computed output size would be negative: ", *output_size, + " [input_size: ", input_size, + ", effective_filter_size: ", effective_filter_size, + ", stride: ", stride, "]"); + } + return Status::OK(); + } + ''' + class ExtractImagePatchesOP(PMLTile.Operation): + def __init__(self, input, ksizes, strides, rates, padding='valid'): + + batch, in_rows, in_cols, depth = input.shape.dims + + ksize_rows = ksizes[1]; + ksize_cols = ksizes[2]; + + stride_rows = strides[1]; + stride_cols = strides[2]; + + rate_rows = rates[1]; + rate_cols = rates[2]; + + ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1); + ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1); + + #if padding == 'valid': + + out_rows = (in_rows - ksize_rows_eff + stride_rows) / stride_rows; + out_cols = (in_cols - ksize_cols_eff + stride_cols) / stride_cols; + + out_sizes = (batch, out_rows, out_cols, ksize_rows * ksize_cols * depth); + + + + B, H, W, CI = input.shape.dims + + RATE = PMLK.constant ([1,rate,rate,1], dtype=PMLK.floatx() ) + + #print (target_dims) + code = """function (I[B, {H}, {W}, {CI} ], RATES[RB, RH, RW, RC] ) -> (O) { + + O[b, {wnd_size}, {wnd_size}, ] = =(I[b, h, w, ci]); + + }""".format(H=H, W=W, CI=CI, RATES=rates, wnd_size=wnd_size) + + super(ExtractImagePatchesOP, self).__init__(code, [('I', input) ], + [('O', PMLTile.Shape(input.shape.dtype, out_sizes ) )]) + + + + + f = ExtractImagePatchesOP.function(t, [1,65,65,1], [1,1,1,1], [1,1,1,1]) + + x, = K.function ([t],[f]) ([ image ]) + print(x.shape) + + import code + code.interact(local=dict(globals(), **locals())) + + #from core.leras import nn + #exec( nn.import_all( device_config=nn.device.Config(cpu_only=True) ), locals(), globals() ) + # + #rnd_data = np.random.uniform( size=(1,64,64,3) ) + #bgr_shape = (64, 64, 3) + #input_layer = Input(bgr_shape) + #x = input_layer + #x = Conv2D(64, 3, padding='same')(x) + #x = Conv2D(128, 3, padding='same')(x) + #x = Conv2D(256, 3, padding='same')(x) + #x = Conv2D(512, 3, padding='same')(x) + #x = Conv2D(1024, 3, padding='same')(x) + #x = Conv2D(3, 3, padding='same')(x) + # + #model = Model (input_layer, [x]) + #model.compile(optimizer=Adam(), loss='mse') + #model.train_on_batch ([rnd_data], [rnd_data]) + ##model.save (r"D:\DevelopPython\test\test_model.h5") + # + #import code + #code.interact(local=dict(globals(), **locals())) + + + + import ffmpeg + + path = Path('D:/deepfacelab/test') + input_path = str(path / 'input.mp4') + #stream = ffmpeg.input(str(path / 'input.mp4') ) + #stream = ffmpeg.hflip(stream) + #stream = ffmpeg.output(stream, str(path / 'output.mp4') ) + #ffmpeg.run(stream) + ( + ffmpeg + .input( str(path / 'input.mp4')) + .hflip() + .output( str(path / 'output.mp4'), r="23000/1001" ) + .run() + ) + + #probe = ffmpeg.probe(str(path / 'input.mp4')) + + #out, _ = ( + # ffmpeg + # .input( input_path ) + # .output('pipe:', format='rawvideo', pix_fmt='rgb24') + # .run(capture_stdout=True) + #) + #video = ( + # np + # .frombuffer(out, np.uint8) + # .reshape([-1, height, width, 3]) + #) + + import code + code.interact(local=dict(globals(), **locals())) + + + + + from core.leras import nn + exec( nn.import_all(), locals(), globals() ) + + #ch = 3 + #def softmax(x, axis=-1): #from K numpy backend + # y = np.exp(x - np.max(x, axis, keepdims=True)) + # return y / np.sum(y, axis, keepdims=True) + # + #def gauss_kernel(size, sigma): + # coords = np.arange(0,size, dtype=K.floatx() ) + # coords -= (size - 1 ) / 2.0 + # g = coords**2 + # g *= ( -0.5 / (sigma**2) ) + # g = np.reshape (g, (1,-1)) + np.reshape(g, (-1,1) ) + # g = np.reshape (g, (1,-1)) + # g = softmax(g) + # g = np.reshape (g, (size, size, 1, 1)) + # g = np.tile (g, (1,1,ch, size*size*ch)) + # return K.constant(g, dtype=K.floatx() ) + # + ##kernel = gauss_kernel(11,1.5) + #kernel = K.constant( np.ones ( (246,246, 3, 1) ) , dtype=K.floatx() ) + ##g = np.eye(9).reshape((3, 3, 1, 9)) + ##g = np.tile (g, (1,1,3,1)) + ##kernel = K.constant(g , dtype=K.floatx() ) + # + #def reducer(x): + # shape = K.shape(x) + # x = K.reshape(x, (-1, shape[-3] , shape[-2], shape[-1]) ) + # + # y = K.depthwise_conv2d(x, kernel, strides=(1, 1), padding='valid') + # + # y_shape = K.shape(y) + # return y#K.reshape(y, (shape[0], y_shape[1], y_shape[2], y_shape[3] ) ) + + image = cv2.imread('D:\\DeepFaceLab\\test\\00000.png').astype(np.float32) / 255.0 + image = cv2.resize ( image, (128,128) ) + + image = cv2.cvtColor (image, cv2.COLOR_BGR2GRAY) + image = np.expand_dims (image, -1) + image_shape = image.shape + + image2 = cv2.imread('D:\\DeepFaceLab\\test\\00001.png').astype(np.float32) / 255.0 + #image2 = cv2.cvtColor (image2, cv2.COLOR_BGR2GRAY) + #image2 = np.expand_dims (image2, -1) + image2_shape = image2.shape + + image_tensor = K.placeholder(shape=[ 1, image_shape[0], image_shape[1], image_shape[2] ], dtype="float32" ) + image2_tensor = K.placeholder(shape=[ 1, image_shape[0], image_shape[1], image_shape[2] ], dtype="float32" ) + + #loss = reducer(image_tensor) + #loss = K.reshape (loss, (-1,246,246, 11,11,3) ) + tf = nn.tf + + sh = K.int_shape(image_tensor)[1] + wnd_size = 16 + step_size = 8 + k = (sh-wnd_size) // step_size + 1 + + loss = tf.image.extract_image_patches(image_tensor, [1,k,k,1], [1,1,1,1], [1,step_size,step_size,1], 'VALID') + print(loss) + + f = K.function ( [image_tensor], [loss] ) + x = f ( [ np.expand_dims(image,0) ] )[0][0] + + import code + code.interact(local=dict(globals(), **locals())) + + for i in range( x.shape[2] ): + img = x[:,:,i:i+1] + + cv2.imshow('', (img*255).astype(np.uint8) ) + cv2.waitKey(0) + + #for i in range( len(x) ): + # for j in range ( len(x) ): + # img = x[i,j] + # import code + # code.interact(local=dict(globals(), **locals())) + # + # cv2.imshow('', (x[i,j]*255).astype(np.uint8) ) + # cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + from core.leras import nn + exec( nn.import_all(), locals(), globals() ) + + PNet_Input = Input ( (None, None,3) ) + x = PNet_Input + x = Conv2D (10, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv1")(x) + x = PReLU (shared_axes=[1,2], name="PReLU1" )(x) + x = MaxPooling2D( pool_size=(2,2), strides=(2,2), padding='same' ) (x) + x = Conv2D (16, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv2")(x) + x = PReLU (shared_axes=[1,2], name="PReLU2" )(x) + x = Conv2D (32, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv3")(x) + x = PReLU (shared_axes=[1,2], name="PReLU3" )(x) + prob = Conv2D (2, kernel_size=(1,1), strides=(1,1), padding='valid', name="conv41")(x) + prob = Softmax()(prob) + x = Conv2D (4, kernel_size=(1,1), strides=(1,1), padding='valid', name="conv42")(x) + + PNet_model = Model(PNet_Input, [x,prob] ) + PNet_model.load_weights ( (Path(mtcnn.__file__).parent / 'mtcnn_pnet.h5').__str__() ) + + RNet_Input = Input ( (24, 24, 3) ) + x = RNet_Input + x = Conv2D (28, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv1")(x) + x = PReLU (shared_axes=[1,2], name="prelu1" )(x) + x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='same' ) (x) + x = Conv2D (48, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv2")(x) + x = PReLU (shared_axes=[1,2], name="prelu2" )(x) + x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='valid' ) (x) + x = Conv2D (64, kernel_size=(2,2), strides=(1,1), padding='valid', name="conv3")(x) + x = PReLU (shared_axes=[1,2], name="prelu3" )(x) + x = Lambda ( lambda x: K.reshape (x, (-1, np.prod(K.int_shape(x)[1:]),) ), output_shape=(np.prod(K.int_shape(x)[1:]),) ) (x) + x = Dense (128, name='conv4')(x) + x = PReLU (name="prelu4" )(x) + prob = Dense (2, name='conv51')(x) + prob = Softmax()(prob) + x = Dense (4, name='conv52')(x) + RNet_model = Model(RNet_Input, [x,prob] ) + RNet_model.load_weights ( (Path(mtcnn.__file__).parent / 'mtcnn_rnet.h5').__str__() ) + + ONet_Input = Input ( (48, 48, 3) ) + x = ONet_Input + x = Conv2D (32, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv1")(x) + x = PReLU (shared_axes=[1,2], name="prelu1" )(x) + x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='same' ) (x) + x = Conv2D (64, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv2")(x) + x = PReLU (shared_axes=[1,2], name="prelu2" )(x) + x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='valid' ) (x) + x = Conv2D (64, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv3")(x) + x = PReLU (shared_axes=[1,2], name="prelu3" )(x) + x = MaxPooling2D( pool_size=(2,2), strides=(2,2), padding='same' ) (x) + x = Conv2D (128, kernel_size=(2,2), strides=(1,1), padding='valid', name="conv4")(x) + x = PReLU (shared_axes=[1,2], name="prelu4" )(x) + x = Lambda ( lambda x: K.reshape (x, (-1, np.prod(K.int_shape(x)[1:]),) ), output_shape=(np.prod(K.int_shape(x)[1:]),) ) (x) + x = Dense (256, name='conv5')(x) + x = PReLU (name="prelu5" )(x) + prob = Dense (2, name='conv61')(x) + prob = Softmax()(prob) + x1 = Dense (4, name='conv62')(x) + x2 = Dense (10, name='conv63')(x) + ONet_model = Model(ONet_Input, [x1,x2,prob] ) + ONet_model.load_weights ( (Path(mtcnn.__file__).parent / 'mtcnn_onet.h5').__str__() ) + + pnet_fun = K.function ( PNet_model.inputs, PNet_model.outputs ) + rnet_fun = K.function ( RNet_model.inputs, RNet_model.outputs ) + onet_fun = K.function ( ONet_model.inputs, ONet_model.outputs ) + + pnet_test_data = np.random.uniform ( size=(1, 64,64,3) ) + pnet_result1, pnet_result2 = pnet_fun ([pnet_test_data]) + + rnet_test_data = np.random.uniform ( size=(1,24,24,3) ) + rnet_result1, rnet_result2 = rnet_fun ([rnet_test_data]) + + onet_test_data = np.random.uniform ( size=(1,48,48,3) ) + onet_result1, onet_result2, onet_result3 = onet_fun ([onet_test_data]) + + import code + code.interact(local=dict(globals(), **locals())) + + from core.leras import nn + #exec( nn.import_all( nn.device.Config(cpu_only=True) ), locals(), globals() )# nn.device.Config(cpu_only=True) + exec( nn.import_all(), locals(), globals() )# nn.device.Config(cpu_only=True) + + #det1_Input = Input ( (None, None,3) ) + #x = det1_Input + #x = Conv2D (10, kernel_size=(3,3), strides=(1,1), padding='valid')(x) + # + #import code + #code.interact(local=dict(globals(), **locals())) + + tf = nn.tf + tf_session = nn.tf_sess + + with tf.variable_scope('pnet2'): + data = tf.placeholder(tf.float32, (None,None,None,3), 'input') + pnet2 = mtcnn.PNet(tf, {'data':data}) + pnet2.load( (Path(mtcnn.__file__).parent / 'det1.npy').__str__(), tf_session) + with tf.variable_scope('rnet2'): + data = tf.placeholder(tf.float32, (None,24,24,3), 'input') + rnet2 = mtcnn.RNet(tf, {'data':data}) + rnet2.load( (Path(mtcnn.__file__).parent / 'det2.npy').__str__(), tf_session) + with tf.variable_scope('onet2'): + data = tf.placeholder(tf.float32, (None,48,48,3), 'input') + onet2 = mtcnn.ONet(tf, {'data':data}) + onet2.load( (Path(mtcnn.__file__).parent / 'det3.npy').__str__(), tf_session) + + + + pnet_fun = K.function([pnet2.layers['data']],[pnet2.layers['conv4-2'], pnet2.layers['prob1']]) + rnet_fun = K.function([rnet2.layers['data']],[rnet2.layers['conv5-2'], rnet2.layers['prob1']]) + onet_fun = K.function([onet2.layers['data']],[onet2.layers['conv6-2'], onet2.layers['conv6-3'], onet2.layers['prob1']]) + + det1_dict = np.load((Path(mtcnn.__file__).parent / 'det1.npy').__str__(), encoding='latin1').item() + det2_dict = np.load((Path(mtcnn.__file__).parent / 'det2.npy').__str__(), encoding='latin1').item() + det3_dict = np.load((Path(mtcnn.__file__).parent / 'det3.npy').__str__(), encoding='latin1').item() + + PNet_Input = Input ( (None, None,3) ) + x = PNet_Input + x = Conv2D (10, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv1")(x) + x = PReLU (shared_axes=[1,2], name="PReLU1" )(x) + x = MaxPooling2D( pool_size=(2,2), strides=(2,2), padding='same' ) (x) + x = Conv2D (16, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv2")(x) + x = PReLU (shared_axes=[1,2], name="PReLU2" )(x) + x = Conv2D (32, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv3")(x) + x = PReLU (shared_axes=[1,2], name="PReLU3" )(x) + prob = Conv2D (2, kernel_size=(1,1), strides=(1,1), padding='valid', name="conv41")(x) + prob = Softmax()(prob) + x = Conv2D (4, kernel_size=(1,1), strides=(1,1), padding='valid', name="conv42")(x) + + + PNet_model = Model(PNet_Input, [x,prob] ) + + #PNet_model.load_weights ( (Path(mtcnn.__file__).parent / 'mtcnn_pnet.h5').__str__() ) + PNet_model.get_layer("conv1").set_weights ( [ det1_dict['conv1']['weights'], det1_dict['conv1']['biases'] ] ) + PNet_model.get_layer("PReLU1").set_weights ( [ np.reshape(det1_dict['PReLU1']['alpha'], (1,1,-1)) ] ) + PNet_model.get_layer("conv2").set_weights ( [ det1_dict['conv2']['weights'], det1_dict['conv2']['biases'] ] ) + PNet_model.get_layer("PReLU2").set_weights ( [ np.reshape(det1_dict['PReLU2']['alpha'], (1,1,-1)) ] ) + PNet_model.get_layer("conv3").set_weights ( [ det1_dict['conv3']['weights'], det1_dict['conv3']['biases'] ] ) + PNet_model.get_layer("PReLU3").set_weights ( [ np.reshape(det1_dict['PReLU3']['alpha'], (1,1,-1)) ] ) + PNet_model.get_layer("conv41").set_weights ( [ det1_dict['conv4-1']['weights'], det1_dict['conv4-1']['biases'] ] ) + PNet_model.get_layer("conv42").set_weights ( [ det1_dict['conv4-2']['weights'], det1_dict['conv4-2']['biases'] ] ) + PNet_model.save ( (Path(mtcnn.__file__).parent / 'mtcnn_pnet.h5').__str__() ) + + pnet_test_data = np.random.uniform ( size=(1, 64,64,3) ) + pnet_result1, pnet_result2 = pnet_fun ([pnet_test_data]) + pnet2_result1, pnet2_result2 = K.function ( PNet_model.inputs, PNet_model.outputs ) ([pnet_test_data]) + + pnet_diff1 = np.mean ( np.abs(pnet_result1 - pnet2_result1) ) + pnet_diff2 = np.mean ( np.abs(pnet_result2 - pnet2_result2) ) + print ("pnet_diff1 = %f, pnet_diff2 = %f, " % (pnet_diff1, pnet_diff2) ) + + RNet_Input = Input ( (24, 24, 3) ) + x = RNet_Input + x = Conv2D (28, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv1")(x) + x = PReLU (shared_axes=[1,2], name="prelu1" )(x) + x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='same' ) (x) + x = Conv2D (48, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv2")(x) + x = PReLU (shared_axes=[1,2], name="prelu2" )(x) + x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='valid' ) (x) + x = Conv2D (64, kernel_size=(2,2), strides=(1,1), padding='valid', name="conv3")(x) + x = PReLU (shared_axes=[1,2], name="prelu3" )(x) + x = Lambda ( lambda x: K.reshape (x, (-1, np.prod(K.int_shape(x)[1:]),) ), output_shape=(np.prod(K.int_shape(x)[1:]),) ) (x) + x = Dense (128, name='conv4')(x) + x = PReLU (name="prelu4" )(x) + prob = Dense (2, name='conv51')(x) + prob = Softmax()(prob) + x = Dense (4, name='conv52')(x) + + RNet_model = Model(RNet_Input, [x,prob] ) + + #RNet_model.load_weights ( (Path(mtcnn.__file__).parent / 'mtcnn_rnet.h5').__str__() ) + RNet_model.get_layer("conv1").set_weights ( [ det2_dict['conv1']['weights'], det2_dict['conv1']['biases'] ] ) + RNet_model.get_layer("prelu1").set_weights ( [ np.reshape(det2_dict['prelu1']['alpha'], (1,1,-1)) ] ) + RNet_model.get_layer("conv2").set_weights ( [ det2_dict['conv2']['weights'], det2_dict['conv2']['biases'] ] ) + RNet_model.get_layer("prelu2").set_weights ( [ np.reshape(det2_dict['prelu2']['alpha'], (1,1,-1)) ] ) + RNet_model.get_layer("conv3").set_weights ( [ det2_dict['conv3']['weights'], det2_dict['conv3']['biases'] ] ) + RNet_model.get_layer("prelu3").set_weights ( [ np.reshape(det2_dict['prelu3']['alpha'], (1,1,-1)) ] ) + RNet_model.get_layer("conv4").set_weights ( [ det2_dict['conv4']['weights'], det2_dict['conv4']['biases'] ] ) + RNet_model.get_layer("prelu4").set_weights ( [ det2_dict['prelu4']['alpha'] ] ) + RNet_model.get_layer("conv51").set_weights ( [ det2_dict['conv5-1']['weights'], det2_dict['conv5-1']['biases'] ] ) + RNet_model.get_layer("conv52").set_weights ( [ det2_dict['conv5-2']['weights'], det2_dict['conv5-2']['biases'] ] ) + RNet_model.save ( (Path(mtcnn.__file__).parent / 'mtcnn_rnet.h5').__str__() ) + + #import code + #code.interact(local=dict(globals(), **locals())) + + rnet_test_data = np.random.uniform ( size=(1,24,24,3) ) + rnet_result1, rnet_result2 = rnet_fun ([rnet_test_data]) + rnet2_result1, rnet2_result2 = K.function ( RNet_model.inputs, RNet_model.outputs ) ([rnet_test_data]) + + rnet_diff1 = np.mean ( np.abs(rnet_result1 - rnet2_result1) ) + rnet_diff2 = np.mean ( np.abs(rnet_result2 - rnet2_result2) ) + print ("rnet_diff1 = %f, rnet_diff2 = %f, " % (rnet_diff1, rnet_diff2) ) + + + ################# + ''' + (self.feed('data') #pylint: disable=no-value-for-parameter, no-member + .conv(3, 3, 32, 1, 1, padding='VALID', relu=False, name='conv1') + .prelu(name='prelu1') + .max_pool(3, 3, 2, 2, name='pool1') + .conv(3, 3, 64, 1, 1, padding='VALID', relu=False, name='conv2') + .prelu(name='prelu2') + .max_pool(3, 3, 2, 2, padding='VALID', name='pool2') + .conv(3, 3, 64, 1, 1, padding='VALID', relu=False, name='conv3') + .prelu(name='prelu3') + .max_pool(2, 2, 2, 2, name='pool3') + .conv(2, 2, 128, 1, 1, padding='VALID', relu=False, name='conv4') + .prelu(name='prelu4') + .fc(256, relu=False, name='conv5') + .prelu(name='prelu5') + .fc(2, relu=False, name='conv6-1') + .softmax(1, name='prob1')) + + (self.feed('prelu5') #pylint: disable=no-value-for-parameter + .fc(4, relu=False, name='conv6-2')) + + (self.feed('prelu5') #pylint: disable=no-value-for-parameter + .fc(10, relu=False, name='conv6-3')) + ''' + ONet_Input = Input ( (48, 48, 3) ) + x = ONet_Input + x = Conv2D (32, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv1")(x) + x = PReLU (shared_axes=[1,2], name="prelu1" )(x) + x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='same' ) (x) + x = Conv2D (64, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv2")(x) + x = PReLU (shared_axes=[1,2], name="prelu2" )(x) + x = MaxPooling2D( pool_size=(3,3), strides=(2,2), padding='valid' ) (x) + x = Conv2D (64, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv3")(x) + x = PReLU (shared_axes=[1,2], name="prelu3" )(x) + x = MaxPooling2D( pool_size=(2,2), strides=(2,2), padding='same' ) (x) + x = Conv2D (128, kernel_size=(2,2), strides=(1,1), padding='valid', name="conv4")(x) + x = PReLU (shared_axes=[1,2], name="prelu4" )(x) + x = Lambda ( lambda x: K.reshape (x, (-1, np.prod(K.int_shape(x)[1:]),) ), output_shape=(np.prod(K.int_shape(x)[1:]),) ) (x) + x = Dense (256, name='conv5')(x) + x = PReLU (name="prelu5" )(x) + prob = Dense (2, name='conv61')(x) + prob = Softmax()(prob) + x1 = Dense (4, name='conv62')(x) + x2 = Dense (10, name='conv63')(x) + + ONet_model = Model(ONet_Input, [x1,x2,prob] ) + + #ONet_model.load_weights ( (Path(mtcnn.__file__).parent / 'mtcnn_onet.h5').__str__() ) + ONet_model.get_layer("conv1").set_weights ( [ det3_dict['conv1']['weights'], det3_dict['conv1']['biases'] ] ) + ONet_model.get_layer("prelu1").set_weights ( [ np.reshape(det3_dict['prelu1']['alpha'], (1,1,-1)) ] ) + ONet_model.get_layer("conv2").set_weights ( [ det3_dict['conv2']['weights'], det3_dict['conv2']['biases'] ] ) + ONet_model.get_layer("prelu2").set_weights ( [ np.reshape(det3_dict['prelu2']['alpha'], (1,1,-1)) ] ) + ONet_model.get_layer("conv3").set_weights ( [ det3_dict['conv3']['weights'], det3_dict['conv3']['biases'] ] ) + ONet_model.get_layer("prelu3").set_weights ( [ np.reshape(det3_dict['prelu3']['alpha'], (1,1,-1)) ] ) + ONet_model.get_layer("conv4").set_weights ( [ det3_dict['conv4']['weights'], det3_dict['conv4']['biases'] ] ) + ONet_model.get_layer("prelu4").set_weights ( [ np.reshape(det3_dict['prelu4']['alpha'], (1,1,-1)) ] ) + ONet_model.get_layer("conv5").set_weights ( [ det3_dict['conv5']['weights'], det3_dict['conv5']['biases'] ] ) + ONet_model.get_layer("prelu5").set_weights ( [ det3_dict['prelu5']['alpha'] ] ) + ONet_model.get_layer("conv61").set_weights ( [ det3_dict['conv6-1']['weights'], det3_dict['conv6-1']['biases'] ] ) + ONet_model.get_layer("conv62").set_weights ( [ det3_dict['conv6-2']['weights'], det3_dict['conv6-2']['biases'] ] ) + ONet_model.get_layer("conv63").set_weights ( [ det3_dict['conv6-3']['weights'], det3_dict['conv6-3']['biases'] ] ) + ONet_model.save ( (Path(mtcnn.__file__).parent / 'mtcnn_onet.h5').__str__() ) + + onet_test_data = np.random.uniform ( size=(1,48,48,3) ) + onet_result1, onet_result2, onet_result3 = onet_fun ([onet_test_data]) + onet2_result1, onet2_result2, onet2_result3 = K.function ( ONet_model.inputs, ONet_model.outputs ) ([onet_test_data]) + + onet_diff1 = np.mean ( np.abs(onet_result1 - onet2_result1) ) + onet_diff2 = np.mean ( np.abs(onet_result2 - onet2_result2) ) + onet_diff3 = np.mean ( np.abs(onet_result3 - onet2_result3) ) + print ("onet_diff1 = %f, onet_diff2 = %f, , onet_diff3 = %f " % (onet_diff1, onet_diff2, onet_diff3) ) + + + import code + code.interact(local=dict(globals(), **locals())) + + + + + + import code + code.interact(local=dict(globals(), **locals())) + + + + + + + #class MTCNNSoftmax(keras.Layer): + # + # def __init__(self, axis=-1, **kwargs): + # super(MTCNNSoftmax, self).__init__(**kwargs) + # self.supports_masking = True + # self.axis = axis + # + # def call(self, inputs): + # + # def softmax(self, target, axis, name=None): + # max_axis = self.tf.reduce_max(target, axis, keepdims=True) + # target_exp = self.tf.exp(target-max_axis) + # normalize = self.tf.reduce_sum(target_exp, axis, keepdims=True) + # softmax = self.tf.div(target_exp, normalize, name) + # return softmax + # #return activations.softmax(inputs, axis=self.axis) + # + # def get_config(self): + # config = {'axis': self.axis} + # base_config = super(MTCNNSoftmax, self).get_config() + # return dict(list(base_config.items()) + list(config.items())) + # + # def compute_output_shape(self, input_shape): + # return input_shape + + from core.leras import nn + exec( nn.import_all(), locals(), globals() ) + + + + + image = cv2.imread('D:\\DeepFaceLab\\test\\00000.png').astype(np.float32) / 255.0 + image = cv2.cvtColor (image, cv2.COLOR_BGR2GRAY) + image = np.expand_dims (image, -1) + image_shape = image.shape + + image2 = cv2.imread('D:\\DeepFaceLab\\test\\00001.png').astype(np.float32) / 255.0 + image2 = cv2.cvtColor (image2, cv2.COLOR_BGR2GRAY) + image2 = np.expand_dims (image2, -1) + image2_shape = image2.shape + + #cv2.imshow('', image) + + + image_tensor = K.placeholder(shape=[ 1, image_shape[0], image_shape[1], image_shape[2] ], dtype="float32" ) + image2_tensor = K.placeholder(shape=[ 1, image_shape[0], image_shape[1], image_shape[2] ], dtype="float32" ) + + blurred_image_tensor = gaussian_blur(16.0)(image_tensor) + x, = nn.tf_sess.run ( blurred_image_tensor, feed_dict={image_tensor: np.expand_dims(image,0)} ) + cv2.imshow('', (x*255).astype(np.uint8) ) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + + #os.environ['plaidML'] = '1' + from core.leras import nn + + dvc = nn.device.Config(force_gpu_idx=1) + exec( nn.import_all(dvc), locals(), globals() ) + + tf = nn.tf + + image = cv2.imread('D:\\DeepFaceLab\\test\\00000.png').astype(np.float32) / 255.0 + image = cv2.cvtColor (image, cv2.COLOR_BGR2GRAY) + image = np.expand_dims (image, -1) + image_shape = image.shape + + image2 = cv2.imread('D:\\DeepFaceLab\\test\\00001.png').astype(np.float32) / 255.0 + image2 = cv2.cvtColor (image2, cv2.COLOR_BGR2GRAY) + image2 = np.expand_dims (image2, -1) + image2_shape = image2.shape + + image1_tensor = K.placeholder(shape=[ 1, image_shape[0], image_shape[1], image_shape[2] ], dtype="float32" ) + image2_tensor = K.placeholder(shape=[ 1, image_shape[0], image_shape[1], image_shape[2] ], dtype="float32" ) + + + + #import code + #code.interact(local=dict(globals(), **locals())) + def manual_conv(input, filter, strides, padding): + h_f, w_f, c_in, c_out = filter.get_shape().as_list() + input_patches = tf.extract_image_patches(input, ksizes=[1, h_f, w_f, 1 ], strides=strides, rates=[1, 1, 1, 1], padding=padding) + return input_patches + filters_flat = tf.reshape(filter, shape=[h_f*w_f*c_in, c_out]) + return tf.einsum("ijkl,lm->ijkm", input_patches, filters_flat) + + def extract_image_patches(x, ksizes, ssizes, padding='SAME', + data_format='channels_last'): + """Extract the patches from an image. + # Arguments + x: The input image + ksizes: 2-d tuple with the kernel size + ssizes: 2-d tuple with the strides size + padding: 'same' or 'valid' + data_format: 'channels_last' or 'channels_first' + # Returns + The (k_w,k_h) patches extracted + TF ==> (batch_size,w,h,k_w,k_h,c) + TH ==> (batch_size,w,h,c,k_w,k_h) + """ + kernel = [1, ksizes[0], ksizes[1], 1] + strides = [1, ssizes[0], ssizes[1], 1] + if data_format == 'channels_first': + x = K.permute_dimensions(x, (0, 2, 3, 1)) + bs_i, w_i, h_i, ch_i = K.int_shape(x) + patches = tf.extract_image_patches(x, kernel, strides, [1, 1, 1, 1], + padding) + # Reshaping to fit Theano + bs, w, h, ch = K.int_shape(patches) + reshaped = tf.reshape(patches, [-1, w, h, tf.floordiv(ch, ch_i), ch_i]) + final_shape = [-1, w, h, ch_i, ksizes[0], ksizes[1]] + patches = tf.reshape(tf.transpose(reshaped, [0, 1, 2, 4, 3]), final_shape) + if data_format == 'channels_last': + patches = K.permute_dimensions(patches, [0, 1, 2, 4, 5, 3]) + return patches + + m = 32 + c_in = 3 + c_out = 16 + + filter_sizes = [5, 11] + strides = [1] + #paddings = ["VALID", "SAME"] + + for fs in filter_sizes: + h = w = 128 + h_f = w_f = fs + stri = 2 + #print "Testing for", imsize, fs, stri, pad + + #tf.reset_default_graph() + X = tf.constant(1.0+np.random.rand(m, h, w, c_in), tf.float32) + W = tf.constant(np.ones([h_f, w_f, c_in, h_f*w_f*c_in]), tf.float32) + + + Z = tf.nn.conv2d(X, W, strides=[1, stri, stri, 1], padding="VALID") + Z_manual = manual_conv(X, W, strides=[1, stri, stri, 1], padding="VALID") + Z_2 = extract_image_patches (X, (fs,fs), (stri,stri), padding="VALID") + import code + code.interact(local=dict(globals(), **locals())) + # + sess = tf.Session() + sess.run(tf.global_variables_initializer()) + Z_, Z_manual_ = sess.run([Z, Z_manual]) + #self.assertEqual(Z_.shape, Z_manual_.shape) + #self.assertTrue(np.allclose(Z_, Z_manual_, rtol=1e-05)) + sess.close() + + + import code + code.interact(local=dict(globals(), **locals())) + + + + + + #k_loss_t = keras_style_loss()(image1_tensor, image2_tensor) + #k_loss_run = K.function( [image1_tensor, image2_tensor],[k_loss_t]) + #import code + #code.interact(local=dict(globals(), **locals())) + #image = np.expand_dims(image,0) + #image2 = np.expand_dims(image2,0) + #k_loss = k_loss_run([image, image2]) + #t_loss = t_loss_run([image, image2]) + + + + + #x, = tf_sess_run ([np.expand_dims(image,0)]) + #x = x[0] + ##import code + ##code.interact(local=dict(globals(), **locals())) + + + + image = cv2.imread('D:\\DeepFaceLab\\test\\00000.png').astype(np.float32) / 255.0 + image = cv2.cvtColor (image, cv2.COLOR_BGR2GRAY) + image = np.expand_dims (image, -1) + image_shape = image.shape + + image2 = cv2.imread('D:\\DeepFaceLab\\test\\00001.png').astype(np.float32) / 255.0 + image2 = cv2.cvtColor (image2, cv2.COLOR_BGR2GRAY) + image2 = np.expand_dims (image2, -1) + image2_shape = image2.shape + + image_tensor = tf.placeholder(tf.float32, shape=[1, image_shape[0], image_shape[1], image_shape[2] ]) + image2_tensor = tf.placeholder(tf.float32, shape=[1, image2_shape[0], image2_shape[1], image2_shape[2] ]) + + blurred_image_tensor = sl(image_tensor, image2_tensor) + x = tf_sess.run ( blurred_image_tensor, feed_dict={image_tensor: np.expand_dims(image,0), image2_tensor: np.expand_dims(image2,0) } ) + + cv2.imshow('', x[0]) + cv2.waitKey(0) + import code + code.interact(local=dict(globals(), **locals())) + + while True: + image = cv2.imread('D:\\DeepFaceLab\\workspace\\data_src\\aligned\\00000.png').astype(np.float32) / 255.0 + image = cv2.resize(image, (256,256)) + image = random_transform( image ) + warped_img, target_img = random_warp( image ) + + #cv2.imshow('', image) + #cv2.waitKey(0) + + cv2.imshow('', warped_img) + cv2.waitKey(0) + cv2.imshow('', target_img) + cv2.waitKey(0) + + import code + code.interact(local=dict(globals(), **locals())) + + import code + code.interact(local=dict(globals(), **locals())) + + return + + + def keras_gaussian_blur(radius=2.0): + def gaussian(x, mu, sigma): + return np.exp(-(float(x) - float(mu)) ** 2 / (2 * sigma ** 2)) + + def make_kernel(sigma): + kernel_size = max(3, int(2 * 2 * sigma + 1)) + mean = np.floor(0.5 * kernel_size) + kernel_1d = np.array([gaussian(x, mean, sigma) for x in range(kernel_size)]) + np_kernel = np.outer(kernel_1d, kernel_1d).astype(dtype=K.floatx()) + kernel = np_kernel / np.sum(np_kernel) + return kernel + + gauss_kernel = make_kernel(radius) + gauss_kernel = gauss_kernel[:, :,np.newaxis, np.newaxis] + + #import code + #code.interact(local=dict(globals(), **locals())) + def func(input): + inputs = [ input[:,:,:,i:i+1] for i in range( K.int_shape( input )[-1] ) ] + + outputs = [] + for i in range(len(inputs)): + outputs += [ K.conv2d( inputs[i] , K.constant(gauss_kernel) , strides=(1,1), padding="same") ] + + return K.concatenate (outputs, axis=-1) + return func + + def keras_style_loss(gaussian_blur_radius=0.0, loss_weight=1.0, epsilon=1e-5): + if gaussian_blur_radius > 0.0: + gblur = keras_gaussian_blur(gaussian_blur_radius) + + def sd(content, style): + content_nc = K.int_shape(content)[-1] + style_nc = K.int_shape(style)[-1] + if content_nc != style_nc: + raise Exception("keras_style_loss() content_nc != style_nc") + + axes = [1,2] + c_mean, c_var = K.mean(content, axis=axes, keepdims=True), K.var(content, axis=axes, keepdims=True) + s_mean, s_var = K.mean(style, axis=axes, keepdims=True), K.var(style, axis=axes, keepdims=True) + c_std, s_std = K.sqrt(c_var + epsilon), K.sqrt(s_var + epsilon) + + mean_loss = K.sum(K.square(c_mean-s_mean)) + std_loss = K.sum(K.square(c_std-s_std)) + + return (mean_loss + std_loss) * loss_weight + + def func(target, style): + if gaussian_blur_radius > 0.0: + return sd( gblur(target), gblur(style)) + else: + return sd( target, style ) + return func + + data = tf.placeholder(tf.float32, (None,None,None,3), 'input') + pnet2 = mtcnn.PNet(tf, {'data':data}) + filename = str(Path(mtcnn.__file__).parent/'det1.npy') + pnet2.load(filename, tf_sess) + + pnet_fun = K.function([pnet2.layers['data']],[pnet2.layers['conv4-2'], pnet2.layers['prob1']]) + + import code + code.interact(local=dict(globals(), **locals())) + + return + + + while True: + img_bgr = np.random.rand ( 268, 640, 3 ) + img_size = img_bgr.shape[1], img_bgr.shape[0] + + mat = np.array( [[ 1.99319629e+00, -1.81504324e-01, -3.62479778e+02], + [ 1.81504324e-01, 1.99319629e+00, -8.05396709e+01]] ) + + tmp_0 = np.random.rand ( 128,128 ) - 0.1 + tmp = np.expand_dims (tmp_0, axis=-1) + + mask = np.ones ( tmp.shape, dtype=np.float32) + mask_border_size = int ( mask.shape[1] * 0.0625 ) + mask[:,0:mask_border_size,:] = 0 + mask[:,-mask_border_size:,:] = 0 + + x = cv2.warpAffine( mask, mat, img_size, np.zeros(img_bgr.shape, dtype=np.float32), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT ) + + if len ( np.argwhere( np.isnan(x) ) ) == 0: + print ("fine") + else: + print ("wtf") + + import code + code.interact(local=dict(globals(), **locals())) + + return + + aligned_path_image_paths = pathex.get_image_paths("E:\\FakeFaceVideoSources\\Datasets\\CelebA aligned") + + a = [] + r_vec = np.array([[0.01891013], [0.08560084], [-3.14392813]]) + t_vec = np.array([[-14.97821226], [-10.62040383], [-2053.03596872]]) + + yaws = [] + pitchs = [] + for filepath in tqdm(aligned_path_image_paths, desc="test", ascii=True ): + filepath = Path(filepath) + + if filepath.suffix == '.png': + dflimg = DFLPNG.load( str(filepath), print_on_no_embedded_data=True ) + elif filepath.suffix == '.jpg': + dflimg = DFLJPG.load ( str(filepath), print_on_no_embedded_data=True ) + else: + print ("%s is not a dfl image file" % (filepath.name) ) + + #source_filename_stem = Path( dflimg.get_source_filename() ).stem + #if source_filename_stem not in alignments.keys(): + # alignments[ source_filename_stem ] = [] + + + #focal_length = dflimg.shape[1] + #camera_center = (dflimg.shape[1] / 2, dflimg.shape[0] / 2) + #camera_matrix = np.array( + # [[focal_length, 0, camera_center[0]], + # [0, focal_length, camera_center[1]], + # [0, 0, 1]], dtype=np.float32) + # + landmarks = dflimg.get_landmarks() + # + #lm = landmarks.astype(np.float32) + + img = cv2_imread (str(filepath)) / 255.0 + + img = LandmarksProcessor.draw_landmarks(img, landmarks, (1,1,1) ) + + + #(_, rotation_vector, translation_vector) = cv2.solvePnP( + # LandmarksProcessor.landmarks_68_3D, + # lm, + # camera_matrix, + # np.zeros((4, 1)) ) + # + #rme = mathlib.rotationMatrixToEulerAngles( cv2.Rodrigues(rotation_vector)[0] ) + #import code + #code.interact(local=dict(globals(), **locals())) + + #rotation_vector = rotation_vector / np.linalg.norm(rotation_vector) + + + #img2 = image_utils.get_text_image ( (256,10, 3), str(rotation_vector) ) + pitch, yaw = LandmarksProcessor.estimate_pitch_yaw (landmarks) + yaws += [yaw] + #print(pitch, yaw) + #cv2.imshow ("", (img * 255).astype(np.uint8) ) + #cv2.waitKey(0) + #a += [ rotation_vector] + yaws = np.array(yaws) + import code + code.interact(local=dict(globals(), **locals())) + + + + + + + #alignments[ source_filename_stem ].append (dflimg.get_source_landmarks()) + #alignments.append (dflimg.get_source_landmarks()) + + + + + + + + o = np.ones ( (128,128,3), dtype=np.float32 ) + cv2.imwrite ("D:\\temp\\z.jpg", o) + + #DFLJPG.x ("D:\\temp\\z.jpg", ) + + dfljpg = DFLJPG.load("D:\\temp\\z.jpg") + + import code + code.interact(local=dict(globals(), **locals())) + + return + + + + #import sys, numpy; print(numpy.__version__, sys.version) + sq = multiprocessing.Queue() + cq = multiprocessing.Queue() + + p = multiprocessing.Process(target=subprocess, args=(sq,cq,)) + p.start() + + while True: + cq.get() #waiting numpy array + sq.put (1) #send message we are ready to get more + + #import code + #code.interact(local=dict(globals(), **locals())) + + os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2' + + from core.leras import nn + exec( nn.import_all(), locals(), globals() ) + + + + + #import tensorflow as tf + #tf_module = tf + # + #config = tf_module.ConfigProto() + #config.gpu_options.force_gpu_compatible = True + #tf_session = tf_module.Session(config=config) + # + #srgb_tensor = tf.placeholder("float", [None, None, 3]) + # + #filename = Path(__file__).parent / '00050.png' + #img = cv2.imread(str(filename)).astype(np.float32) / 255.0 + # + #lab_tensor = rgb_to_lab (tf_module, srgb_tensor) + # + #rgb_tensor = lab_to_rgb (tf_module, lab_tensor) + # + #rgb = tf_session.run(rgb_tensor, feed_dict={srgb_tensor: img}) + #cv2.imshow("", rgb) + #cv2.waitKey(0) + + #from skimage import io, color + #def_lab = color.rgb2lab(img) + # + #t = time.time() + #def_lab = color.rgb2lab(img) + #print ( time.time() - t ) + # + #lab = tf_session.run(lab_tensor, feed_dict={srgb_tensor: img}) + # + #t = time.time() + #lab = tf_session.run(lab_tensor, feed_dict={srgb_tensor: img}) + #print ( time.time() - t ) + + + + + + + #lab_clr = color.rgb2lab(img_bgr) + #lab_bw = color.rgb2lab(out_img) + #tmp_channel, a_channel, b_channel = cv2.split(lab_clr) + #l_channel, tmp2_channel, tmp3_channel = cv2.split(lab_bw) + #img_LAB = cv2.merge((l_channel,a_channel, b_channel)) + #out_img = color.lab2rgb(lab.astype(np.float64)) + # + #cv2.imshow("", out_img) + #cv2.waitKey(0) + + #import code + #code.interact(local=dict(globals(), **locals())) + + + +if __name__ == "__main__": + + #import os + #os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" + #os.environ["PLAIDML_DEVICE_IDS"] = "opencl_nvidia_geforce_gtx_1060_6gb.0" + #import keras + #import numpy as np + #import cv2 + #import time + #K = keras.backend + # + # + # + #PNet_Input = keras.layers.Input ( (None, None,3) ) + #x = PNet_Input + #x = keras.layers.Conv2D (10, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv1")(x) + #x = keras.layers.PReLU (shared_axes=[1,2], name="PReLU1" )(x) + #x = keras.layers.MaxPooling2D( pool_size=(2,2), strides=(2,2), padding='same' ) (x) + #x = keras.layers.Conv2D (16, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv2")(x) + #x = keras.layers.PReLU (shared_axes=[1,2], name="PReLU2" )(x) + #x = keras.layers.Conv2D (32, kernel_size=(3,3), strides=(1,1), padding='valid', name="conv3")(x) + #x = keras.layers.PReLU (shared_axes=[1,2], name="PReLU3" )(x) + #prob = keras.layers.Conv2D (2, kernel_size=(1,1), strides=(1,1), padding='valid', name="conv41")(x) + #x = keras.layers.Conv2D (4, kernel_size=(1,1), strides=(1,1), padding='valid', name="conv42")(x) + # + #pnet = K.function ([PNet_Input], [x,prob] ) + # + #img = np.random.uniform ( size=(1920,1920,3) ) + #minsize=80 + #factor=0.95 + #factor_count=0 + #h=img.shape[0] + #w=img.shape[1] + # + #minl=np.amin([h, w]) + #m=12.0/minsize + #minl=minl*m + ## create scale pyramid + #scales=[] + #while minl>=12: + # scales += [m*np.power(factor, factor_count)] + # minl = minl*factor + # factor_count += 1 + # # first stage + # for scale in scales: + # hs=int(np.ceil(h*scale)) + # ws=int(np.ceil(w*scale)) + # im_data = cv2.resize(img, (ws, hs), interpolation=cv2.INTER_LINEAR) + # im_data = (im_data-127.5)*0.0078125 + # img_x = np.expand_dims(im_data, 0) + # img_x = np.transpose(img_x, (0,2,1,3)) + # t = time.time() + # out = pnet([img_x]) + # t = time.time() - t + # print (img_x.shape, t) + # + #import code + #code.interact(local=dict(globals(), **locals())) + + #os.environ["KERAS_BACKEND"] = "plaidml.keras.backend" + #os.environ["PLAIDML_DEVICE_IDS"] = "opencl_nvidia_geforce_gtx_1060_6gb.0" + #import keras + #K = keras.backend + # + #image = np.random.uniform ( size=(1,256,256,3) ) + #image2 = np.random.uniform ( size=(1,256,256,3) ) + # + #y_true = K.placeholder ( (None,) + image.shape[1:] ) + #y_pred = K.placeholder ( (None,) + image2.shape[1:] ) + # + #def reducer(x): + # shape = K.shape(x) + # x = K.reshape(x, (-1, shape[-3] , shape[-2], shape[-1]) ) + # y = K.depthwise_conv2d(x, K.constant(np.ones( (11,11,3,1) )), strides=(1, 1), padding='valid' ) + # y_shape = K.shape(y) + # return K.reshape(y, (shape[0], y_shape[1], y_shape[2], y_shape[3] ) ) + # + #mean0 = reducer(y_true) + #mean1 = reducer(y_pred) + #luminance = mean0 * mean1 + #cs = y_true * y_pred + # + #result = K.function([y_true, y_pred],[luminance, cs]) ([image, image2]) + # + #print (result) + #import code + #code.interact(local=dict(globals(), **locals())) + + + main() + +""" + +MobileNetV2 + +class Upscale(nn.ModelBase): + def on_build(self, in_ch, out_ch, kernel_size=3 ): + self.conv1 = nn.Conv2D( in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME') + + def forward(self, x): + x = self.conv1(x) + x = tf.nn.leaky_relu(x, 0.1) + x = nn.depth_to_space(x, 2) + return x + + + class BottleNeck(nn.ModelBase): + def on_build(self, in_ch, ch, kernel_size, t, strides, r=False, **kwargs ): + + dc = in_ch*t + + self.conv1 = nn.Conv2D (in_ch, dc, kernel_size=1, strides=1, padding='SAME') + self.frn1 = nn.FRNorm2D(dc) + self.tlu1 = nn.TLU(dc) + + self.conv2 = nn.DepthwiseConv2D (dc, kernel_size=kernel_size, strides=strides, padding='SAME') + self.frn2 = nn.FRNorm2D(dc) + self.tlu2 = nn.TLU(dc) + + + self.conv3 = nn.Conv2D (dc, ch, kernel_size=1, strides=1, padding='SAME') + self.frn3 = nn.FRNorm2D(ch) + + self.r = r + + def forward(self, inp): + x = inp + + x = self.conv1(x) + x = self.frn1(x) + x = self.tlu1(x) + + x = self.conv2(x) + x = self.frn2(x) + x = self.tlu2(x) + + x = self.conv3(x) + x = self.frn3(x) + + if self.r: + x = x + inp + + return x + + + class InvResidualBlock(nn.ModelBase): + def on_build(self, in_ch, ch, kernel_size, t, strides, n, **kwargs ): + self.b1 = BottleNeck(in_ch, ch, kernel_size, t, strides) + + self.b_list = [] + for i in range(1, n): + self.b_list.append ( BottleNeck(ch, ch, kernel_size, t, 1, r=True) ) + + def forward(self, inp): + x = inp + x = self.b1(x) + + for i in range(len(self.b_list)): + x = self.b_list[i](x) + + return x + + class Encoder(nn.ModelBase): + def on_build(self, in_ch, e_ch, **kwargs): + e_ch = e_ch // 8 + + self.conv1 = nn.Conv2D( in_ch, e_ch, kernel_size=3, strides=2, padding='SAME') + self.frn1 = nn.FRNorm2D(e_ch) + self.tlu1 = nn.TLU(e_ch) + + self.ir1 = InvResidualBlock(e_ch, e_ch*2, kernel_size=3, t=1, strides=1, n=1) + self.ir2 = InvResidualBlock(e_ch*2, e_ch*3, kernel_size=3, t=6, strides=2, n=2) + self.ir3 = InvResidualBlock(e_ch*3, e_ch*4, kernel_size=3, t=6, strides=2, n=3) + self.ir4 = InvResidualBlock(e_ch*4, e_ch*8, kernel_size=3, t=6, strides=2, n=4) + self.ir5 = InvResidualBlock(e_ch*8, e_ch*12, kernel_size=3, t=6, strides=1, n=3) + self.ir6 = InvResidualBlock(e_ch*12, e_ch*20, kernel_size=3, t=6, strides=2, n=3) + self.ir7 = InvResidualBlock(e_ch*20, e_ch*40, kernel_size=3, t=6, strides=1, n=1) + + def forward(self, inp): + x = inp + x = self.conv1(x) + x = self.frn1(x) + x = self.tlu1(x) + + x = self.ir1(x) + x = self.ir2(x) + x = self.ir3(x) + x = self.ir4(x) + x = self.ir5(x) + x = self.ir6(x) + x = self.ir7(x) + + return x + + lowest_dense_res = resolution // 32 + + class Inter(nn.ModelBase): + def __init__(self, in_ch, ae_ch, ae_out_ch, is_hd=False, **kwargs): + self.in_ch, self.ae_ch, self.ae_out_ch = in_ch, ae_ch, ae_out_ch + super().__init__(**kwargs) + + def on_build(self): + in_ch, ae_ch, ae_out_ch = self.in_ch, self.ae_ch, self.ae_out_ch + + self.conv2 = nn.Conv2D( in_ch, ae_ch, kernel_size=3, strides=1, padding='SAME') + self.frn2 = nn.FRNorm2D(ae_ch) + self.tlu2 = nn.TLU(ae_ch) + + + + self.dense1 = nn.Dense( ae_ch, ae_ch ) + self.dense2 = nn.Dense( ae_ch, lowest_dense_res * lowest_dense_res * ae_out_ch ) + + def forward(self, inp): + x = self.conv2(inp) + x = self.frn2(x) + x = self.tlu2(x) + + x = nn.tf.reduce_mean (x, axis=nn.conv2d_spatial_axes, keepdims=True) + + x = nn.flatten(x) + + x = self.dense1(x) + x = self.dense2(x) + x = nn.reshape_4D (x, lowest_dense_res, lowest_dense_res, self.ae_out_ch) + return x + + @staticmethod + def get_code_res(): + return lowest_dense_res + + def get_out_ch(self): + return self.ae_out_ch + + class Decoder(nn.ModelBase): + def on_build(self, in_ch, d_ch, d_mask_ch, **kwargs ): + d_ch = d_ch // 8 + d_mask_ch = d_mask_ch // 8 + + self.conv2 = nn.Conv2D( in_ch, d_ch*40, kernel_size=3, strides=1, padding='SAME') + self.frn2 = nn.FRNorm2D(d_ch*40) + self.tlu2 = nn.TLU(d_ch*40) + + self.ir7 = InvResidualBlock(d_ch*40, d_ch*20, kernel_size=3, t=6, strides=1, n=1) + self.ir6 = InvResidualBlock(d_ch*20, d_ch*12, kernel_size=3, t=6, strides=1, n=3) + self.ir5 = InvResidualBlock(d_ch*12, d_ch*8, kernel_size=3, t=6, strides=1, n=3) + self.ir4 = InvResidualBlock(d_ch*8, d_ch*4, kernel_size=3, t=6, strides=1, n=4) + self.ir3 = InvResidualBlock(d_ch*4, d_ch*3, kernel_size=3, t=6, strides=1, n=3) + self.ir2 = InvResidualBlock(d_ch*3, d_ch*2, kernel_size=3, t=6, strides=1, n=2) + self.ir1 = InvResidualBlock(d_ch*2, d_ch, kernel_size=3, t=1, strides=1, n=1) + self.out_conv = nn.Conv2D( d_ch, 3, kernel_size=1, padding='SAME') + + + self.mir7 = InvResidualBlock(d_ch*40, d_mask_ch*20, kernel_size=3, t=6, strides=1, n=1) + self.mir6 = InvResidualBlock(d_mask_ch*20, d_mask_ch*12, kernel_size=3, t=6, strides=1, n=1) + self.mir5 = InvResidualBlock(d_mask_ch*12, d_mask_ch*8, kernel_size=3, t=6, strides=1, n=1) + self.mir4 = InvResidualBlock(d_mask_ch*8, d_mask_ch*4, kernel_size=3, t=6, strides=1, n=1) + self.mir3 = InvResidualBlock(d_mask_ch*4, d_mask_ch*3, kernel_size=3, t=6, strides=1, n=1) + self.mir2 = InvResidualBlock(d_mask_ch*3, d_mask_ch*2, kernel_size=3, t=6, strides=1, n=1) + self.mir1 = InvResidualBlock(d_mask_ch*2, d_mask_ch, kernel_size=3, t=1, strides=1, n=1) + self.out_convm = nn.Conv2D( d_mask_ch, 1, kernel_size=1, padding='SAME') + + + def forward(self, inp): + x = inp + + x = self.conv2(x) + x = self.frn2(x) + x = z = self.tlu2(x) + + x = self.ir7(x) + x = nn.upsample2d(x) + x = self.ir6(x) + x = self.ir5(x) + x = nn.upsample2d(x) + x = self.ir4(x) + x = nn.upsample2d(x) + x = self.ir3(x) + x = nn.upsample2d(x) + x = self.ir2(x) + x = nn.upsample2d(x) + x = self.ir1(x) + + m = self.mir7(z) + m = nn.upsample2d(m) + m = self.mir6(m) + m = self.mir5(m) + m = nn.upsample2d(m) + m = self.mir4(m) + m = nn.upsample2d(m) + m = self.mir3(m) + m = nn.upsample2d(m) + m = self.mir2(m) + m = nn.upsample2d(m) + m = self.mir1(m) + + return tf.nn.sigmoid(self.out_conv(x)), \ + tf.nn.sigmoid(self.out_convm(m)) + +""" \ No newline at end of file