Forrest99 commited on
Commit
79ce9cc
·
verified ·
1 Parent(s): 8945d20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -42
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from flask import Flask, request, jsonify
2
  from sentence_transformers import SentenceTransformer, util
3
  import logging
 
4
 
5
  app = Flask(__name__)
6
 
@@ -8,7 +9,7 @@ app = Flask(__name__)
8
  logging.basicConfig(level=logging.INFO)
9
  app.logger = logging.getLogger("CodeSearchAPI")
10
 
11
- # 预定义代码片段(已验证数据)
12
  CODE_SNIPPETS = [
13
  """def sort_list(x): return sorted(x)""",
14
  """def count_above_threshold(elements, threshold=0):
@@ -17,66 +18,61 @@ CODE_SNIPPETS = [
17
  return min(elements), max(elements)"""
18
  ]
19
 
20
- # 输入数据验证
21
- def validate_snippets(snippets):
22
- cleaned = []
23
- for idx, s in enumerate(snippets):
24
- if not isinstance(s, str):
25
- app.logger.warning(f"索引 {idx} 类型错误,已转换为字符串")
26
- s = str(s)
27
- cleaned.append(s.replace("...", "").strip())
28
- return [s for s in cleaned if len(s) > 0]
29
 
30
- # 初始化模型和编码
31
  try:
32
- model = SentenceTransformer("flax-sentence-embeddings/st-codesearch-distilroberta-base")
33
- valid_snippets = validate_snippets(CODE_SNIPPETS)
34
- code_emb = model.encode(valid_snippets, convert_to_tensor=True)
35
- app.logger.info(f"成功加载模型,编码 {len(valid_snippets)} 个有效代码片段")
 
 
 
 
 
 
36
  except Exception as e:
37
- app.logger.error(f"初始化失败: {str(e)}")
38
  raise
39
 
 
 
 
 
 
 
 
 
40
  @app.route('/search', methods=['POST'])
41
  def handle_search():
42
- """API 处理端点"""
 
 
 
43
  try:
44
  # 请求验证
45
  if not request.is_json:
46
- app.logger.warning("无效的 Content-Type")
47
  return jsonify({"error": "需要 application/json"}), 415
48
 
49
  data = request.get_json()
50
  query = data.get('query', '').strip()
51
 
52
  if not query:
53
- app.logger.warning("收到空查询")
54
  return jsonify({"error": "查询不能为空"}), 400
55
 
56
- # 编码查询
57
- try:
58
- query_emb = model.encode(query, convert_to_tensor=True)
59
- except Exception as e:
60
- app.logger.error(f"编码失败: {str(e)}")
61
- return jsonify({"error": "查询处理失败"}), 500
62
-
63
- # 语义搜索
64
- try:
65
- hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
66
- best = hits[0]
67
- result = {
68
- "code": valid_snippets[best['corpus_id']],
69
- "score": round(float(best['score']), 4)
70
- }
71
- app.logger.info(f"成功处理查询: '{query}'")
72
- return jsonify(result)
73
-
74
- except IndexError:
75
- app.logger.error("无匹配结果")
76
- return jsonify({"error": "无可用匹配"}), 404
77
-
78
  except Exception as e:
79
- app.logger.error(f"未知错误: {str(e)}", exc_info=True)
80
  return jsonify({"error": "服务器内部错误"}), 500
81
 
82
  if __name__ == "__main__":
 
1
  from flask import Flask, request, jsonify
2
  from sentence_transformers import SentenceTransformer, util
3
  import logging
4
+ import os
5
 
6
  app = Flask(__name__)
7
 
 
9
  logging.basicConfig(level=logging.INFO)
10
  app.logger = logging.getLogger("CodeSearchAPI")
11
 
12
+ # 预定义代码片段
13
  CODE_SNIPPETS = [
14
  """def sort_list(x): return sorted(x)""",
15
  """def count_above_threshold(elements, threshold=0):
 
18
  return min(elements), max(elements)"""
19
  ]
20
 
21
+ # 初始化标记
22
+ model_ready = False
 
 
 
 
 
 
 
23
 
 
24
  try:
25
+ # 初始化模型(使用预下载的缓存)
26
+ model = SentenceTransformer(
27
+ "flax-sentence-embeddings/st-codesearch-distilroberta-base",
28
+ cache_folder=os.getenv("HF_HOME")
29
+ )
30
+
31
+ # 预计算编码
32
+ code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True)
33
+ model_ready = True
34
+ app.logger.info("模型加载完成,服务就绪")
35
  except Exception as e:
36
+ app.logger.error(f"模型初始化失败: {str(e)}")
37
  raise
38
 
39
+ @app.route('/health')
40
+ def health_check():
41
+ """健康检查端点"""
42
+ if model_ready:
43
+ return jsonify({"status": "ready"}), 200
44
+ else:
45
+ return jsonify({"status": "initializing"}), 503
46
+
47
  @app.route('/search', methods=['POST'])
48
  def handle_search():
49
+ """搜索请求处理"""
50
+ if not model_ready:
51
+ return jsonify({"error": "服务正在初始化"}), 503
52
+
53
  try:
54
  # 请求验证
55
  if not request.is_json:
 
56
  return jsonify({"error": "需要 application/json"}), 415
57
 
58
  data = request.get_json()
59
  query = data.get('query', '').strip()
60
 
61
  if not query:
 
62
  return jsonify({"error": "查询不能为空"}), 400
63
 
64
+ # 处理查询
65
+ query_emb = model.encode(query, convert_to_tensor=True)
66
+ hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
67
+ best = hits[0]
68
+
69
+ return jsonify({
70
+ "code": CODE_SNIPPETS[best['corpus_id']],
71
+ "score": round(float(best['score']), 4)
72
+ })
73
+
 
 
 
 
 
 
 
 
 
 
 
 
74
  except Exception as e:
75
+ app.logger.error(f"请求处理失败: {str(e)}")
76
  return jsonify({"error": "服务器内部错误"}), 500
77
 
78
  if __name__ == "__main__":