File size: 2,902 Bytes
90f5392
 
892f484
90f5392
 
 
892f484
 
 
90f5392
892f484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aacc39b
0db0051
 
892f484
 
0db0051
 
892f484
0db0051
892f484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0db0051
892f484
 
 
 
 
 
 
aacc39b
892f484
 
90f5392
0db0051
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from flask import Flask, request, jsonify
from sentence_transformers import SentenceTransformer, util
import logging

app = Flask(__name__)

# 配置日志
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")

# 预定义代码片段(已验证数据)
CODE_SNIPPETS = [
    """def sort_list(x): return sorted(x)""",
    """def count_above_threshold(elements, threshold=0):
    return sum(1 for e in elements if e > threshold)""",
    """def find_min_max(elements):
    return min(elements), max(elements)"""
]

# 输入数据验证
def validate_snippets(snippets):
    cleaned = []
    for idx, s in enumerate(snippets):
        if not isinstance(s, str):
            app.logger.warning(f"索引 {idx} 类型错误,已转换为字符串")
            s = str(s)
        cleaned.append(s.replace("...", "").strip())
    return [s for s in cleaned if len(s) > 0]

# 初始化模型和编码
try:
    model = SentenceTransformer("flax-sentence-embeddings/st-codesearch-distilroberta-base")
    valid_snippets = validate_snippets(CODE_SNIPPETS)
    code_emb = model.encode(valid_snippets, convert_to_tensor=True)
    app.logger.info(f"成功加载模型,编码 {len(valid_snippets)} 个有效代码片段")
except Exception as e:
    app.logger.error(f"初始化失败: {str(e)}")
    raise

@app.route('/search', methods=['POST'])
def handle_search():
    """API 处理端点"""
    try:
        # 请求验证
        if not request.is_json:
            app.logger.warning("无效的 Content-Type")
            return jsonify({"error": "需要 application/json"}), 415
            
        data = request.get_json()
        query = data.get('query', '').strip()
        
        if not query:
            app.logger.warning("收到空查询")
            return jsonify({"error": "查询不能为空"}), 400
            
        # 编码查询
        try:
            query_emb = model.encode(query, convert_to_tensor=True)
        except Exception as e:
            app.logger.error(f"编码失败: {str(e)}")
            return jsonify({"error": "查询处理失败"}), 500
            
        # 语义搜索
        try:
            hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
            best = hits[0]
            result = {
                "code": valid_snippets[best['corpus_id']],
                "score": round(float(best['score']), 4)
            }
            app.logger.info(f"成功处理查询: '{query}'")
            return jsonify(result)
            
        except IndexError:
            app.logger.error("无匹配结果")
            return jsonify({"error": "无可用匹配"}), 404
            
    except Exception as e:
        app.logger.error(f"未知错误: {str(e)}", exc_info=True)
        return jsonify({"error": "服务器内部错误"}), 500

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=8080)