Forrest99 commited on
Commit
2a39044
·
verified ·
1 Parent(s): c79e0ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -24
app.py CHANGED
@@ -1,13 +1,13 @@
1
- from flask import Flask, request, jsonify
2
  from sentence_transformers import SentenceTransformer, util
3
  import logging
4
  import sys
5
  import signal
6
 
7
- # 初始化Flask应用
8
  app = Flask(__name__)
9
 
10
- # 配置日志,日志级别设为INFO
11
  logging.basicConfig(level=logging.INFO)
12
  app.logger = logging.getLogger("CodeSearchAPI")
13
 
@@ -31,42 +31,48 @@ def handle_shutdown(signum, frame):
31
  signal.signal(signal.SIGTERM, handle_shutdown)
32
  signal.signal(signal.SIGINT, handle_shutdown)
33
 
34
- # 初始化模型和编码
35
  try:
36
- # Hugging Face Spaces专用缓存路径
37
  model = SentenceTransformer(
38
  "flax-sentence-embeddings/st-codesearch-distilroberta-base",
39
  cache_folder="/model-cache"
40
  )
41
-
42
- # 预计算编码(强制使用CPU)
43
- code_emb = model.encode(CODE_SNIPPETS,
44
- convert_to_tensor=True,
45
- device="cpu")
46
-
47
  service_ready = True
48
  app.logger.info("服务初始化完成")
49
  except Exception as e:
50
- app.logger.error(f"初始化失败: {str(e)}")
51
  raise
52
 
53
- # Hugging Face健康检查端点
54
  @app.route('/')
55
  def hf_health_check():
56
- """必须响应根路径的健康检查"""
 
 
 
 
 
 
 
 
 
57
  if service_ready:
58
  return jsonify({"status": "ready"}), 200
59
  else:
60
  return jsonify({"status": "initializing"}), 503
61
 
62
- # 支持GET和POST请求的搜索API端点
63
  @app.route('/search', methods=['GET', 'POST'])
64
  def handle_search():
65
  if not service_ready:
 
66
  return jsonify({"error": "服务正在初始化"}), 503
67
 
68
  try:
69
- # 区分GET和POST请求,GET从URL参数中获取query,POST从JSON体中获取
70
  if request.method == 'GET':
71
  query = request.args.get('query', '').strip()
72
  else:
@@ -76,14 +82,12 @@ def handle_search():
76
  if not query:
77
  app.logger.info("收到空的查询请求")
78
  return jsonify({"error": "查询不能为空"}), 400
79
-
80
  # 记录接收到的查询
81
  app.logger.info("收到查询请求: %s", query)
82
 
83
- # 对查询进行编码,并搜索最匹配的代码片段
84
- query_emb = model.encode(query,
85
- convert_to_tensor=True,
86
- device="cpu")
87
  hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
88
  best = hits[0]
89
 
@@ -91,15 +95,15 @@ def handle_search():
91
  "code": CODE_SNIPPETS[best['corpus_id']],
92
  "score": round(float(best['score']), 4)
93
  }
 
94
  # 记录返回结果
95
  app.logger.info("返回结果: %s", result)
96
-
97
  return jsonify(result)
98
-
99
  except Exception as e:
100
  app.logger.error("请求处理失败: %s", str(e))
101
  return jsonify({"error": "服务器内部错误"}), 500
102
 
103
  if __name__ == "__main__":
104
- # Hugging Face Spaces会通过gunicorn启动,此处仅为本地测试保留
105
  app.run(host='0.0.0.0', port=7860)
 
1
+ from flask import Flask, request, jsonify, render_template_string
2
  from sentence_transformers import SentenceTransformer, util
3
  import logging
4
  import sys
5
  import signal
6
 
7
+ # 初始化 Flask 应用
8
  app = Flask(__name__)
9
 
10
+ # 配置日志,级别设为 INFO
11
  logging.basicConfig(level=logging.INFO)
12
  app.logger = logging.getLogger("CodeSearchAPI")
13
 
 
31
  signal.signal(signal.SIGTERM, handle_shutdown)
32
  signal.signal(signal.SIGINT, handle_shutdown)
33
 
34
+ # 初始化模型和预计算编码
35
  try:
36
+ app.logger.info("开始加载模型...")
37
  model = SentenceTransformer(
38
  "flax-sentence-embeddings/st-codesearch-distilroberta-base",
39
  cache_folder="/model-cache"
40
  )
41
+ # 预计算代码片段的编码(强制使用 CPU)
42
+ code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True, device="cpu")
 
 
 
 
43
  service_ready = True
44
  app.logger.info("服务初始化完成")
45
  except Exception as e:
46
+ app.logger.error("初始化失败: %s", str(e))
47
  raise
48
 
49
+ # Hugging Face 健康检查端点,必须响应根路径
50
  @app.route('/')
51
  def hf_health_check():
52
+ # 如果请求接受 HTML,则返回一个简单的 HTML 页面(包含测试链接)
53
+ if request.accept_mimetypes.accept_html:
54
+ html = """
55
+ <h2>CodeSearch API</h2>
56
+ <p>服务状态:{{ status }}</p>
57
+ <p>你可以在地址栏输入 /search?query=你的查询 来测试接口</p>
58
+ """
59
+ status = "ready" if service_ready else "initializing"
60
+ return render_template_string(html, status=status)
61
+ # 否则返回 JSON 格式的健康检查
62
  if service_ready:
63
  return jsonify({"status": "ready"}), 200
64
  else:
65
  return jsonify({"status": "initializing"}), 503
66
 
67
+ # 搜索 API 端点,同时支持 GET POST 请求
68
  @app.route('/search', methods=['GET', 'POST'])
69
  def handle_search():
70
  if not service_ready:
71
+ app.logger.info("服务未就绪")
72
  return jsonify({"error": "服务正在初始化"}), 503
73
 
74
  try:
75
+ # 根据请求方法提取查询内容
76
  if request.method == 'GET':
77
  query = request.args.get('query', '').strip()
78
  else:
 
82
  if not query:
83
  app.logger.info("收到空的查询请求")
84
  return jsonify({"error": "查询不能为空"}), 400
85
+
86
  # 记录接收到的查询
87
  app.logger.info("收到查询请求: %s", query)
88
 
89
+ # 对查询进行编码,并进行语义搜索
90
+ query_emb = model.encode(query, convert_to_tensor=True, device="cpu")
 
 
91
  hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
92
  best = hits[0]
93
 
 
95
  "code": CODE_SNIPPETS[best['corpus_id']],
96
  "score": round(float(best['score']), 4)
97
  }
98
+
99
  # 记录返回结果
100
  app.logger.info("返回结果: %s", result)
 
101
  return jsonify(result)
102
+
103
  except Exception as e:
104
  app.logger.error("请求处理失败: %s", str(e))
105
  return jsonify({"error": "服务器内部错误"}), 500
106
 
107
  if __name__ == "__main__":
108
+ # 本地测试用,Hugging Face Spaces 通常通过 gunicorn 启动
109
  app.run(host='0.0.0.0', port=7860)