Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,13 +7,13 @@ import signal
|
|
7 |
# 初始化Flask应用
|
8 |
app = Flask(__name__)
|
9 |
|
10 |
-
#
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
app.logger = logging.getLogger("CodeSearchAPI")
|
13 |
|
14 |
# 预定义代码片段
|
15 |
CODE_SNIPPETS = [
|
16 |
-
"
|
17 |
"""def count_above_threshold(elements, threshold=0):
|
18 |
return sum(1 for e in elements if e > threshold)""",
|
19 |
"""def find_min_max(elements):
|
@@ -41,8 +41,8 @@ try:
|
|
41 |
|
42 |
# 预计算编码(强制使用CPU)
|
43 |
code_emb = model.encode(CODE_SNIPPETS,
|
44 |
-
|
45 |
-
|
46 |
|
47 |
service_ready = True
|
48 |
app.logger.info("服务初始化完成")
|
@@ -59,34 +59,47 @@ def hf_health_check():
|
|
59 |
else:
|
60 |
return jsonify({"status": "initializing"}), 503
|
61 |
|
62 |
-
#
|
63 |
-
@app.route('/search', methods=['POST'])
|
64 |
def handle_search():
|
65 |
if not service_ready:
|
66 |
return jsonify({"error": "服务正在初始化"}), 503
|
67 |
-
|
68 |
try:
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
if not query:
|
|
|
73 |
return jsonify({"error": "查询不能为空"}), 400
|
74 |
-
|
|
|
|
|
|
|
|
|
75 |
query_emb = model.encode(query,
|
76 |
-
|
77 |
-
|
78 |
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
|
79 |
best = hits[0]
|
80 |
-
|
81 |
-
|
82 |
"code": CODE_SNIPPETS[best['corpus_id']],
|
83 |
"score": round(float(best['score']), 4)
|
84 |
-
}
|
|
|
|
|
|
|
|
|
85 |
|
86 |
except Exception as e:
|
87 |
-
app.logger.error(
|
88 |
return jsonify({"error": "服务器内部错误"}), 500
|
89 |
|
90 |
if __name__ == "__main__":
|
91 |
# Hugging Face Spaces会通过gunicorn启动,此处仅为本地测试保留
|
92 |
-
app.run(host='0.0.0.0', port=7860)
|
|
|
7 |
# 初始化Flask应用
|
8 |
app = Flask(__name__)
|
9 |
|
10 |
+
# 配置日志,日志级别设为INFO
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
app.logger = logging.getLogger("CodeSearchAPI")
|
13 |
|
14 |
# 预定义代码片段
|
15 |
CODE_SNIPPETS = [
|
16 |
+
"def sort_list(x): return sorted(x)",
|
17 |
"""def count_above_threshold(elements, threshold=0):
|
18 |
return sum(1 for e in elements if e > threshold)""",
|
19 |
"""def find_min_max(elements):
|
|
|
41 |
|
42 |
# 预计算编码(强制使用CPU)
|
43 |
code_emb = model.encode(CODE_SNIPPETS,
|
44 |
+
convert_to_tensor=True,
|
45 |
+
device="cpu")
|
46 |
|
47 |
service_ready = True
|
48 |
app.logger.info("服务初始化完成")
|
|
|
59 |
else:
|
60 |
return jsonify({"status": "initializing"}), 503
|
61 |
|
62 |
+
# 支持GET和POST请求的搜索API端点
|
63 |
+
@app.route('/search', methods=['GET', 'POST'])
|
64 |
def handle_search():
|
65 |
if not service_ready:
|
66 |
return jsonify({"error": "服务正在初始化"}), 503
|
67 |
+
|
68 |
try:
|
69 |
+
# 区分GET和POST请求,GET从URL参数中获取query,POST从JSON体中获取
|
70 |
+
if request.method == 'GET':
|
71 |
+
query = request.args.get('query', '').strip()
|
72 |
+
else:
|
73 |
+
data = request.get_json() or {}
|
74 |
+
query = data.get('query', '').strip()
|
75 |
+
|
76 |
if not query:
|
77 |
+
app.logger.info("收到空的查询请求")
|
78 |
return jsonify({"error": "查询不能为空"}), 400
|
79 |
+
|
80 |
+
# 记录接收到的查询
|
81 |
+
app.logger.info("收到查询请求: %s", query)
|
82 |
+
|
83 |
+
# 对查询进行编码,并搜索最匹配的代码片段
|
84 |
query_emb = model.encode(query,
|
85 |
+
convert_to_tensor=True,
|
86 |
+
device="cpu")
|
87 |
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
|
88 |
best = hits[0]
|
89 |
+
|
90 |
+
result = {
|
91 |
"code": CODE_SNIPPETS[best['corpus_id']],
|
92 |
"score": round(float(best['score']), 4)
|
93 |
+
}
|
94 |
+
# 记录返回结果
|
95 |
+
app.logger.info("返回结果: %s", result)
|
96 |
+
|
97 |
+
return jsonify(result)
|
98 |
|
99 |
except Exception as e:
|
100 |
+
app.logger.error("请求处理失败: %s", str(e))
|
101 |
return jsonify({"error": "服务器内部错误"}), 500
|
102 |
|
103 |
if __name__ == "__main__":
|
104 |
# Hugging Face Spaces会通过gunicorn启动,此处仅为本地测试保留
|
105 |
+
app.run(host='0.0.0.0', port=7860)
|