Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,41 +1,83 @@
|
|
1 |
from flask import Flask, request, jsonify
|
2 |
from sentence_transformers import SentenceTransformer, util
|
|
|
3 |
|
4 |
app = Flask(__name__)
|
5 |
|
6 |
-
#
|
7 |
-
|
8 |
-
|
9 |
-
code_embeddings = model.encode(code_snippets, convert_to_tensor=True)
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
try:
|
15 |
# 请求验证
|
16 |
if not request.is_json:
|
17 |
-
|
|
|
18 |
|
19 |
data = request.get_json()
|
20 |
-
|
21 |
-
return jsonify({"error": "Missing query parameter"}), 400
|
22 |
-
|
23 |
-
# 执行语义搜索
|
24 |
-
query = data['query']
|
25 |
-
query_emb = model.encode(query, convert_to_tensor=True)
|
26 |
-
results = util.semantic_search(query_emb, code_embeddings, top_k=1)[0]
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
"
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
}
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
36 |
except Exception as e:
|
37 |
-
app.logger.error(f"
|
38 |
-
return jsonify({"error": "
|
39 |
|
40 |
if __name__ == "__main__":
|
41 |
app.run(host='0.0.0.0', port=8080)
|
|
|
1 |
from flask import Flask, request, jsonify
|
2 |
from sentence_transformers import SentenceTransformer, util
|
3 |
+
import logging
|
4 |
|
5 |
app = Flask(__name__)
|
6 |
|
7 |
+
# 配置日志
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
+
app.logger = logging.getLogger("CodeSearchAPI")
|
|
|
10 |
|
11 |
+
# 预定义代码片段(已验证数据)
|
12 |
+
CODE_SNIPPETS = [
|
13 |
+
"""def sort_list(x): return sorted(x)""",
|
14 |
+
"""def count_above_threshold(elements, threshold=0):
|
15 |
+
return sum(1 for e in elements if e > threshold)""",
|
16 |
+
"""def find_min_max(elements):
|
17 |
+
return min(elements), max(elements)"""
|
18 |
+
]
|
19 |
+
|
20 |
+
# 输入数据验证
|
21 |
+
def validate_snippets(snippets):
|
22 |
+
cleaned = []
|
23 |
+
for idx, s in enumerate(snippets):
|
24 |
+
if not isinstance(s, str):
|
25 |
+
app.logger.warning(f"索引 {idx} 类型错误,已转换为字符串")
|
26 |
+
s = str(s)
|
27 |
+
cleaned.append(s.replace("...", "").strip())
|
28 |
+
return [s for s in cleaned if len(s) > 0]
|
29 |
+
|
30 |
+
# 初始化模型和编码
|
31 |
+
try:
|
32 |
+
model = SentenceTransformer("flax-sentence-embeddings/st-codesearch-distilroberta-base")
|
33 |
+
valid_snippets = validate_snippets(CODE_SNIPPETS)
|
34 |
+
code_emb = model.encode(valid_snippets, convert_to_tensor=True)
|
35 |
+
app.logger.info(f"成功加载模型,编码 {len(valid_snippets)} 个有效代码片段")
|
36 |
+
except Exception as e:
|
37 |
+
app.logger.error(f"初始化失败: {str(e)}")
|
38 |
+
raise
|
39 |
+
|
40 |
+
@app.route('/search', methods=['POST'])
|
41 |
+
def handle_search():
|
42 |
+
"""API 处理端点"""
|
43 |
try:
|
44 |
# 请求验证
|
45 |
if not request.is_json:
|
46 |
+
app.logger.warning("无效的 Content-Type")
|
47 |
+
return jsonify({"error": "需要 application/json"}), 415
|
48 |
|
49 |
data = request.get_json()
|
50 |
+
query = data.get('query', '').strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
if not query:
|
53 |
+
app.logger.warning("收到空查询")
|
54 |
+
return jsonify({"error": "查询不能为空"}), 400
|
55 |
+
|
56 |
+
# 编码查询
|
57 |
+
try:
|
58 |
+
query_emb = model.encode(query, convert_to_tensor=True)
|
59 |
+
except Exception as e:
|
60 |
+
app.logger.error(f"编码失败: {str(e)}")
|
61 |
+
return jsonify({"error": "查询处理失败"}), 500
|
62 |
+
|
63 |
+
# 语义搜索
|
64 |
+
try:
|
65 |
+
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
|
66 |
+
best = hits[0]
|
67 |
+
result = {
|
68 |
+
"code": valid_snippets[best['corpus_id']],
|
69 |
+
"score": round(float(best['score']), 4)
|
70 |
}
|
71 |
+
app.logger.info(f"成功处理查询: '{query}'")
|
72 |
+
return jsonify(result)
|
73 |
+
|
74 |
+
except IndexError:
|
75 |
+
app.logger.error("无匹配结果")
|
76 |
+
return jsonify({"error": "无可用匹配"}), 404
|
77 |
+
|
78 |
except Exception as e:
|
79 |
+
app.logger.error(f"未知错误: {str(e)}", exc_info=True)
|
80 |
+
return jsonify({"error": "服务器内部错误"}), 500
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
app.run(host='0.0.0.0', port=8080)
|