Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Callable, Optional | |
| from .triposg_transformer import TripoSGDiTModel | |
| def default_set_attn_proc_func( | |
| name: str, | |
| hidden_size: int, | |
| cross_attention_dim: Optional[int], | |
| ori_attn_proc: object, | |
| ) -> object: | |
| return ori_attn_proc | |
| def set_transformer_attn_processor( | |
| transformer: TripoSGDiTModel, | |
| set_self_attn_proc_func: Callable = default_set_attn_proc_func, | |
| set_cross_attn_1_proc_func: Callable = default_set_attn_proc_func, | |
| set_cross_attn_2_proc_func: Callable = default_set_attn_proc_func, | |
| set_self_attn_module_names: Optional[list[str]] = None, | |
| set_cross_attn_1_module_names: Optional[list[str]] = None, | |
| set_cross_attn_2_module_names: Optional[list[str]] = None, | |
| ) -> None: | |
| do_set_processor = lambda name, module_names: ( | |
| any([name.startswith(module_name) for module_name in module_names]) | |
| if module_names is not None | |
| else True | |
| ) # prefix match | |
| attn_procs = {} | |
| for name, attn_processor in transformer.attn_processors.items(): | |
| hidden_size = transformer.config.width | |
| if name.endswith("attn1.processor"): | |
| # self attention | |
| attn_procs[name] = ( | |
| set_self_attn_proc_func(name, hidden_size, None, attn_processor) | |
| if do_set_processor(name, set_self_attn_module_names) | |
| else attn_processor | |
| ) | |
| elif name.endswith("attn2.processor"): | |
| # cross attention | |
| cross_attention_dim = transformer.config.cross_attention_dim | |
| attn_procs[name] = ( | |
| set_cross_attn_1_proc_func( | |
| name, hidden_size, cross_attention_dim, attn_processor | |
| ) | |
| if do_set_processor(name, set_cross_attn_1_module_names) | |
| else attn_processor | |
| ) | |
| elif name.endswith("attn2_2.processor"): | |
| # cross attention 2 | |
| cross_attention_dim = transformer.config.cross_attention_2_dim | |
| attn_procs[name] = ( | |
| set_cross_attn_2_proc_func( | |
| name, hidden_size, cross_attention_dim, attn_processor | |
| ) | |
| if do_set_processor(name, set_cross_attn_2_module_names) | |
| else attn_processor | |
| ) | |
| transformer.set_attn_processor(attn_procs) | |