Spaces:
Build error
Build error
import uvicorn | |
from fastapi import FastAPI, UploadFile, File, HTTPException | |
from fastapi.staticfiles import StaticFiles | |
import hashlib | |
import os | |
from enum import Enum | |
from paddleocr import PaddleOCR | |
from PIL import Image | |
import io | |
import numpy as np | |
from typing import Optional | |
app = FastAPI(docs_url='/') | |
# 确保输出目录存在 | |
output_dir = 'output' | |
os.makedirs(output_dir, exist_ok=True) | |
class LangEnum(str, Enum): | |
ch = "ch" | |
en = "en" | |
japan = "japan" | |
korean = "korean" | |
chinese_cht = "chinese_cht" | |
fr = "fr" | |
de = "de" | |
# OCR 实例缓存 | |
ocr_cache = {} | |
def get_ocr_instance(lang: str = "ch", use_gpu: bool = False): | |
"""获取OCR实例,使用PP-OCRv5模型""" | |
cache_key = f"v5_{lang}_{use_gpu}" | |
if cache_key not in ocr_cache: | |
# 使用PaddleOCR 3.0的新API + PP-OCRv5模型 | |
ocr_cache[cache_key] = PaddleOCR( | |
ocr_version="PP-OCRv5", # 指定使用PP-OCRv5版本 | |
lang=lang, | |
text_detection_model_name="PP-OCRv5_server_det", # 使用server版本检测模型 | |
text_recognition_model_name="PP-OCRv5_server_rec", # 使用server版本识别模型 | |
use_doc_orientation_classify=False, # 关闭文档方向分类 | |
use_doc_unwarping=False, # 关闭文档矫正 | |
use_textline_orientation=False, # 关闭文本行方向分类 | |
device="gpu" if use_gpu else "cpu" | |
) | |
return ocr_cache[cache_key] | |
def validate_image(file: UploadFile): | |
"""验证上传的文件""" | |
if not file.content_type or not file.content_type.startswith('image/'): | |
raise HTTPException(status_code=400, detail="文件必须是图片格式") | |
# 检查文件大小 (最大10MB) | |
if hasattr(file, 'size') and file.size and file.size > 10 * 1024 * 1024: | |
raise HTTPException(status_code=400, detail="图片文件大小不能超过10MB") | |
async def ocr_recognition( | |
file: UploadFile = File(...), | |
lang: LangEnum = LangEnum.ch, | |
use_gpu: bool = False | |
): | |
"""PP-OCRv5文字识别 - 支持5种文字类型的单模型""" | |
try: | |
validate_image(file) | |
contents = await file.read() | |
if not contents: | |
raise HTTPException(status_code=400, detail="文件内容为空") | |
# 转换图片格式 | |
image = Image.open(io.BytesIO(contents)) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# 获取OCR实例 | |
ocr = get_ocr_instance(lang=lang, use_gpu=use_gpu) | |
# 转换为numpy数组进行识别 | |
img_array = np.array(image) | |
# 使用PP-OCRv5进行识别 | |
results = ocr.predict(img_array) | |
if not results or len(results) == 0: | |
return { | |
"success": True, | |
"message": "未检测到文字", | |
"model_version": "PP-OCRv5", | |
"language": lang, | |
"count": 0, | |
"results": [] | |
} | |
# 处理识别结果 | |
result = results[0] # 取第一个结果 | |
# 提取结果信息 | |
ocr_results = [] | |
if hasattr(result, 'json') and result.json: | |
# 从result.json中提取信息 | |
result_data = result.json | |
rec_texts = result_data.get('rec_texts', []) | |
rec_scores = result_data.get('rec_scores', []) | |
dt_polys = result_data.get('dt_polys', []) | |
for i, (text, score, poly) in enumerate(zip(rec_texts, rec_scores, dt_polys)): | |
ocr_results.append({ | |
"id": i, | |
"text": text, | |
"confidence": round(float(score), 4), | |
"bbox": poly.tolist() if hasattr(poly, 'tolist') else poly | |
}) | |
return { | |
"success": True, | |
"model_version": "PP-OCRv5", | |
"language": lang, | |
"count": len(ocr_results), | |
"results": ocr_results | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"OCR识别失败: {str(e)}") | |
async def table_recognition( | |
file: UploadFile = File(...), | |
lang: LangEnum = LangEnum.ch, | |
use_gpu: bool = False | |
): | |
"""PP-StructureV3表格识别""" | |
try: | |
validate_image(file) | |
contents = await file.read() | |
if not contents: | |
raise HTTPException(status_code=400, detail="文件内容为空") | |
# 计算文件哈希 | |
file_hash = hashlib.sha256(contents).hexdigest()[:12] | |
# 转换图片格式 | |
image = Image.open(io.BytesIO(contents)) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# 使用PP-StructureV3进行表格识别 | |
# 这里需要单独的表格识别产线 | |
from paddleocr import PPStructure | |
# 获取表格识别实例 | |
table_key = f"table_v3_{lang}_{use_gpu}" | |
if table_key not in ocr_cache: | |
ocr_cache[table_key] = PPStructure( | |
table=True, | |
lang=lang, | |
device="gpu" if use_gpu else "cpu", | |
show_log=True | |
) | |
table_engine = ocr_cache[table_key] | |
img_array = np.array(image) | |
result = table_engine(img_array) | |
# 保存结果 | |
try: | |
from paddleocr import save_structure_res | |
save_structure_res(result, output_dir, file_hash) | |
except Exception as save_error: | |
print(f"保存结果文件失败: {save_error}") | |
# 处理结果 | |
tables = [] | |
images = [] | |
texts = [] | |
for item in result: | |
item_type = item.get('type', '') | |
bbox = item.get('bbox', []) | |
res = item.get('res', {}) | |
if item_type == 'table': | |
tables.append({ | |
"type": item_type, | |
"bbox": bbox, | |
"html": res.get('html', ''), | |
"confidence": res.get('confidence', 0.0) | |
}) | |
elif item_type == 'figure': | |
images.append({ | |
"type": item_type, | |
"bbox": bbox | |
}) | |
else: | |
texts.append({ | |
"type": item_type, | |
"bbox": bbox, | |
"text": res.get('text', '') if isinstance(res, dict) else str(res) | |
}) | |
return { | |
"success": True, | |
"model_version": "PP-StructureV3", | |
"language": lang, | |
"hash": file_hash, | |
"summary": { | |
"total_elements": len(result), | |
"tables": len(tables), | |
"images": len(images), | |
"texts": len(texts) | |
}, | |
"tables": tables, | |
"images": images, | |
"texts": texts | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"表格识别失败: {str(e)}") | |
async def health_check(): | |
"""健康检查接口""" | |
return { | |
"status": "healthy", | |
"ocr_version": "PP-OCRv5", | |
"structure_version": "PP-StructureV3", | |
"cache_instances": len(ocr_cache), | |
"supported_languages": [lang.value for lang in LangEnum] | |
} | |
async def get_model_info(): | |
"""获取模型信息""" | |
return { | |
"ocr_models": { | |
"PP-OCRv5_server_det": "高精度文本检测模型", | |
"PP-OCRv5_server_rec": "高精度文本识别模型 - 支持中英日韩繁5种文字类型" | |
}, | |
"structure_models": { | |
"PP-StructureV3": "通用文档解析方案 - 支持表格、图像、文本混合识别" | |
}, | |
"features": { | |
"multi_language": "单模型支持5种文字类型", | |
"handwriting": "显著提升手写体识别能力", | |
"accuracy_improvement": "相比PP-OCRv4提升13个百分点" | |
} | |
} | |
async def root(): | |
"""根路径""" | |
return { | |
"message": "PP-OCRv5 OCR API 服务正常运行", | |
"version": "3.0", | |
"models": "PP-OCRv5 + PP-StructureV3", | |
"docs": "/docs" | |
} | |
# 挂载静态文件服务 | |
app.mount("/output", StaticFiles(directory=output_dir, follow_symlink=True, html=True), name="output") | |
if __name__ == '__main__': | |
uvicorn.run(app=app, host="0.0.0.0", port=7860) |