Spaces:
Running
Running
File size: 2,902 Bytes
90f5392 892f484 90f5392 892f484 90f5392 892f484 aacc39b 0db0051 892f484 0db0051 892f484 0db0051 892f484 0db0051 892f484 aacc39b 892f484 90f5392 0db0051 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from flask import Flask, request, jsonify
from sentence_transformers import SentenceTransformer, util
import logging
app = Flask(__name__)
# 配置日志
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")
# 预定义代码片段(已验证数据)
CODE_SNIPPETS = [
"""def sort_list(x): return sorted(x)""",
"""def count_above_threshold(elements, threshold=0):
return sum(1 for e in elements if e > threshold)""",
"""def find_min_max(elements):
return min(elements), max(elements)"""
]
# 输入数据验证
def validate_snippets(snippets):
cleaned = []
for idx, s in enumerate(snippets):
if not isinstance(s, str):
app.logger.warning(f"索引 {idx} 类型错误,已转换为字符串")
s = str(s)
cleaned.append(s.replace("...", "").strip())
return [s for s in cleaned if len(s) > 0]
# 初始化模型和编码
try:
model = SentenceTransformer("flax-sentence-embeddings/st-codesearch-distilroberta-base")
valid_snippets = validate_snippets(CODE_SNIPPETS)
code_emb = model.encode(valid_snippets, convert_to_tensor=True)
app.logger.info(f"成功加载模型,编码 {len(valid_snippets)} 个有效代码片段")
except Exception as e:
app.logger.error(f"初始化失败: {str(e)}")
raise
@app.route('/search', methods=['POST'])
def handle_search():
"""API 处理端点"""
try:
# 请求验证
if not request.is_json:
app.logger.warning("无效的 Content-Type")
return jsonify({"error": "需要 application/json"}), 415
data = request.get_json()
query = data.get('query', '').strip()
if not query:
app.logger.warning("收到空查询")
return jsonify({"error": "查询不能为空"}), 400
# 编码查询
try:
query_emb = model.encode(query, convert_to_tensor=True)
except Exception as e:
app.logger.error(f"编码失败: {str(e)}")
return jsonify({"error": "查询处理失败"}), 500
# 语义搜索
try:
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
best = hits[0]
result = {
"code": valid_snippets[best['corpus_id']],
"score": round(float(best['score']), 4)
}
app.logger.info(f"成功处理查询: '{query}'")
return jsonify(result)
except IndexError:
app.logger.error("无匹配结果")
return jsonify({"error": "无可用匹配"}), 404
except Exception as e:
app.logger.error(f"未知错误: {str(e)}", exc_info=True)
return jsonify({"error": "服务器内部错误"}), 500
if __name__ == "__main__":
app.run(host='0.0.0.0', port=8080) |