File size: 15,328 Bytes
600759a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <algorithm>
#include <cmath>
#include <queue>
#include <vector>
#include <functional>

namespace py = pybind11;
using namespace std;

namespace {
// 内部数据结构,避免重复的buffer获取和指针设置
struct MeshData {
    int texture_height, texture_width, texture_channel;
    int vtx_num;
    float* texture_ptr;
    uint8_t* mask_ptr;
    float* vtx_pos_ptr;
    float* vtx_uv_ptr;
    int* pos_idx_ptr;
    int* uv_idx_ptr;
    
    // 存储buffer以防止被销毁
    py::buffer_info texture_buf, mask_buf, vtx_pos_buf, vtx_uv_buf, pos_idx_buf, uv_idx_buf;
    
    MeshData(py::array_t<float>& texture, py::array_t<uint8_t>& mask, 
             py::array_t<float>& vtx_pos, py::array_t<float>& vtx_uv,
             py::array_t<int>& pos_idx, py::array_t<int>& uv_idx) {
        
        texture_buf = texture.request();
        mask_buf = mask.request();
        vtx_pos_buf = vtx_pos.request();
        vtx_uv_buf = vtx_uv.request();
        pos_idx_buf = pos_idx.request();
        uv_idx_buf = uv_idx.request();

        texture_height = texture_buf.shape[0];
        texture_width = texture_buf.shape[1];
        texture_channel = texture_buf.shape[2];
        texture_ptr = static_cast<float*>(texture_buf.ptr);
        mask_ptr = static_cast<uint8_t*>(mask_buf.ptr);

        vtx_num = vtx_pos_buf.shape[0];
        vtx_pos_ptr = static_cast<float*>(vtx_pos_buf.ptr);
        vtx_uv_ptr = static_cast<float*>(vtx_uv_buf.ptr);
        pos_idx_ptr = static_cast<int*>(pos_idx_buf.ptr);
        uv_idx_ptr = static_cast<int*>(uv_idx_buf.ptr);
    }
};

// 公共函数:计算UV坐标
pair<int, int> calculateUVCoordinates(int vtx_uv_idx, const MeshData& data) {
    int uv_v = round(data.vtx_uv_ptr[vtx_uv_idx * 2] * (data.texture_width - 1));
    int uv_u = round((1.0 - data.vtx_uv_ptr[vtx_uv_idx * 2 + 1]) * (data.texture_height - 1));
    return make_pair(uv_u, uv_v);
}

// 公共函数:计算距离权重
float calculateDistanceWeight(const array<float, 3>& vtx_0, const array<float, 3>& vtx1) {
    float dist_weight = 1.0f / max(
        sqrt(
            pow(vtx_0[0] - vtx1[0], 2) + 
            pow(vtx_0[1] - vtx1[1], 2) + 
            pow(vtx_0[2] - vtx1[2], 2)
        ), 1E-4);
    return dist_weight * dist_weight;
}

// 公共函数:获取顶点位置
array<float, 3> getVertexPosition(int vtx_idx, const MeshData& data) {
    return {data.vtx_pos_ptr[vtx_idx * 3], 
            data.vtx_pos_ptr[vtx_idx * 3 + 1], 
            data.vtx_pos_ptr[vtx_idx * 3 + 2]};
}

// 公共函数:构建图结构
void buildGraph(vector<vector<int>>& G, const MeshData& data) {
    G.resize(data.vtx_num);
    for(int i = 0; i < data.uv_idx_buf.shape[0]; ++i) {
        for(int k = 0; k < 3; ++k) {
            G[data.pos_idx_ptr[i * 3 + k]].push_back(data.pos_idx_ptr[i * 3 + (k + 1) % 3]);
        }
    }
}

// 通用初始化函数:处理两种掩码类型(float和int)
template<typename MaskType>
void initializeVertexDataGeneric(const MeshData& data, vector<MaskType>& vtx_mask, 
                                vector<vector<float>>& vtx_color, vector<int>* uncolored_vtxs = nullptr,
                                MaskType mask_value = static_cast<MaskType>(1)) {
    vtx_mask.assign(data.vtx_num, static_cast<MaskType>(0));
    vtx_color.assign(data.vtx_num, vector<float>(data.texture_channel, 0.0f));
    
    if(uncolored_vtxs) {
        uncolored_vtxs->clear();
    }

    for(int i = 0; i < data.uv_idx_buf.shape[0]; ++i) {
        for(int k = 0; k < 3; ++k) {
            int vtx_uv_idx = data.uv_idx_ptr[i * 3 + k];
            int vtx_idx = data.pos_idx_ptr[i * 3 + k];
            auto uv_coords = calculateUVCoordinates(vtx_uv_idx, data);

            if(data.mask_ptr[uv_coords.first * data.texture_width + uv_coords.second] > 0) {
                vtx_mask[vtx_idx] = mask_value;
                for(int c = 0; c < data.texture_channel; ++c) {
                    vtx_color[vtx_idx][c] = data.texture_ptr[(uv_coords.first * data.texture_width + 
                                                            uv_coords.second) * data.texture_channel + c];
                }
            } else if(uncolored_vtxs) {
                uncolored_vtxs->push_back(vtx_idx);
            }
        }
    }
}

// 通用平滑算法:支持不同的掩码类型和检查函数
template<typename MaskType>
void performSmoothingAlgorithm(const MeshData& data, const vector<vector<int>>& G,
                              vector<MaskType>& vtx_mask, vector<vector<float>>& vtx_color, 
                              const vector<int>& uncolored_vtxs,
                              function<bool(MaskType)> is_colored_func,
                              function<void(MaskType&)> set_colored_func) {
    int smooth_count = 2;
    int last_uncolored_vtx_count = 0;
    
    while(smooth_count > 0) {
        int uncolored_vtx_count = 0;

        for(int vtx_idx : uncolored_vtxs) {
            vector<float> sum_color(data.texture_channel, 0.0f);
            float total_weight = 0.0f;

            array<float, 3> vtx_0 = getVertexPosition(vtx_idx, data);
            
            for(int connected_idx : G[vtx_idx]) {
                if(is_colored_func(vtx_mask[connected_idx])) {
                    array<float, 3> vtx1 = getVertexPosition(connected_idx, data);
                    float dist_weight = calculateDistanceWeight(vtx_0, vtx1);
                    
                    for(int c = 0; c < data.texture_channel; ++c) {
                        sum_color[c] += vtx_color[connected_idx][c] * dist_weight;
                    }
                    total_weight += dist_weight;
                }
            }

            if(total_weight > 0.0f) {
                for(int c = 0; c < data.texture_channel; ++c) {
                    vtx_color[vtx_idx][c] = sum_color[c] / total_weight;
                }
                set_colored_func(vtx_mask[vtx_idx]);
            } else {
                uncolored_vtx_count++;
            }
        }

        if(last_uncolored_vtx_count == uncolored_vtx_count) {
            smooth_count--;
        } else {
            smooth_count++;
        }
        last_uncolored_vtx_count = uncolored_vtx_count;
    }
}

// 前向传播算法的通用实现
void performForwardPropagation(const MeshData& data, const vector<vector<int>>& G,
                              vector<float>& vtx_mask, vector<vector<float>>& vtx_color,
                              queue<int>& active_vtxs) {
    while(!active_vtxs.empty()) {
        queue<int> pending_active_vtxs;
        
        while(!active_vtxs.empty()) {
            int vtx_idx = active_vtxs.front();
            active_vtxs.pop();
            array<float, 3> vtx_0 = getVertexPosition(vtx_idx, data);
            
            for(int connected_idx : G[vtx_idx]) {
                if(vtx_mask[connected_idx] > 0) continue;
                
                array<float, 3> vtx1 = getVertexPosition(connected_idx, data);
                float dist_weight = calculateDistanceWeight(vtx_0, vtx1);
                
                for(int c = 0; c < data.texture_channel; ++c) {
                    vtx_color[connected_idx][c] += vtx_color[vtx_idx][c] * dist_weight;
                }
                
                if(vtx_mask[connected_idx] == 0) {
                    pending_active_vtxs.push(connected_idx);
                }
                vtx_mask[connected_idx] -= dist_weight;
            }
        }

        while(!pending_active_vtxs.empty()) {
            int vtx_idx = pending_active_vtxs.front();
            pending_active_vtxs.pop();
            
            for(int c = 0; c < data.texture_channel; ++c) {
                vtx_color[vtx_idx][c] /= -vtx_mask[vtx_idx];
            }
            vtx_mask[vtx_idx] = 1.0f;
            active_vtxs.push(vtx_idx);
        }
    }
}

// 公共函数:创建输出数组
pair<py::array_t<float>, py::array_t<uint8_t>> createOutputArrays(
    const MeshData& data, const vector<float>& vtx_mask, 
    const vector<vector<float>>& vtx_color) {
    
    py::array_t<float> new_texture(data.texture_buf.size);
    py::array_t<uint8_t> new_mask(data.mask_buf.size);

    auto new_texture_buf = new_texture.request();
    auto new_mask_buf = new_mask.request();

    float* new_texture_ptr = static_cast<float*>(new_texture_buf.ptr);
    uint8_t* new_mask_ptr = static_cast<uint8_t*>(new_mask_buf.ptr);
    
    // Copy original texture and mask to new arrays
    copy(data.texture_ptr, data.texture_ptr + data.texture_buf.size, new_texture_ptr);
    copy(data.mask_ptr, data.mask_ptr + data.mask_buf.size, new_mask_ptr);

    for(int face_idx = 0; face_idx < data.uv_idx_buf.shape[0]; ++face_idx) {
        for(int k = 0; k < 3; ++k) {
            int vtx_uv_idx = data.uv_idx_ptr[face_idx * 3 + k];
            int vtx_idx = data.pos_idx_ptr[face_idx * 3 + k];

            if(vtx_mask[vtx_idx] == 1.0f) {
                auto uv_coords = calculateUVCoordinates(vtx_uv_idx, data);
                
                for(int c = 0; c < data.texture_channel; ++c) {
                    new_texture_ptr[
                        (uv_coords.first * data.texture_width + uv_coords.second) * 
                        data.texture_channel + c
                    ] = vtx_color[vtx_idx][c];
                }
                new_mask_ptr[uv_coords.first * data.texture_width + uv_coords.second] = 255;
            }
        }
    }

    // Reshape the new arrays to match the original texture and mask shapes
    new_texture.resize({data.texture_height, data.texture_width, 3});
    new_mask.resize({data.texture_height, data.texture_width});

    return make_pair(new_texture, new_mask);
}

// 创建顶点颜色输出数组的专用函数
pair<py::array_t<float>, py::array_t<uint8_t>> createVertexColorOutput(
    const MeshData& data, const vector<int>& vtx_mask, 
    const vector<vector<float>>& vtx_color) {
    
    py::array_t<float> py_vtx_color({data.vtx_num, data.texture_channel});
    py::array_t<uint8_t> py_vtx_mask({data.vtx_num});

    auto py_vtx_color_buf = py_vtx_color.request();
    auto py_vtx_mask_buf = py_vtx_mask.request();

    float* py_vtx_color_ptr = static_cast<float*>(py_vtx_color_buf.ptr);
    uint8_t* py_vtx_mask_ptr = static_cast<uint8_t*>(py_vtx_mask_buf.ptr);

    for(int i = 0; i < data.vtx_num; ++i) {
        py_vtx_mask_ptr[i] = vtx_mask[i];
        for(int c = 0; c < data.texture_channel; ++c) {
            py_vtx_color_ptr[i * data.texture_channel + c] = vtx_color[i][c];
        }
    }

    return make_pair(py_vtx_color, py_vtx_mask);
}

} // anonymous namespace

