File size: 2,729 Bytes
90f5392
 
892f484
7c08c7b
 
90f5392
7c08c7b
90f5392
 
892f484
 
 
90f5392
79ce9cc
892f484
 
 
 
 
 
 
 
7c08c7b
 
892f484
7c08c7b
 
 
 
 
 
 
 
 
892f484
7c08c7b
79ce9cc
 
7c08c7b
79ce9cc
 
7c08c7b
 
 
 
 
 
 
892f484
7c08c7b
892f484
 
7c08c7b
 
 
 
 
79ce9cc
 
 
 
7c08c7b
892f484
 
7c08c7b
79ce9cc
7c08c7b
aacc39b
0db0051
892f484
0db0051
892f484
 
 
7c08c7b
 
 
79ce9cc
 
 
 
 
 
 
 
aacc39b
79ce9cc
892f484
90f5392
0db0051
7c08c7b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from flask import Flask, request, jsonify
from sentence_transformers import SentenceTransformer, util
import logging
import sys
import signal

# 初始化Flask应用
app = Flask(__name__)

# 配置日志
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")

# 预定义代码片段
CODE_SNIPPETS = [
    """def sort_list(x): return sorted(x)""",
    """def count_above_threshold(elements, threshold=0):
    return sum(1 for e in elements if e > threshold)""",
    """def find_min_max(elements):
    return min(elements), max(elements)"""
]

# 全局服务状态
service_ready = False

# 优雅关闭处理
def handle_shutdown(signum, frame):
    app.logger.info("收到终止信号,开始关闭...")
    sys.exit(0)

signal.signal(signal.SIGTERM, handle_shutdown)
signal.signal(signal.SIGINT, handle_shutdown)

# 初始化模型和编码
try:
    # Hugging Face Spaces专用缓存路径
    model = SentenceTransformer(
        "flax-sentence-embeddings/st-codesearch-distilroberta-base",
        cache_folder="/model-cache"
    )
    
    # 预计算编码(强制使用CPU)
    code_emb = model.encode(CODE_SNIPPETS, 
                          convert_to_tensor=True,
                          device="cpu")
    
    service_ready = True
    app.logger.info("服务初始化完成")
except Exception as e:
    app.logger.error(f"初始化失败: {str(e)}")
    raise

# Hugging Face健康检查端点
@app.route('/')
def hf_health_check():
    """必须响应根路径的健康检查"""
    if service_ready:
        return jsonify({"status": "ready"}), 200
    else:
        return jsonify({"status": "initializing"}), 503

# 搜索API端点
@app.route('/search', methods=['POST'])
def handle_search():
    if not service_ready:
        return jsonify({"error": "服务正在初始化"}), 503
    
    try:
        data = request.get_json()
        query = data.get('query', '').strip()
        
        if not query:
            return jsonify({"error": "查询不能为空"}), 400
            
        query_emb = model.encode(query, 
                               convert_to_tensor=True,
                               device="cpu")
        hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
        best = hits[0]
        
        return jsonify({
            "code": CODE_SNIPPETS[best['corpus_id']],
            "score": round(float(best['score']), 4)
        })
        
    except Exception as e:
        app.logger.error(f"请求处理失败: {str(e)}")
        return jsonify({"error": "服务器内部错误"}), 500

if __name__ == "__main__":
    # Hugging Face Spaces会通过gunicorn启动,此处仅为本地测试保留
    app.run(host='0.0.0.0', port=7860)