Forrest99 commited on
Commit
892f484
·
verified ·
1 Parent(s): 73e30a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -24
app.py CHANGED
@@ -1,41 +1,83 @@
1
  from flask import Flask, request, jsonify
2
  from sentence_transformers import SentenceTransformer, util
 
3
 
4
  app = Flask(__name__)
5
 
6
- # 预加载模型(线上部署关键)
7
- model = SentenceTransformer("flax-sentence-embeddings/st-codesearch-distilroberta-base")
8
- code_snippets = [...] # 你的代码片段
9
- code_embeddings = model.encode(code_snippets, convert_to_tensor=True)
10
 
11
- @app.route('/v1/search', methods=['POST'])
12
- def api_handler():
13
- """生产级API端点"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  try:
15
  # 请求验证
16
  if not request.is_json:
17
- return jsonify({"error": "Invalid Content-Type"}), 415
 
18
 
19
  data = request.get_json()
20
- if 'query' not in data:
21
- return jsonify({"error": "Missing query parameter"}), 400
22
-
23
- # 执行语义搜索
24
- query = data['query']
25
- query_emb = model.encode(query, convert_to_tensor=True)
26
- results = util.semantic_search(query_emb, code_embeddings, top_k=1)[0]
27
 
28
- # 构建响应
29
- return jsonify({
30
- "data": {
31
- "best_match": code_snippets[results[0]['corpus_id']],
32
- "similarity": float(results[0]['score'])
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  }
34
- })
35
-
 
 
 
 
 
36
  except Exception as e:
37
- app.logger.error(f"API Error: {str(e)}")
38
- return jsonify({"error": "Internal Server Error"}), 500
39
 
40
  if __name__ == "__main__":
41
  app.run(host='0.0.0.0', port=8080)
 
1
  from flask import Flask, request, jsonify
2
  from sentence_transformers import SentenceTransformer, util
3
+ import logging
4
 
5
  app = Flask(__name__)
6
 
7
+ # 配置日志
8
+ logging.basicConfig(level=logging.INFO)
9
+ app.logger = logging.getLogger("CodeSearchAPI")
 
10
 
11
+ # 预定义代码片段(已验证数据)
12
+ CODE_SNIPPETS = [
13
+ """def sort_list(x): return sorted(x)""",
14
+ """def count_above_threshold(elements, threshold=0):
15
+ return sum(1 for e in elements if e > threshold)""",
16
+ """def find_min_max(elements):
17
+ return min(elements), max(elements)"""
18
+ ]
19
+
20
+ # 输入数据验证
21
+ def validate_snippets(snippets):
22
+ cleaned = []
23
+ for idx, s in enumerate(snippets):
24
+ if not isinstance(s, str):
25
+ app.logger.warning(f"索引 {idx} 类型错误,已转换为字符串")
26
+ s = str(s)
27
+ cleaned.append(s.replace("...", "").strip())
28
+ return [s for s in cleaned if len(s) > 0]
29
+
30
+ # 初始化模型和编码
31
+ try:
32
+ model = SentenceTransformer("flax-sentence-embeddings/st-codesearch-distilroberta-base")
33
+ valid_snippets = validate_snippets(CODE_SNIPPETS)
34
+ code_emb = model.encode(valid_snippets, convert_to_tensor=True)
35
+ app.logger.info(f"成功加载模型,编码 {len(valid_snippets)} 个有效代码片段")
36
+ except Exception as e:
37
+ app.logger.error(f"初始化失败: {str(e)}")
38
+ raise
39
+
40
+ @app.route('/search', methods=['POST'])
41
+ def handle_search():
42
+ """API 处理端点"""
43
  try:
44
  # 请求验证
45
  if not request.is_json:
46
+ app.logger.warning("无效的 Content-Type")
47
+ return jsonify({"error": "需要 application/json"}), 415
48
 
49
  data = request.get_json()
50
+ query = data.get('query', '').strip()
 
 
 
 
 
 
51
 
52
+ if not query:
53
+ app.logger.warning("收到空查询")
54
+ return jsonify({"error": "查询不能为空"}), 400
55
+
56
+ # 编码查询
57
+ try:
58
+ query_emb = model.encode(query, convert_to_tensor=True)
59
+ except Exception as e:
60
+ app.logger.error(f"编码失败: {str(e)}")
61
+ return jsonify({"error": "查询处理失败"}), 500
62
+
63
+ # 语义搜索
64
+ try:
65
+ hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
66
+ best = hits[0]
67
+ result = {
68
+ "code": valid_snippets[best['corpus_id']],
69
+ "score": round(float(best['score']), 4)
70
  }
71
+ app.logger.info(f"成功处理查询: '{query}'")
72
+ return jsonify(result)
73
+
74
+ except IndexError:
75
+ app.logger.error("无匹配结果")
76
+ return jsonify({"error": "无可用匹配"}), 404
77
+
78
  except Exception as e:
79
+ app.logger.error(f"未知错误: {str(e)}", exc_info=True)
80
+ return jsonify({"error": "服务器内部错误"}), 500
81
 
82
  if __name__ == "__main__":
83
  app.run(host='0.0.0.0', port=8080)