from flask import Flask, request, jsonify from sentence_transformers import SentenceTransformer, util import logging app = Flask(__name__) # 配置日志 logging.basicConfig(level=logging.INFO) app.logger = logging.getLogger("CodeSearchAPI") # 预定义代码片段(已验证数据) CODE_SNIPPETS = [ """def sort_list(x): return sorted(x)""", """def count_above_threshold(elements, threshold=0): return sum(1 for e in elements if e > threshold)""", """def find_min_max(elements): return min(elements), max(elements)""" ] # 输入数据验证 def validate_snippets(snippets): cleaned = [] for idx, s in enumerate(snippets): if not isinstance(s, str): app.logger.warning(f"索引 {idx} 类型错误,已转换为字符串") s = str(s) cleaned.append(s.replace("...", "").strip()) return [s for s in cleaned if len(s) > 0] # 初始化模型和编码 try: model = SentenceTransformer("flax-sentence-embeddings/st-codesearch-distilroberta-base") valid_snippets = validate_snippets(CODE_SNIPPETS) code_emb = model.encode(valid_snippets, convert_to_tensor=True) app.logger.info(f"成功加载模型,编码 {len(valid_snippets)} 个有效代码片段") except Exception as e: app.logger.error(f"初始化失败: {str(e)}") raise @app.route('/search', methods=['POST']) def handle_search(): """API 处理端点""" try: # 请求验证 if not request.is_json: app.logger.warning("无效的 Content-Type") return jsonify({"error": "需要 application/json"}), 415 data = request.get_json() query = data.get('query', '').strip() if not query: app.logger.warning("收到空查询") return jsonify({"error": "查询不能为空"}), 400 # 编码查询 try: query_emb = model.encode(query, convert_to_tensor=True) except Exception as e: app.logger.error(f"编码失败: {str(e)}") return jsonify({"error": "查询处理失败"}), 500 # 语义搜索 try: hits = util.semantic_search(query_emb, code_emb, top_k=1)[0] best = hits[0] result = { "code": valid_snippets[best['corpus_id']], "score": round(float(best['score']), 4) } app.logger.info(f"成功处理查询: '{query}'") return jsonify(result) except IndexError: app.logger.error("无匹配结果") return jsonify({"error": "无可用匹配"}), 404 except Exception as e: app.logger.error(f"未知错误: {str(e)}", exc_info=True) return jsonify({"error": "服务器内部错误"}), 500 if __name__ == "__main__": app.run(host='0.0.0.0', port=8080)