Meismaxandmaxisme commited on
Commit
c542dab
·
verified ·
1 Parent(s): c6f24ab

Upload 2 files

Browse files
src/backend/gguf/gguf_diffusion.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wrapper class to call the stablediffusion.cpp shared library for GGUF support
3
+ """
4
+
5
+ import ctypes
6
+ import platform
7
+ from ctypes import (
8
+ POINTER,
9
+ c_bool,
10
+ c_char_p,
11
+ c_float,
12
+ c_int,
13
+ c_int64,
14
+ c_void_p,
15
+ )
16
+ from dataclasses import dataclass
17
+ from os import path
18
+ from typing import List, Any
19
+
20
+ import numpy as np
21
+ from PIL import Image
22
+
23
+ from backend.gguf.sdcpp_types import (
24
+ RngType,
25
+ SampleMethod,
26
+ Schedule,
27
+ SDCPPLogLevel,
28
+ SDImage,
29
+ SdType,
30
+ )
31
+
32
+
33
+ @dataclass
34
+ class ModelConfig:
35
+ model_path: str = ""
36
+ clip_l_path: str = ""
37
+ t5xxl_path: str = ""
38
+ diffusion_model_path: str = ""
39
+ vae_path: str = ""
40
+ taesd_path: str = ""
41
+ control_net_path: str = ""
42
+ lora_model_dir: str = ""
43
+ embed_dir: str = ""
44
+ stacked_id_embed_dir: str = ""
45
+ vae_decode_only: bool = True
46
+ vae_tiling: bool = False
47
+ free_params_immediately: bool = False
48
+ n_threads: int = 4
49
+ wtype: SdType = SdType.SD_TYPE_Q4_0
50
+ rng_type: RngType = RngType.CUDA_RNG
51
+ schedule: Schedule = Schedule.DEFAULT
52
+ keep_clip_on_cpu: bool = False
53
+ keep_control_net_cpu: bool = False
54
+ keep_vae_on_cpu: bool = False
55
+
56
+
57
+ @dataclass
58
+ class Txt2ImgConfig:
59
+ prompt: str = "a man wearing sun glasses, highly detailed"
60
+ negative_prompt: str = ""
61
+ clip_skip: int = -1
62
+ cfg_scale: float = 2.0
63
+ guidance: float = 3.5
64
+ width: int = 512
65
+ height: int = 512
66
+ sample_method: SampleMethod = SampleMethod.EULER_A
67
+ sample_steps: int = 1
68
+ seed: int = -1
69
+ batch_count: int = 2
70
+ control_cond: Image = None
71
+ control_strength: float = 0.90
72
+ style_strength: float = 0.5
73
+ normalize_input: bool = False
74
+ input_id_images_path: bytes = b""
75
+
76
+
77
+ class GGUFDiffusion:
78
+ """GGUF Diffusion
79
+ To support GGUF diffusion model based on stablediffusion.cpp
80
+ https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
81
+ Implmented based on stablediffusion.h
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ libpath: str,
87
+ config: ModelConfig,
88
+ logging_enabled: bool = False,
89
+ ):
90
+ sdcpp_shared_lib_path = self._get_sdcpp_shared_lib_path(libpath)
91
+ try:
92
+ self.libsdcpp = ctypes.CDLL(sdcpp_shared_lib_path)
93
+ except OSError as e:
94
+ print(f"Failed to load library {sdcpp_shared_lib_path}")
95
+ raise ValueError(f"Error: {e}")
96
+
97
+ if not config.clip_l_path or not path.exists(config.clip_l_path):
98
+ raise ValueError(
99
+ "CLIP model file not found,please check readme.md for GGUF model usage"
100
+ )
101
+
102
+ if not config.t5xxl_path or not path.exists(config.t5xxl_path):
103
+ raise ValueError(
104
+ "T5XXL model file not found,please check readme.md for GGUF model usage"
105
+ )
106
+
107
+ if not config.diffusion_model_path or not path.exists(
108
+ config.diffusion_model_path
109
+ ):
110
+ raise ValueError(
111
+ "Diffusion model file not found,please check readme.md for GGUF model usage"
112
+ )
113
+
114
+ if not config.vae_path or not path.exists(config.vae_path):
115
+ raise ValueError(
116
+ "VAE model file not found,please check readme.md for GGUF model usage"
117
+ )
118
+
119
+ self.model_config = config
120
+
121
+ self.libsdcpp.new_sd_ctx.argtypes = [
122
+ c_char_p, # const char* model_path
123
+ c_char_p, # const char* clip_l_path
124
+ c_char_p, # const char* t5xxl_path
125
+ c_char_p, # const char* diffusion_model_path
126
+ c_char_p, # const char* vae_path
127
+ c_char_p, # const char* taesd_path
128
+ c_char_p, # const char* control_net_path_c_str
129
+ c_char_p, # const char* lora_model_dir
130
+ c_char_p, # const char* embed_dir_c_str
131
+ c_char_p, # const char* stacked_id_embed_dir_c_str
132
+ c_bool, # bool vae_decode_only
133
+ c_bool, # bool vae_tiling
134
+ c_bool, # bool free_params_immediately
135
+ c_int, # int n_threads
136
+ SdType, # enum sd_type_t wtype
137
+ RngType, # enum rng_type_t rng_type
138
+ Schedule, # enum schedule_t s
139
+ c_bool, # bool keep_clip_on_cpu
140
+ c_bool, # bool keep_control_net_cpu
141
+ c_bool, # bool keep_vae_on_cpu
142
+ ]
143
+
144
+ self.libsdcpp.new_sd_ctx.restype = POINTER(c_void_p)
145
+
146
+ self.sd_ctx = self.libsdcpp.new_sd_ctx(
147
+ self._str_to_bytes(self.model_config.model_path),
148
+ self._str_to_bytes(self.model_config.clip_l_path),
149
+ self._str_to_bytes(self.model_config.t5xxl_path),
150
+ self._str_to_bytes(self.model_config.diffusion_model_path),
151
+ self._str_to_bytes(self.model_config.vae_path),
152
+ self._str_to_bytes(self.model_config.taesd_path),
153
+ self._str_to_bytes(self.model_config.control_net_path),
154
+ self._str_to_bytes(self.model_config.lora_model_dir),
155
+ self._str_to_bytes(self.model_config.embed_dir),
156
+ self._str_to_bytes(self.model_config.stacked_id_embed_dir),
157
+ self.model_config.vae_decode_only,
158
+ self.model_config.vae_tiling,
159
+ self.model_config.free_params_immediately,
160
+ self.model_config.n_threads,
161
+ self.model_config.wtype,
162
+ self.model_config.rng_type,
163
+ self.model_config.schedule,
164
+ self.model_config.keep_clip_on_cpu,
165
+ self.model_config.keep_control_net_cpu,
166
+ self.model_config.keep_vae_on_cpu,
167
+ )
168
+
169
+ if logging_enabled:
170
+ self._set_logcallback()
171
+
172
+ def _set_logcallback(self):
173
+ print("Setting logging callback")
174
+ # Define function callback
175
+ SdLogCallbackType = ctypes.CFUNCTYPE(
176
+ None,
177
+ SDCPPLogLevel,
178
+ ctypes.c_char_p,
179
+ ctypes.c_void_p,
180
+ )
181
+
182
+ self.libsdcpp.sd_set_log_callback.argtypes = [
183
+ SdLogCallbackType,
184
+ ctypes.c_void_p,
185
+ ]
186
+ self.libsdcpp.sd_set_log_callback.restype = None
187
+ # Convert the Python callback to a C func pointer
188
+ self.c_log_callback = SdLogCallbackType(
189
+ self.log_callback
190
+ ) # prevent GC,keep callback as member variable
191
+ self.libsdcpp.sd_set_log_callback(self.c_log_callback, None)
192
+
193
+ def _get_sdcpp_shared_lib_path(
194
+ self,
195
+ root_path: str,
196
+ ) -> str:
197
+ system_name = platform.system()
198
+ print(f"GGUF Diffusion on {system_name}")
199
+ lib_name = "stable-diffusion.dll"
200
+ sdcpp_lib_path = ""
201
+
202
+ if system_name == "Windows":
203
+ sdcpp_lib_path = path.join(root_path, lib_name)
204
+ elif system_name == "Linux":
205
+ lib_name = "libstable-diffusion.so"
206
+ sdcpp_lib_path = path.join(root_path, lib_name)
207
+ elif system_name == "Darwin":
208
+ lib_name = "libstable-diffusion.dylib"
209
+ sdcpp_lib_path = path.join(root_path, lib_name)
210
+ else:
211
+ print("Unknown platform.")
212
+
213
+ return sdcpp_lib_path
214
+
215
+ @staticmethod
216
+ def log_callback(
217
+ level,
218
+ text,
219
+ data,
220
+ ):
221
+ print(f"{text.decode('utf-8')}", end="")
222
+
223
+ def _str_to_bytes(self, in_str: str, encoding: str = "utf-8") -> bytes:
224
+ if in_str:
225
+ return in_str.encode(encoding)
226
+ else:
227
+ return b""
228
+
229
+ def generate_text2mg(self, txt2img_cfg: Txt2ImgConfig) -> List[Any]:
230
+ self.libsdcpp.txt2img.restype = POINTER(SDImage)
231
+ self.libsdcpp.txt2img.argtypes = [
232
+ c_void_p, # sd_ctx_t* sd_ctx (pointer to context object)
233
+ c_char_p, # const char* prompt
234
+ c_char_p, # const char* negative_prompt
235
+ c_int, # int clip_skip
236
+ c_float, # float cfg_scale
237
+ c_float, # float guidance
238
+ c_int, # int width
239
+ c_int, # int height
240
+ SampleMethod, # enum sample_method_t sample_method
241
+ c_int, # int sample_steps
242
+ c_int64, # int64_t seed
243
+ c_int, # int batch_count
244
+ POINTER(SDImage), # const sd_image_t* control_cond (pointer to SDImage)
245
+ c_float, # float control_strength
246
+ c_float, # float style_strength
247
+ c_bool, # bool normalize_input
248
+ c_char_p, # const char* input_id_images_path
249
+ ]
250
+
251
+ image_buffer = self.libsdcpp.txt2img(
252
+ self.sd_ctx,
253
+ self._str_to_bytes(txt2img_cfg.prompt),
254
+ self._str_to_bytes(txt2img_cfg.negative_prompt),
255
+ txt2img_cfg.clip_skip,
256
+ txt2img_cfg.cfg_scale,
257
+ txt2img_cfg.guidance,
258
+ txt2img_cfg.width,
259
+ txt2img_cfg.height,
260
+ txt2img_cfg.sample_method,
261
+ txt2img_cfg.sample_steps,
262
+ txt2img_cfg.seed,
263
+ txt2img_cfg.batch_count,
264
+ txt2img_cfg.control_cond,
265
+ txt2img_cfg.control_strength,
266
+ txt2img_cfg.style_strength,
267
+ txt2img_cfg.normalize_input,
268
+ txt2img_cfg.input_id_images_path,
269
+ )
270
+
271
+ images = self._get_sd_images_from_buffer(
272
+ image_buffer,
273
+ txt2img_cfg.batch_count,
274
+ )
275
+
276
+ return images
277
+
278
+ def _get_sd_images_from_buffer(
279
+ self,
280
+ image_buffer: Any,
281
+ batch_count: int,
282
+ ) -> List[Any]:
283
+ images = []
284
+ if image_buffer:
285
+ for i in range(batch_count):
286
+ image = image_buffer[i]
287
+ print(
288
+ f"Generated image: {image.width}x{image.height} with {image.channel} channels"
289
+ )
290
+
291
+ width = image.width
292
+ height = image.height
293
+ channels = image.channel
294
+ pixel_data = np.ctypeslib.as_array(
295
+ image.data, shape=(height, width, channels)
296
+ )
297
+
298
+ if channels == 1:
299
+ pil_image = Image.fromarray(pixel_data.squeeze(), mode="L")
300
+ elif channels == 3:
301
+ pil_image = Image.fromarray(pixel_data, mode="RGB")
302
+ elif channels == 4:
303
+ pil_image = Image.fromarray(pixel_data, mode="RGBA")
304
+ else:
305
+ raise ValueError(f"Unsupported number of channels: {channels}")
306
+
307
+ images.append(pil_image)
308
+ return images
309
+
310
+ def terminate(self):
311
+ if self.libsdcpp:
312
+ if self.sd_ctx:
313
+ self.libsdcpp.free_sd_ctx.argtypes = [c_void_p]
314
+ self.libsdcpp.free_sd_ctx.restype = None
315
+ self.libsdcpp.free_sd_ctx(self.sd_ctx)
316
+ del self.sd_ctx
317
+ self.sd_ctx = None
318
+ del self.libsdcpp
319
+ self.libsdcpp = None
src/backend/gguf/sdcpp_types.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ctypes for stablediffusion.cpp shared library
3
+ This is as per the stablediffusion.h file
4
+ """
5
+
6
+ from enum import IntEnum
7
+ from ctypes import (
8
+ c_int,
9
+ c_uint32,
10
+ c_uint8,
11
+ POINTER,
12
+ Structure,
13
+ )
14
+
15
+
16
+ class CtypesEnum(IntEnum):
17
+ """A ctypes-compatible IntEnum superclass."""
18
+
19
+ @classmethod
20
+ def from_param(cls, obj):
21
+ return int(obj)
22
+
23
+
24
+ class RngType(CtypesEnum):
25
+ STD_DEFAULT_RNG = 0
26
+ CUDA_RNG = 1
27
+
28
+
29
+ class SampleMethod(CtypesEnum):
30
+ EULER_A = 0
31
+ EULER = 1
32
+ HEUN = 2
33
+ DPM2 = 3
34
+ DPMPP2S_A = 4
35
+ DPMPP2M = 5
36
+ DPMPP2Mv2 = 6
37
+ IPNDM = 7
38
+ IPNDM_V = 7
39
+ LCM = 8
40
+ N_SAMPLE_METHODS = 9
41
+
42
+
43
+ class Schedule(CtypesEnum):
44
+ DEFAULT = 0
45
+ DISCRETE = 1
46
+ KARRAS = 2
47
+ EXPONENTIAL = 3
48
+ AYS = 4
49
+ GITS = 5
50
+ N_SCHEDULES = 5
51
+
52
+
53
+ class SdType(CtypesEnum):
54
+ SD_TYPE_F32 = 0
55
+ SD_TYPE_F16 = 1
56
+ SD_TYPE_Q4_0 = 2
57
+ SD_TYPE_Q4_1 = 3
58
+ # SD_TYPE_Q4_2 = 4, support has been removed
59
+ # SD_TYPE_Q4_3 = 5, support has been removed
60
+ SD_TYPE_Q5_0 = 6
61
+ SD_TYPE_Q5_1 = 7
62
+ SD_TYPE_Q8_0 = 8
63
+ SD_TYPE_Q8_1 = 9
64
+ SD_TYPE_Q2_K = 10
65
+ SD_TYPE_Q3_K = 11
66
+ SD_TYPE_Q4_K = 12
67
+ SD_TYPE_Q5_K = 13
68
+ SD_TYPE_Q6_K = 14
69
+ SD_TYPE_Q8_K = 15
70
+ SD_TYPE_IQ2_XXS = 16
71
+ SD_TYPE_IQ2_XS = 17
72
+ SD_TYPE_IQ3_XXS = 18
73
+ SD_TYPE_IQ1_S = 19
74
+ SD_TYPE_IQ4_NL = 20
75
+ SD_TYPE_IQ3_S = 21
76
+ SD_TYPE_IQ2_S = 22
77
+ SD_TYPE_IQ4_XS = 23
78
+ SD_TYPE_I8 = 24
79
+ SD_TYPE_I16 = 25
80
+ SD_TYPE_I32 = 26
81
+ SD_TYPE_I64 = 27
82
+ SD_TYPE_F64 = 28
83
+ SD_TYPE_IQ1_M = 29
84
+ SD_TYPE_BF16 = 30
85
+ SD_TYPE_Q4_0_4_4 = 31
86
+ SD_TYPE_Q4_0_4_8 = 32
87
+ SD_TYPE_Q4_0_8_8 = 33
88
+ SD_TYPE_COUNT = 34
89
+
90
+
91
+ class SDImage(Structure):
92
+ _fields_ = [
93
+ ("width", c_uint32),
94
+ ("height", c_uint32),
95
+ ("channel", c_uint32),
96
+ ("data", POINTER(c_uint8)),
97
+ ]
98
+
99
+
100
+ class SDCPPLogLevel(c_int):
101
+ SD_LOG_LEVEL_DEBUG = 0
102
+ SD_LOG_LEVEL_INFO = 1
103
+ SD_LOG_LEVEL_WARNING = 2
104
+ SD_LOG_LEVEL_ERROR = 3