// 重构后的 meshVerticeInpaint_smooth 函数
pair<py::array_t<float>, py::array_t<uint8_t>> meshVerticeInpaint_smooth(
    py::array_t<float> texture, py::array_t<uint8_t> mask, py::array_t<float> vtx_pos, py::array_t<float> vtx_uv,
    py::array_t<int> pos_idx, py::array_t<int> uv_idx) {
    
    MeshData data(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx);
    
    vector<float> vtx_mask;
    vector<vector<float>> vtx_color;
    vector<int> uncolored_vtxs;
    vector<vector<int>> G;

    initializeVertexDataGeneric(data, vtx_mask, vtx_color, &uncolored_vtxs, 1.0f);
    buildGraph(G, data);

    // 使用通用平滑算法
    performSmoothingAlgorithm<float>(data, G, vtx_mask, vtx_color, uncolored_vtxs,
        [](float mask_val) { return mask_val > 0; },  // 检查是否着色
        [](float& mask_val) { mask_val = 1.0f; }      // 设置为已着色
    );

    return createOutputArrays(data, vtx_mask, vtx_color);
}

// 重构后的 meshVerticeInpaint_forward 函数
pair<py::array_t<float>, py::array_t<uint8_t>> meshVerticeInpaint_forward(
    py::array_t<float> texture, py::array_t<uint8_t> mask, py::array_t<float> vtx_pos, py::array_t<float> vtx_uv,
    py::array_t<int> pos_idx, py::array_t<int> uv_idx) {
    
    MeshData data(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx);
    
    vector<float> vtx_mask;
    vector<vector<float>> vtx_color;
    vector<vector<int>> G;
    queue<int> active_vtxs;

    // 使用通用初始化(不需要 uncolored_vtxs)
    initializeVertexDataGeneric(data, vtx_mask, vtx_color, nullptr, 1.0f);
    buildGraph(G, data);

    // 收集活跃顶点
    for(int i = 0; i < data.vtx_num; ++i) {
        if(vtx_mask[i] == 1.0f) {
            active_vtxs.push(i);
        }
    }

    // 使用通用前向传播算法
    performForwardPropagation(data, G, vtx_mask, vtx_color, active_vtxs);

    return createOutputArrays(data, vtx_mask, vtx_color);
}

