Forrest99 commited on
Commit
0db0051
·
verified ·
1 Parent(s): 45f0945

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -23
app.py CHANGED
@@ -1,35 +1,41 @@
1
  from flask import Flask, request, jsonify
2
  from sentence_transformers import SentenceTransformer, util
3
- import os
4
 
5
  app = Flask(__name__)
6
 
7
- # 预定义代码片段
8
- CODE = [
9
- """def sort_list(x): return sorted(x)""",
10
- # 其他代码片段...
11
- ]
12
 
13
- # 初始化模型(显示指定缓存路径)
14
- model = SentenceTransformer(
15
- "flax-sentence-embeddings/st-codesearch-distilroberta-base",
16
- cache_folder=os.getenv("TRANSFORMERS_CACHE")
17
- )
18
- code_emb = model.encode(CODE, convert_to_tensor=True)
19
-
20
- @app.route('/search', methods=['POST'])
21
- def search():
22
  try:
23
- query = request.json['query']
 
 
 
 
 
 
 
 
 
24
  query_emb = model.encode(query, convert_to_tensor=True)
25
- hits = util.semantic_search(query_emb, code_emb)[0]
26
- best = hits[0]
 
27
  return jsonify({
28
- 'code': CODE[best['corpus_id']],
29
- 'score': float(best['score'])
 
 
30
  })
 
31
  except Exception as e:
32
- return jsonify({"error": str(e)}), 500
 
33
 
34
- if __name__ == '__main__':
35
- app.run(host='0.0.0.0', port=5000)
 
1
  from flask import Flask, request, jsonify
2
  from sentence_transformers import SentenceTransformer, util
 
3
 
4
  app = Flask(__name__)
5
 
6
+ # 预加载模型(线上部署关键)
7
+ model = SentenceTransformer("flax-sentence-embeddings/st-codesearch-distilroberta-base")
8
+ code_snippets = [...] # 你的代码片段
9
+ code_embeddings = model.encode(code_snippets, convert_to_tensor=True)
 
10
 
11
+ @app.route('/v1/search', methods=['POST'])
12
+ def api_handler():
13
+ """生产级API端点"""
 
 
 
 
 
 
14
  try:
15
+ # 请求验证
16
+ if not request.is_json:
17
+ return jsonify({"error": "Invalid Content-Type"}), 415
18
+
19
+ data = request.get_json()
20
+ if 'query' not in data:
21
+ return jsonify({"error": "Missing query parameter"}), 400
22
+
23
+ # 执行语义搜索
24
+ query = data['query']
25
  query_emb = model.encode(query, convert_to_tensor=True)
26
+ results = util.semantic_search(query_emb, code_embeddings, top_k=1)[0]
27
+
28
+ # 构建响应
29
  return jsonify({
30
+ "data": {
31
+ "best_match": code_snippets[results[0]['corpus_id']],
32
+ "similarity": float(results[0]['score'])
33
+ }
34
  })
35
+
36
  except Exception as e:
37
+ app.logger.error(f"API Error: {str(e)}")
38
+ return jsonify({"error": "Internal Server Error"}), 500
39
 
40
+ if __name__ == "__main__":
41
+ app.run(host='0.0.0.0', port=8080)