codesearchBase / app.py
Forrest99's picture
Update app.py
79ce9cc verified
raw
history blame
2.33 kB
from flask import Flask, request, jsonify
from sentence_transformers import SentenceTransformer, util
import logging
import os
app = Flask(__name__)
# 配置日志
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")
# 预定义代码片段
CODE_SNIPPETS = [
"""def sort_list(x): return sorted(x)""",
"""def count_above_threshold(elements, threshold=0):
return sum(1 for e in elements if e > threshold)""",
"""def find_min_max(elements):
return min(elements), max(elements)"""
]
# 初始化标记
model_ready = False
try:
# 初始化模型(使用预下载的缓存)
model = SentenceTransformer(
"flax-sentence-embeddings/st-codesearch-distilroberta-base",
cache_folder=os.getenv("HF_HOME")
)
# 预计算编码
code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True)
model_ready = True
app.logger.info("模型加载完成,服务就绪")
except Exception as e:
app.logger.error(f"模型初始化失败: {str(e)}")
raise
@app.route('/health')
def health_check():
"""健康检查端点"""
if model_ready:
return jsonify({"status": "ready"}), 200
else:
return jsonify({"status": "initializing"}), 503
@app.route('/search', methods=['POST'])
def handle_search():
"""搜索请求处理"""
if not model_ready:
return jsonify({"error": "服务正在初始化"}), 503
try:
# 请求验证
if not request.is_json:
return jsonify({"error": "需要 application/json"}), 415
data = request.get_json()
query = data.get('query', '').strip()
if not query:
return jsonify({"error": "查询不能为空"}), 400
# 处理查询
query_emb = model.encode(query, convert_to_tensor=True)
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
best = hits[0]
return jsonify({
"code": CODE_SNIPPETS[best['corpus_id']],
"score": round(float(best['score']), 4)
})
except Exception as e:
app.logger.error(f"请求处理失败: {str(e)}")
return jsonify({"error": "服务器内部错误"}), 500
if __name__ == "__main__":
app.run(host='0.0.0.0', port=8080)