import os import sys import logging from datetime import datetime import torch from transformers import AutoProcessor, Pix2StructForConditionalGeneration from PIL import Image # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout), logging.FileHandler('app.log') ] ) logger = logging.getLogger(__name__) class ChartAnalyzer: def __init__(self): try: logger.info("Initializing model and processor...") self.model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot") self.processor = AutoProcessor.from_pretrained("google/deplot") logger.info("Model and processor initialized successfully") except Exception as e: logger.error(f"Error initializing model: {str(e)}") raise def process_image(self, image_path, prompt=None): """处理图片并生成数据表格""" try: # 验证文件存在 if not os.path.exists(image_path): raise FileNotFoundError(f"Image file not found: {image_path}") # 打开并处理图片 logger.info(f"Processing image: {image_path}") image = Image.open(image_path) # 准备输入 if prompt is None: prompt = "Generate underlying data table of the figure below:" inputs = self.processor( images=image, text=prompt, return_tensors="pt" ) # 生成预测 logger.info("Generating predictions...") with torch.no_grad(): # 提高性能并减少内存使用 predictions = self.model.generate( **inputs, max_new_tokens=512, num_beams=4, length_penalty=1.0 ) # 解码预测结果 raw_output = self.processor.decode(predictions[0], skip_special_tokens=True) # 处理结果 split_by_newline = raw_output.split("<0x0A>") result_array = [] for item in split_by_newline: if item.strip(): # 跳过空行 result_array.append([x.strip() for x in item.split("|")]) # 保存结果 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_file = f'results_{timestamp}.log' with open(output_file, mode='w', encoding='utf-8') as file: for row in result_array: file.write(" | ".join(row) + "\n") logger.info(f"Results saved to {output_file}") return result_array except Exception as e: logger.error(f"Error processing image: {str(e)}") raise def main(): try: # 创建分析器实例 analyzer = ChartAnalyzer() # 指定图片路径(在Space中使用上传的图片路径) image_path = '05e57f1c9acff69f1eb6fa72d4805d0.jpg' # 处理图片 results = analyzer.process_image(image_path) # 打印结果 print("\nAnalysis Results:") for row in results: print(" | ".join(row)) except Exception as e: logger.error(f"Application error: {str(e)}") raise if __name__ == "__main__": main()