Forrest99 commited on
Commit
7c08c7b
·
verified ·
1 Parent(s): 890dfaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -24
app.py CHANGED
@@ -1,8 +1,10 @@
1
  from flask import Flask, request, jsonify
2
  from sentence_transformers import SentenceTransformer, util
3
  import logging
4
- import os
 
5
 
 
6
  app = Flask(__name__)
7
 
8
  # 配置日志
@@ -18,51 +20,61 @@ CODE_SNIPPETS = [
18
  return min(elements), max(elements)"""
19
  ]
20
 
21
- # 初始化标记
22
- model_ready = False
23
 
 
 
 
 
 
 
 
 
 
24
  try:
25
- # 初始化模型(使用预下载的缓存)
26
  model = SentenceTransformer(
27
  "flax-sentence-embeddings/st-codesearch-distilroberta-base",
28
- cache_folder=os.getenv("HF_HOME")
29
  )
30
 
31
- # 预计算编码
32
- code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True)
33
- model_ready = True
34
- app.logger.info("模型加载完成,服务就绪")
 
 
 
35
  except Exception as e:
36
- app.logger.error(f"模型初始化失败: {str(e)}")
37
  raise
38
 
39
- @app.route('/health')
40
- def health_check():
41
- """健康检查端点"""
42
- if model_ready:
 
43
  return jsonify({"status": "ready"}), 200
44
  else:
45
  return jsonify({"status": "initializing"}), 503
46
 
 
47
  @app.route('/search', methods=['POST'])
48
  def handle_search():
49
- """搜索请求处理"""
50
- if not model_ready:
51
  return jsonify({"error": "服务正在初始化"}), 503
52
-
53
  try:
54
- # 请求验证
55
- if not request.is_json:
56
- return jsonify({"error": "需要 application/json"}), 415
57
-
58
  data = request.get_json()
59
  query = data.get('query', '').strip()
60
 
61
  if not query:
62
  return jsonify({"error": "查询不能为空"}), 400
63
 
64
- # 处理查询
65
- query_emb = model.encode(query, convert_to_tensor=True)
 
66
  hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
67
  best = hits[0]
68
 
@@ -76,4 +88,5 @@ def handle_search():
76
  return jsonify({"error": "服务器内部错误"}), 500
77
 
78
  if __name__ == "__main__":
79
- app.run(host='0.0.0.0', port=8080)
 
 
1
  from flask import Flask, request, jsonify
2
  from sentence_transformers import SentenceTransformer, util
3
  import logging
4
+ import sys
5
+ import signal
6
 
7
+ # 初始化Flask应用
8
  app = Flask(__name__)
9
 
10
  # 配置日志
 
20
  return min(elements), max(elements)"""
21
  ]
22
 
23
+ # 全局服务状态
24
+ service_ready = False
25
 
26
+ # 优雅关闭处理
27
+ def handle_shutdown(signum, frame):
28
+ app.logger.info("收到终止信号,开始关闭...")
29
+ sys.exit(0)
30
+
31
+ signal.signal(signal.SIGTERM, handle_shutdown)
32
+ signal.signal(signal.SIGINT, handle_shutdown)
33
+
34
+ # 初始化模型和编码
35
  try:
36
+ # Hugging Face Spaces专用缓存路径
37
  model = SentenceTransformer(
38
  "flax-sentence-embeddings/st-codesearch-distilroberta-base",
39
+ cache_folder="/model-cache"
40
  )
41
 
42
+ # 预计算编码(强制使用CPU)
43
+ code_emb = model.encode(CODE_SNIPPETS,
44
+ convert_to_tensor=True,
45
+ device="cpu")
46
+
47
+ service_ready = True
48
+ app.logger.info("服务初始化完成")
49
  except Exception as e:
50
+ app.logger.error(f"初始化失败: {str(e)}")
51
  raise
52
 
53
+ # Hugging Face健康检查端点
54
+ @app.route('/')
55
+ def hf_health_check():
56
+ """必须响应根路径的健康检查"""
57
+ if service_ready:
58
  return jsonify({"status": "ready"}), 200
59
  else:
60
  return jsonify({"status": "initializing"}), 503
61
 
62
+ # 搜索API端点
63
  @app.route('/search', methods=['POST'])
64
  def handle_search():
65
+ if not service_ready:
 
66
  return jsonify({"error": "服务正在初始化"}), 503
67
+
68
  try:
 
 
 
 
69
  data = request.get_json()
70
  query = data.get('query', '').strip()
71
 
72
  if not query:
73
  return jsonify({"error": "查询不能为空"}), 400
74
 
75
+ query_emb = model.encode(query,
76
+ convert_to_tensor=True,
77
+ device="cpu")
78
  hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
79
  best = hits[0]
80
 
 
88
  return jsonify({"error": "服务器内部错误"}), 500
89
 
90
  if __name__ == "__main__":
91
+ # Hugging Face Spaces会通过gunicorn启动,此处仅为本地测试保留
92
+ app.run(host='0.0.0.0', port=7860)