Forrest99 commited on
Commit
aacc39b
·
verified ·
1 Parent(s): 06ff106

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -12
app.py CHANGED
@@ -1,29 +1,35 @@
1
  from flask import Flask, request, jsonify
2
  from sentence_transformers import SentenceTransformer, util
 
3
 
4
  app = Flask(__name__)
5
 
6
  # 预定义代码片段
7
  CODE = [
8
  """def sort_list(x): return sorted(x)""",
9
- """def count_above_threshold(elements, threshold=0):...""",
10
- """def find_min_max(elements):..."""
11
  ]
12
 
13
- # 初始化模型(启动时自动下载)
14
- model = SentenceTransformer("flax-sentence-embeddings/st-codesearch-distilroberta-base")
 
 
 
15
  code_emb = model.encode(CODE, convert_to_tensor=True)
16
 
17
  @app.route('/search', methods=['POST'])
18
  def search():
19
- query = request.json.get('query', '')
20
- query_emb = model.encode(query, convert_to_tensor=True)
21
- hits = util.semantic_search(query_emb, code_emb)[0]
22
- best = hits[0]
23
- return jsonify({
24
- 'code': CODE[best['corpus_id']],
25
- 'score': float(best['score'])
26
- })
 
 
 
27
 
28
  if __name__ == '__main__':
29
  app.run(host='0.0.0.0', port=5000)
 
1
  from flask import Flask, request, jsonify
2
  from sentence_transformers import SentenceTransformer, util
3
+ import os
4
 
5
  app = Flask(__name__)
6
 
7
  # 预定义代码片段
8
  CODE = [
9
  """def sort_list(x): return sorted(x)""",
10
+ # 其他代码片段...
 
11
  ]
12
 
13
+ # 初始化模型(显示指定缓存路径)
14
+ model = SentenceTransformer(
15
+ "flax-sentence-embeddings/st-codesearch-distilroberta-base",
16
+ cache_folder=os.getenv("TRANSFORMERS_CACHE")
17
+ )
18
  code_emb = model.encode(CODE, convert_to_tensor=True)
19
 
20
  @app.route('/search', methods=['POST'])
21
  def search():
22
+ try:
23
+ query = request.json['query']
24
+ query_emb = model.encode(query, convert_to_tensor=True)
25
+ hits = util.semantic_search(query_emb, code_emb)[0]
26
+ best = hits[0]
27
+ return jsonify({
28
+ 'code': CODE[best['corpus_id']],
29
+ 'score': float(best['score'])
30
+ })
31
+ except Exception as e:
32
+ return jsonify({"error": str(e)}), 500
33
 
34
  if __name__ == '__main__':
35
  app.run(host='0.0.0.0', port=5000)