// 主接口函数
pair<py::array_t<float>, py::array_t<uint8_t>> meshVerticeInpaint(
    py::array_t<float> texture, py::array_t<uint8_t> mask, py::array_t<float> vtx_pos, py::array_t<float> vtx_uv,
    py::array_t<int> pos_idx, py::array_t<int> uv_idx, const string& method = "smooth") {
    
    if(method == "smooth") {
        return meshVerticeInpaint_smooth(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx);
    } else if(method == "forward") {
        return meshVerticeInpaint_forward(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx);
    } else {
        throw invalid_argument("Invalid method. Use 'smooth' or 'forward'.");
    }
}

//============================

// 重构后的 meshVerticeColor_smooth 函数
pair<py::array_t<float>, py::array_t<uint8_t>> meshVerticeColor_smooth(
    py::array_t<float> texture, py::array_t<uint8_t> mask, py::array_t<float> vtx_pos, py::array_t<float> vtx_uv,
    py::array_t<int> pos_idx, py::array_t<int> uv_idx) {
    
    MeshData data(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx);
    
    vector<int> vtx_mask;
    vector<vector<float>> vtx_color;
    vector<int> uncolored_vtxs;
    vector<vector<int>> G;

    initializeVertexDataGeneric(data, vtx_mask, vtx_color, &uncolored_vtxs, 1);
    buildGraph(G, data);

    // 使用通用平滑算法
    performSmoothingAlgorithm<int>(data, G, vtx_mask, vtx_color, uncolored_vtxs,
        [](int mask_val) { return mask_val > 0; },    // 检查是否着色
        [](int& mask_val) { mask_val = 2; }           // 设置为已着色(值为2)
    );

    return createVertexColorOutput(data, vtx_mask, vtx_color);
}

