from flask import Flask, request, jsonify from sentence_transformers import SentenceTransformer, util import logging import os app = Flask(__name__) # 配置日志 logging.basicConfig(level=logging.INFO) app.logger = logging.getLogger("CodeSearchAPI") # 预定义代码片段 CODE_SNIPPETS = [ """def sort_list(x): return sorted(x)""", """def count_above_threshold(elements, threshold=0): return sum(1 for e in elements if e > threshold)""", """def find_min_max(elements): return min(elements), max(elements)""" ] # 初始化标记 model_ready = False try: # 初始化模型(使用预下载的缓存) model = SentenceTransformer( "flax-sentence-embeddings/st-codesearch-distilroberta-base", cache_folder=os.getenv("HF_HOME") ) # 预计算编码 code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True) model_ready = True app.logger.info("模型加载完成,服务就绪") except Exception as e: app.logger.error(f"模型初始化失败: {str(e)}") raise @app.route('/health') def health_check(): """健康检查端点""" if model_ready: return jsonify({"status": "ready"}), 200 else: return jsonify({"status": "initializing"}), 503 @app.route('/search', methods=['POST']) def handle_search(): """搜索请求处理""" if not model_ready: return jsonify({"error": "服务正在初始化"}), 503 try: # 请求验证 if not request.is_json: return jsonify({"error": "需要 application/json"}), 415 data = request.get_json() query = data.get('query', '').strip() if not query: return jsonify({"error": "查询不能为空"}), 400 # 处理查询 query_emb = model.encode(query, convert_to_tensor=True) hits = util.semantic_search(query_emb, code_emb, top_k=1)[0] best = hits[0] return jsonify({ "code": CODE_SNIPPETS[best['corpus_id']], "score": round(float(best['score']), 4) }) except Exception as e: app.logger.error(f"请求处理失败: {str(e)}") return jsonify({"error": "服务器内部错误"}), 500 if __name__ == "__main__": app.run(host='0.0.0.0', port=8080)