Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import subprocess | |
import streamlit as st | |
import io | |
import pypdfium2 | |
from PIL import Image, ExifTags | |
import logging | |
# 设置日志记录器 | |
# logging.basicConfig(level=logging.INFO) | |
logging.basicConfig(level=logging.ERROR) | |
logger = logging.getLogger(__name__) | |
def resize_image_if_needed(pil_image, max_size_mb=1, max_edge_length=1024): | |
""" | |
Detect the size of a PIL image, and if it exceeds 1MB or its long edge is larger than 1024 pixels, | |
reduce its size to a smaller size. | |
Args: | |
pil_image (PIL.Image.Image): The input PIL image. | |
max_size_mb (int): The maximum allowed size in megabytes. | |
max_edge_length (int): The maximum allowed length of the long edge in pixels. | |
Returns: | |
PIL.Image.Image: The resized PIL image. | |
""" | |
# Convert image to bytes and check its size | |
img_byte_arr = io.BytesIO() | |
pil_image.save(img_byte_arr, format='JPEG') | |
img_size_mb = len(img_byte_arr.getvalue()) / (1024 * 1024) | |
print(f"Image size: {img_size_mb} MB") | |
# Check if the image size exceeds the maximum allowed size | |
if img_size_mb > max_size_mb or max(pil_image.size) > max_edge_length: | |
# Calculate the new size while maintaining the aspect ratio | |
aspect_ratio = pil_image.width / pil_image.height | |
if pil_image.width > pil_image.height: | |
new_width = min(max_edge_length, pil_image.width) | |
new_height = int(new_width / aspect_ratio) | |
else: | |
new_height = min(max_edge_length, pil_image.height) | |
new_width = int(new_height * aspect_ratio) | |
# Resize the image | |
pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS) | |
# Convert the resized image to bytes and check its size again | |
img_byte_arr = io.BytesIO() | |
pil_image.save(img_byte_arr, format='JPEG') | |
img_size_mb = len(img_byte_arr.getvalue()) / (1024 * 1024) | |
# If the resized image still exceeds the maximum allowed size, reduce the quality | |
if img_size_mb > max_size_mb: | |
quality = 95 | |
while img_size_mb > max_size_mb and quality > 10: | |
img_byte_arr = io.BytesIO() | |
pil_image.save(img_byte_arr, format='JPEG', quality=quality) | |
img_size_mb = len(img_byte_arr.getvalue()) / (1024 * 1024) | |
quality -= 5 | |
return pil_image | |
def correct_image_orientation(pil_image): | |
""" | |
自动检测PIL Image对象是否包含EXIF信息,如果包含则根据EXIF信息重新修改图片的朝向。 | |
:param pil_image: 输入的PIL Image对象 | |
:return: 返回修正后的PIL Image对象 | |
""" | |
try: | |
# 获取EXIF信息 | |
exif = pil_image._getexif() | |
if exif is not None: | |
# 查找Orientation的EXIF标签编号 | |
for orientation in ExifTags.TAGS.keys(): | |
if ExifTags.TAGS[orientation] == 'Orientation': | |
break | |
# 获取图片的朝向信息 | |
orientation_value = exif.get(orientation) | |
print(f"Orientation value: {orientation_value}") | |
# 根据朝向信息调整图片方向 | |
if orientation_value == 3: | |
pil_image = pil_image.rotate(180, expand=True) | |
elif orientation_value == 6: | |
pil_image = pil_image.rotate(270, expand=True) | |
elif orientation_value == 8: | |
pil_image = pil_image.rotate(90, expand=True) | |
except (AttributeError, KeyError, IndexError): | |
# 如果没有EXIF信息或者没有朝向信息,跳过处理 | |
pass | |
return pil_image | |
def clone_repo(): | |
# 从环境变量中获取 GitHub Token | |
github_token = os.getenv('GH_TOKEN') | |
if github_token is None: | |
logger.error("GitHub token is not set. Please set the GH_TOKEN secret in your Space settings.") | |
return False | |
# 使用 GitHub Token 进行身份验证并克隆仓库 | |
clone_command = f'git clone https://{github_token}@github.com/mamba-ai/invoice_agent.git' | |
repo_dir = 'invoice_agent' | |
if os.path.exists(repo_dir): | |
logger.warning("Repository already exists.") | |
# 将仓库路径添加到 Python 模块搜索路径中 | |
# logger.warning(f"Adding {os.path.abspath(repo_dir)} to sys.path") | |
# sys.path.append(os.path.abspath(repo_dir)) | |
return True | |
else: | |
logger.info("Cloning repository...") | |
result = subprocess.run(clone_command, shell=True, capture_output=True, text=True) | |
if result.returncode == 0: | |
logger.warning("Repository cloned successfully.") | |
repo_dir = 'invoice_agent' | |
# 将仓库路径添加到 Python 模块搜索路径中 | |
sys.path.append(os.path.abspath(repo_dir)) | |
logger.warning(f"Adding {os.path.abspath(repo_dir)} to sys.path") | |
return True | |
else: | |
logger.error(f"Failed to clone repository: {result.stderr}") | |
return False | |
if clone_repo(): | |
# 克隆成功后导入模块 | |
import invoice_agent.agent as ia | |
def open_pdf(pdf_file): | |
stream = io.BytesIO(pdf_file.getvalue()) | |
return pypdfium2.PdfDocument(stream) | |
def get_page_image(pdf_file, page_num, dpi=96): | |
doc = open_pdf(pdf_file) | |
renderer = doc.render( | |
pypdfium2.PdfBitmap.to_pil, | |
page_indices=[page_num - 1], | |
scale=dpi / 72, | |
) | |
png = list(renderer)[0] | |
png_image = png.convert("RGB") | |
return png_image | |
def page_count(pdf_file): | |
doc = open_pdf(pdf_file) | |
return len(doc) | |
st.set_page_config(layout="wide") | |
st.title(""" | |
受領した請求書を自動で電子化 (Demo) | |
""") | |
col1, _, col2 = st.columns([.45, 0.1, .45]) | |
in_file = st.sidebar.file_uploader( | |
"PDFファイルまたは画像:", | |
type=["pdf", "png", "jpg", "jpeg", "gif", "webp"], | |
) | |
if in_file is None: | |
st.stop() | |
filetype = in_file.type | |
whole_image = False | |
if "pdf" in filetype: | |
page_count = page_count(in_file) | |
page_number = st.sidebar.number_input(f"ページ番号 {page_count}:", min_value=1, value=1, max_value=page_count) | |
pil_image = get_page_image(in_file, page_number) | |
else: | |
pil_image = Image.open(in_file).convert("RGB") | |
pil_image = correct_image_orientation(pil_image) | |
pil_image = resize_image_if_needed(pil_image) | |
text_rec = st.sidebar.button("認識開始") | |
if pil_image is None: | |
st.stop() | |
with col1: | |
st.write("## アップロードされたファイル") | |
st.image(pil_image, caption="アップロードされたファイル", use_column_width=True) | |
# if 'json_predictions' in st.session_state: | |
# prev_json_predictions = st.session_state.json_predictions | |
# prev_excel_file_path = st.session_state.excel_file_path | |
# with col2: | |
# st.write("## 結果") | |
# # 提供下载链接 | |
# with open(prev_excel_file_path, "rb") as file: | |
# st.download_button( | |
# label="Download Excel", | |
# data=file, | |
# file_name="output.xlsx", | |
# mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | |
# ) | |
# st.write("解析後の内容:") | |
# st.json(prev_json_predictions) | |
if text_rec: | |
with col2: | |
st.write("## 結果") | |
# Placeholder for status indicator | |
status_placeholder = st.empty() | |
with st.spinner('現在ファイルを解析中です'): | |
# Simulate model running time | |
json_predictions = ia.get_json_result_v2(pil_image, None) | |
logger.error(json_predictions) | |
st.session_state.json_predictions = json_predictions | |
# Convert JSON to Excel | |
# excel_file_path = "output.xlsx" | |
# st.session_state.excel_file_path = excel_file_path | |
# ia.json_to_excel_with_links(json_predictions, excel_file_path) | |
# After model finishes | |
status_placeholder.success('ファイルの解析が完了しました!') | |
# 提供下载链接 | |
# with open(excel_file_path, "rb") as file: | |
# st.download_button( | |
# label="Download Excel", | |
# data=file, | |
# file_name="output.xlsx", | |
# mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | |
# ) | |
# Display the result | |
st.write("解析後の内容:") | |
st.json(json_predictions) | |
# st.write(predictions) | |