Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
from flask import Flask, request, jsonify
|
2 |
from sentence_transformers import SentenceTransformer, util
|
3 |
import logging
|
4 |
import sys
|
5 |
import signal
|
6 |
|
7 |
-
# 初始化Flask应用
|
8 |
app = Flask(__name__)
|
9 |
|
10 |
-
#
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
app.logger = logging.getLogger("CodeSearchAPI")
|
13 |
|
@@ -31,42 +31,48 @@ def handle_shutdown(signum, frame):
|
|
31 |
signal.signal(signal.SIGTERM, handle_shutdown)
|
32 |
signal.signal(signal.SIGINT, handle_shutdown)
|
33 |
|
34 |
-
#
|
35 |
try:
|
36 |
-
|
37 |
model = SentenceTransformer(
|
38 |
"flax-sentence-embeddings/st-codesearch-distilroberta-base",
|
39 |
cache_folder="/model-cache"
|
40 |
)
|
41 |
-
|
42 |
-
|
43 |
-
code_emb = model.encode(CODE_SNIPPETS,
|
44 |
-
convert_to_tensor=True,
|
45 |
-
device="cpu")
|
46 |
-
|
47 |
service_ready = True
|
48 |
app.logger.info("服务初始化完成")
|
49 |
except Exception as e:
|
50 |
-
app.logger.error(
|
51 |
raise
|
52 |
|
53 |
-
# Hugging Face
|
54 |
@app.route('/')
|
55 |
def hf_health_check():
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
if service_ready:
|
58 |
return jsonify({"status": "ready"}), 200
|
59 |
else:
|
60 |
return jsonify({"status": "initializing"}), 503
|
61 |
|
62 |
-
#
|
63 |
@app.route('/search', methods=['GET', 'POST'])
|
64 |
def handle_search():
|
65 |
if not service_ready:
|
|
|
66 |
return jsonify({"error": "服务正在初始化"}), 503
|
67 |
|
68 |
try:
|
69 |
-
#
|
70 |
if request.method == 'GET':
|
71 |
query = request.args.get('query', '').strip()
|
72 |
else:
|
@@ -76,14 +82,12 @@ def handle_search():
|
|
76 |
if not query:
|
77 |
app.logger.info("收到空的查询请求")
|
78 |
return jsonify({"error": "查询不能为空"}), 400
|
79 |
-
|
80 |
# 记录接收到的查询
|
81 |
app.logger.info("收到查询请求: %s", query)
|
82 |
|
83 |
-
#
|
84 |
-
query_emb = model.encode(query,
|
85 |
-
convert_to_tensor=True,
|
86 |
-
device="cpu")
|
87 |
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
|
88 |
best = hits[0]
|
89 |
|
@@ -91,15 +95,15 @@ def handle_search():
|
|
91 |
"code": CODE_SNIPPETS[best['corpus_id']],
|
92 |
"score": round(float(best['score']), 4)
|
93 |
}
|
|
|
94 |
# 记录返回结果
|
95 |
app.logger.info("返回结果: %s", result)
|
96 |
-
|
97 |
return jsonify(result)
|
98 |
-
|
99 |
except Exception as e:
|
100 |
app.logger.error("请求处理失败: %s", str(e))
|
101 |
return jsonify({"error": "服务器内部错误"}), 500
|
102 |
|
103 |
if __name__ == "__main__":
|
104 |
-
# Hugging Face Spaces
|
105 |
app.run(host='0.0.0.0', port=7860)
|
|
|
1 |
+
from flask import Flask, request, jsonify, render_template_string
|
2 |
from sentence_transformers import SentenceTransformer, util
|
3 |
import logging
|
4 |
import sys
|
5 |
import signal
|
6 |
|
7 |
+
# 初始化 Flask 应用
|
8 |
app = Flask(__name__)
|
9 |
|
10 |
+
# 配置日志,级别设为 INFO
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
app.logger = logging.getLogger("CodeSearchAPI")
|
13 |
|
|
|
31 |
signal.signal(signal.SIGTERM, handle_shutdown)
|
32 |
signal.signal(signal.SIGINT, handle_shutdown)
|
33 |
|
34 |
+
# 初始化模型和预计算编码
|
35 |
try:
|
36 |
+
app.logger.info("开始加载模型...")
|
37 |
model = SentenceTransformer(
|
38 |
"flax-sentence-embeddings/st-codesearch-distilroberta-base",
|
39 |
cache_folder="/model-cache"
|
40 |
)
|
41 |
+
# 预计算代码片段的编码(强制使用 CPU)
|
42 |
+
code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True, device="cpu")
|
|
|
|
|
|
|
|
|
43 |
service_ready = True
|
44 |
app.logger.info("服务初始化完成")
|
45 |
except Exception as e:
|
46 |
+
app.logger.error("初始化失败: %s", str(e))
|
47 |
raise
|
48 |
|
49 |
+
# Hugging Face 健康检查端点,必须响应根路径
|
50 |
@app.route('/')
|
51 |
def hf_health_check():
|
52 |
+
# 如果请求接受 HTML,则返回一个简单的 HTML 页面(包含测试链接)
|
53 |
+
if request.accept_mimetypes.accept_html:
|
54 |
+
html = """
|
55 |
+
<h2>CodeSearch API</h2>
|
56 |
+
<p>服务状态:{{ status }}</p>
|
57 |
+
<p>你可以在地址栏输入 /search?query=你的查询 来测试接口</p>
|
58 |
+
"""
|
59 |
+
status = "ready" if service_ready else "initializing"
|
60 |
+
return render_template_string(html, status=status)
|
61 |
+
# 否则返回 JSON 格式的健康检查
|
62 |
if service_ready:
|
63 |
return jsonify({"status": "ready"}), 200
|
64 |
else:
|
65 |
return jsonify({"status": "initializing"}), 503
|
66 |
|
67 |
+
# 搜索 API 端点,同时支持 GET 和 POST 请求
|
68 |
@app.route('/search', methods=['GET', 'POST'])
|
69 |
def handle_search():
|
70 |
if not service_ready:
|
71 |
+
app.logger.info("服务未就绪")
|
72 |
return jsonify({"error": "服务正在初始化"}), 503
|
73 |
|
74 |
try:
|
75 |
+
# 根据请求方法提取查询内容
|
76 |
if request.method == 'GET':
|
77 |
query = request.args.get('query', '').strip()
|
78 |
else:
|
|
|
82 |
if not query:
|
83 |
app.logger.info("收到空的查询请求")
|
84 |
return jsonify({"error": "查询不能为空"}), 400
|
85 |
+
|
86 |
# 记录接收到的查询
|
87 |
app.logger.info("收到查询请求: %s", query)
|
88 |
|
89 |
+
# 对查询进行编码,并进行语义搜索
|
90 |
+
query_emb = model.encode(query, convert_to_tensor=True, device="cpu")
|
|
|
|
|
91 |
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
|
92 |
best = hits[0]
|
93 |
|
|
|
95 |
"code": CODE_SNIPPETS[best['corpus_id']],
|
96 |
"score": round(float(best['score']), 4)
|
97 |
}
|
98 |
+
|
99 |
# 记录返回结果
|
100 |
app.logger.info("返回结果: %s", result)
|
|
|
101 |
return jsonify(result)
|
102 |
+
|
103 |
except Exception as e:
|
104 |
app.logger.error("请求处理失败: %s", str(e))
|
105 |
return jsonify({"error": "服务器内部错误"}), 500
|
106 |
|
107 |
if __name__ == "__main__":
|
108 |
+
# 本地测试用,Hugging Face Spaces 通常通过 gunicorn 启动
|
109 |
app.run(host='0.0.0.0', port=7860)
|