// meshVerticeColor 主接口函数
pair<py::array_t<float>, py::array_t<uint8_t>> meshVerticeColor(
    py::array_t<float> texture, py::array_t<uint8_t> mask, py::array_t<float> vtx_pos, py::array_t<float> vtx_uv,
    py::array_t<int> pos_idx, py::array_t<int> uv_idx, const string& method = "smooth") {
    
    if(method == "smooth") {
        return meshVerticeColor_smooth(texture, mask, vtx_pos, vtx_uv, pos_idx, uv_idx);
    } else {
        throw invalid_argument("Invalid method. Use 'smooth' or 'forward'.");
    }
}

// Python绑定
PYBIND11_MODULE(mesh_inpaint_processor, m) {
    m.def("meshVerticeInpaint", &meshVerticeInpaint, "A function to process mesh", 
          py::arg("texture"), py::arg("mask"), py::arg("vtx_pos"), py::arg("vtx_uv"), 
          py::arg("pos_idx"), py::arg("uv_idx"), py::arg("method") = "smooth");
    m.def("meshVerticeColor", &meshVerticeColor, "A function to process mesh", 
          py::arg("texture"), py::arg("mask"), py::arg("vtx_pos"), py::arg("vtx_uv"), 
          py::arg("pos_idx"), py::arg("uv_idx"), py::arg("method") = "smooth");
}