Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import subprocess | |
| import streamlit as st | |
| import io | |
| import pypdfium2 | |
| from PIL import Image, ExifTags | |
| import logging | |
| # 设置日志记录器 | |
| # logging.basicConfig(level=logging.INFO) | |
| logging.basicConfig(level=logging.ERROR) | |
| logger = logging.getLogger(__name__) | |
| def resize_image_if_needed(pil_image, max_size_mb=1, max_edge_length=1024): | |
| """ | |
| Detect the size of a PIL image, and if it exceeds 1MB or its long edge is larger than 1024 pixels, | |
| reduce its size to a smaller size. | |
| Args: | |
| pil_image (PIL.Image.Image): The input PIL image. | |
| max_size_mb (int): The maximum allowed size in megabytes. | |
| max_edge_length (int): The maximum allowed length of the long edge in pixels. | |
| Returns: | |
| PIL.Image.Image: The resized PIL image. | |
| """ | |
| # Convert image to bytes and check its size | |
| img_byte_arr = io.BytesIO() | |
| pil_image.save(img_byte_arr, format='JPEG') | |
| img_size_mb = len(img_byte_arr.getvalue()) / (1024 * 1024) | |
| print(f"Image size: {img_size_mb} MB") | |
| # Check if the image size exceeds the maximum allowed size | |
| if img_size_mb > max_size_mb or max(pil_image.size) > max_edge_length: | |
| # Calculate the new size while maintaining the aspect ratio | |
| aspect_ratio = pil_image.width / pil_image.height | |
| if pil_image.width > pil_image.height: | |
| new_width = min(max_edge_length, pil_image.width) | |
| new_height = int(new_width / aspect_ratio) | |
| else: | |
| new_height = min(max_edge_length, pil_image.height) | |
| new_width = int(new_height * aspect_ratio) | |
| # Resize the image | |
| pil_image = pil_image.resize((new_width, new_height), Image.LANCZOS) | |
| # Convert the resized image to bytes and check its size again | |
| img_byte_arr = io.BytesIO() | |
| pil_image.save(img_byte_arr, format='JPEG') | |
| img_size_mb = len(img_byte_arr.getvalue()) / (1024 * 1024) | |
| # If the resized image still exceeds the maximum allowed size, reduce the quality | |
| if img_size_mb > max_size_mb: | |
| quality = 95 | |
| while img_size_mb > max_size_mb and quality > 10: | |
| img_byte_arr = io.BytesIO() | |
| pil_image.save(img_byte_arr, format='JPEG', quality=quality) | |
| img_size_mb = len(img_byte_arr.getvalue()) / (1024 * 1024) | |
| quality -= 5 | |
| return pil_image | |
| def correct_image_orientation(pil_image): | |
| """ | |
| 自动检测PIL Image对象是否包含EXIF信息,如果包含则根据EXIF信息重新修改图片的朝向。 | |
| :param pil_image: 输入的PIL Image对象 | |
| :return: 返回修正后的PIL Image对象 | |
| """ | |
| try: | |
| # 获取EXIF信息 | |
| exif = pil_image._getexif() | |
| if exif is not None: | |
| # 查找Orientation的EXIF标签编号 | |
| for orientation in ExifTags.TAGS.keys(): | |
| if ExifTags.TAGS[orientation] == 'Orientation': | |
| break | |
| # 获取图片的朝向信息 | |
| orientation_value = exif.get(orientation) | |
| print(f"Orientation value: {orientation_value}") | |
| # 根据朝向信息调整图片方向 | |
| if orientation_value == 3: | |
| pil_image = pil_image.rotate(180, expand=True) | |
| elif orientation_value == 6: | |
| pil_image = pil_image.rotate(270, expand=True) | |
| elif orientation_value == 8: | |
| pil_image = pil_image.rotate(90, expand=True) | |
| except (AttributeError, KeyError, IndexError): | |
| # 如果没有EXIF信息或者没有朝向信息,跳过处理 | |
| pass | |
| return pil_image | |
| def clone_repo(): | |
| # 从环境变量中获取 GitHub Token | |
| github_token = os.getenv('GH_TOKEN') | |
| if github_token is None: | |
| logger.error("GitHub token is not set. Please set the GH_TOKEN secret in your Space settings.") | |
| return False | |
| # 使用 GitHub Token 进行身份验证并克隆仓库 | |
| clone_command = f'git clone https://{github_token}@github.com/mamba-ai/invoice_agent.git' | |
| repo_dir = 'invoice_agent' | |
| if os.path.exists(repo_dir): | |
| logger.warning("Repository already exists.") | |
| # 将仓库路径添加到 Python 模块搜索路径中 | |
| # logger.warning(f"Adding {os.path.abspath(repo_dir)} to sys.path") | |
| # sys.path.append(os.path.abspath(repo_dir)) | |
| return True | |
| else: | |
| logger.info("Cloning repository...") | |
| result = subprocess.run(clone_command, shell=True, capture_output=True, text=True) | |
| if result.returncode == 0: | |
| logger.warning("Repository cloned successfully.") | |
| repo_dir = 'invoice_agent' | |
| # 将仓库路径添加到 Python 模块搜索路径中 | |
| sys.path.append(os.path.abspath(repo_dir)) | |
| logger.warning(f"Adding {os.path.abspath(repo_dir)} to sys.path") | |
| return True | |
| else: | |
| logger.error(f"Failed to clone repository: {result.stderr}") | |
| return False | |
| if clone_repo(): | |
| # 克隆成功后导入模块 | |
| import invoice_agent.agent as ia | |
| # from invoice_agent.agent import load_models, get_ocr_predictions, get_json_result | |
| def open_pdf(pdf_file): | |
| stream = io.BytesIO(pdf_file.getvalue()) | |
| return pypdfium2.PdfDocument(stream) | |
| def get_page_image(pdf_file, page_num, dpi=96): | |
| doc = open_pdf(pdf_file) | |
| renderer = doc.render( | |
| pypdfium2.PdfBitmap.to_pil, | |
| page_indices=[page_num - 1], | |
| scale=dpi / 72, | |
| ) | |
| png = list(renderer)[0] | |
| png_image = png.convert("RGB") | |
| return png_image | |
| def page_count(pdf_file): | |
| doc = open_pdf(pdf_file) | |
| return len(doc) | |
| st.set_page_config(layout="wide") | |
| models = ia.load_models() | |
| st.title(""" | |
| 受領した請求書を自動で電子化 (Demo) | |
| """) | |
| col1, _, col2 = st.columns([.45, 0.1, .45]) | |
| in_file = st.sidebar.file_uploader( | |
| "PDFファイルまたは画像:", | |
| type=["pdf", "png", "jpg", "jpeg", "gif", "webp"], | |
| ) | |
| if in_file is None: | |
| st.stop() | |
| filetype = in_file.type | |
| whole_image = False | |
| if "pdf" in filetype: | |
| page_count = page_count(in_file) | |
| page_number = st.sidebar.number_input(f"ページ番号 {page_count}:", min_value=1, value=1, max_value=page_count) | |
| pil_image = get_page_image(in_file, page_number) | |
| else: | |
| pil_image = Image.open(in_file).convert("RGB") | |
| pil_image = correct_image_orientation(pil_image) | |
| pil_image = resize_image_if_needed(pil_image) | |
| text_rec = st.sidebar.button("認識開始") | |
| if pil_image is None: | |
| st.stop() | |
| with col1: | |
| st.write("## アップロードされたファイル") | |
| st.image(pil_image, caption="アップロードされたファイル", use_column_width=True) | |
| # if 'json_predictions' in st.session_state: | |
| # prev_json_predictions = st.session_state.json_predictions | |
| # prev_excel_file_path = st.session_state.excel_file_path | |
| # with col2: | |
| # st.write("## 結果") | |
| # # 提供下载链接 | |
| # with open(prev_excel_file_path, "rb") as file: | |
| # st.download_button( | |
| # label="Download Excel", | |
| # data=file, | |
| # file_name="output.xlsx", | |
| # mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | |
| # ) | |
| # st.write("解析後の内容:") | |
| # st.json(prev_json_predictions) | |
| if text_rec: | |
| with col2: | |
| st.write("## 結果") | |
| # Placeholder for status indicator | |
| status_placeholder = st.empty() | |
| with st.spinner('現在ファイルを解析中です'): | |
| # Simulate model running time | |
| # time.sleep(5) # Replace this with actual model running code | |
| # predictions = ia.get_ocr_predictions(pil_image, models) | |
| # json_predictions = ia.get_json_result(predictions) | |
| json_predictions = ia.get_json_result_v2(pil_image, models) | |
| logger.error(json_predictions) | |
| st.session_state.json_predictions = json_predictions | |
| # Convert JSON to Excel | |
| # excel_file_path = "output.xlsx" | |
| # st.session_state.excel_file_path = excel_file_path | |
| # ia.json_to_excel_with_links(json_predictions, excel_file_path) | |
| # After model finishes | |
| status_placeholder.success('ファイルの解析が完了しました!') | |
| # 提供下载链接 | |
| # with open(excel_file_path, "rb") as file: | |
| # st.download_button( | |
| # label="Download Excel", | |
| # data=file, | |
| # file_name="output.xlsx", | |
| # mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | |
| # ) | |
| # Display the result | |
| st.write("解析後の内容:") | |
| st.json(json_predictions) | |
| # st.write(predictions) | |