Spaces:
Sleeping
Sleeping
| import requests | |
| import pandas as pd | |
| import os | |
| import time | |
| import gradio as gr | |
| import json | |
| import google.generativeai as genai | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') | |
| genai.configure(api_key=GOOGLE_API_KEY) | |
| ############## Photos ############## | |
| def download_file(url, save_path): | |
| try: | |
| # Send a GET request to the URL | |
| response = requests.get(url) | |
| # Check if the request was successful (status code 200) | |
| if response.status_code == 200: | |
| # Open the specified path in binary-write mode and save the content | |
| with open(save_path, 'wb') as file: | |
| file.write(response.content) | |
| else: | |
| print(f"Failed to download image. Status code: {response.status_code}") | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| def upload_file(photo_path): | |
| photo = genai.upload_file(photo_path) | |
| return photo | |
| ###### Data extraction | |
| ## Helper function to initialize model | |
| price_token={'gemini-1.5-pro-002': {'input': 1.25 / 1000000, 'output': 5 / 1000000} | |
| } | |
| gemini_safety_settings = [ | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_HARASSMENT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_HATE_SPEECH", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| ] | |
| def load_gemini_model(model_name): | |
| generation_config = genai.types.GenerationConfig( | |
| # Only one candidate for now. | |
| candidate_count=1, | |
| max_output_tokens=4000, | |
| temperature=0, | |
| response_mime_type="text/plain" | |
| ) | |
| generation_config_json = genai.types.GenerationConfig( | |
| # Only one candidate for now. | |
| candidate_count=1, | |
| max_output_tokens=4000, | |
| temperature=0, | |
| response_mime_type= "application/json" | |
| ) | |
| system_prompt = ["You are a helpful assistant."] | |
| gemini_model = genai.GenerativeModel(model_name, system_instruction=system_prompt, | |
| safety_settings=gemini_safety_settings) | |
| return gemini_model, generation_config, generation_config_json | |
| ##### Call LLM | |
| def call_llm_gemini(model_instance, model, messages, generation_config): | |
| response = model_instance.generate_content(messages, | |
| generation_config=generation_config) | |
| try: | |
| response_content = response.text.strip() | |
| except: | |
| response_content = 'Failed' | |
| nb_input_tokens = model_instance.count_tokens(messages).total_tokens | |
| nb_output_tokens = model_instance.count_tokens(response_content).total_tokens | |
| price = nb_input_tokens * price_token[model]['input'] + nb_output_tokens * price_token[model]['output'] | |
| print(f"input tokens: {nb_input_tokens}; output tokens: {nb_output_tokens}, price: {price}") | |
| return response_content, nb_input_tokens, nb_output_tokens, price | |
| ##### Prompts | |
| def get_prompt_brand(language): | |
| prompt = "What is the brand of this product? Answer with the brand name and nothing else." | |
| return prompt | |
| def get_prompt_product_name(language): | |
| prompt = f"What is the {language} product name of this product? Answer in {language} with the product name and nothing else." | |
| return prompt | |
| def get_prompt_ingredients(language): | |
| prompt=f""" | |
| You will be given an image of a product label or packaging. Your task is to extract the ingredients list from this image, focusing specifically on the {language} language version. Here's how to approach this task: | |
| 1. Analyze the provided image | |
| 2. Locate the ingredients list on the product label or packaging. | |
| 3. Identify the {language} language section of the ingredients list. | |
| 4. Extract only the {language} ingredients list. Ignore any ingredients lists in other languages, even if they are present in the image. | |
| 5. If there are multiple {language} ingredient lists (e.g., for different flavors or varieties), extract all of them and clearly separate them. | |
| 6. Do not include any additional information such as allergen warnings, nutritional information, or preparation instructions, even if they are in {language}. | |
| 7. If you cannot find a {language} ingredients list in the image, state that no {language} ingredients list was found. | |
| 8. If the image is unclear, state that the image quality is insufficient to extract the ingredients list accurately. | |
| Provide your output in the following format: | |
| <ingredients> | |
| [Insert the extracted {language} ingredients list here, exactly as it appears in the image] | |
| </ingredients> | |
| Remember, include only the text of the {language} ingredients list, nothing else. Do not translate or interpret the ingredients; simply transcribe them as they appear in {language}. | |
| """ | |
| return prompt | |
| def get_prompt_nutritional_info(): | |
| prompt = """Extract the following nutritional information from the product image and present it **only** in JSON format, providing only the values per 100g: Energy kJ, Energy kcal, Fat, Saturated fat, Carbohydrates, Sugars, Fibers, Proteins, Salt. | |
| If you can't extract the nutritional information from the image, you need to say why it's the case. | |
| The response should contain **only** the following JSON: | |
| { | |
| "Energy kJ": 1500, | |
| "Energy kcal": 360, | |
| "Fat": 18, | |
| "Saturated fat": 7, | |
| "Carbohydrates": 40, | |
| "Sugars": 25, | |
| "Fibers": 3, | |
| "Proteins": 8, | |
| "Salt": 0.5 | |
| } | |
| No additional text or explanation should be included. | |
| """ | |
| return prompt | |
| ##### Extract data functions | |
| def extract_text_from_picture_baseline(OUTPUT_DIR, | |
| df_product_id, | |
| prompt, | |
| type_photo, | |
| generation_config, | |
| max_entry=None, | |
| progress=None | |
| ): | |
| outputs = [] | |
| if max_entry is None: | |
| max_entry = len(df_product_id) | |
| for i in progress.tqdm(range(max_entry)) if progress is not None else range(max_entry): | |
| start_time = time.time() | |
| product = df_product_id.loc[i] | |
| product_id = product['ID'] | |
| photo_path = f'{OUTPUT_DIR}/photos/{product_id}_{type_photo}.jpg' | |
| download_file(url=product[type_photo], save_path=photo_path) | |
| photo = upload_file(photo_path) | |
| messages = [photo, prompt] | |
| try: | |
| response_content, _, _, price = call_llm_gemini(gemini_model, model, messages, generation_config) | |
| print(response_content) | |
| processing_time = time.time() - start_time | |
| output = [product_id, response_content, round(price, 4), round(processing_time, 2)] | |
| outputs.append(output) | |
| except: | |
| print(f"Error for ID: {product_id}") | |
| df_output = pd.DataFrame(outputs, columns=['ID', 'Extracted_Text', 'Price', 'Processing time']) | |
| return df_output | |
| def extract_brand(OUTPUT_DIR, df_product_id, language, progress=gr.Progress()): | |
| df_output = extract_text_from_picture_baseline(OUTPUT_DIR, | |
| df_product_id, | |
| get_prompt_brand(language), | |
| type_photo="Front photo", | |
| generation_config=generation_config, | |
| max_entry=None, | |
| progress=progress) | |
| df_output.to_csv(f'{OUTPUT_DIR}/data_extraction/brand.csv', index=False) | |
| return df_output | |
| def extract_product_name(OUTPUT_DIR, df_product_id, language, progress=gr.Progress()): | |
| df_output = extract_text_from_picture_baseline(OUTPUT_DIR, df_product_id, | |
| get_prompt_product_name(language), | |
| type_photo="Front photo", | |
| generation_config=generation_config, | |
| max_entry=None, | |
| progress=progress) | |
| df_output.to_csv(f'{OUTPUT_DIR}/data_extraction/product_name.csv', index=False) | |
| return df_output | |
| def extract_ingredients(OUTPUT_DIR, df_product_id, language, progress=gr.Progress()): | |
| df_output = extract_text_from_picture_baseline(OUTPUT_DIR, df_product_id, | |
| get_prompt_ingredients(language), | |
| type_photo="Ingredients photo", | |
| generation_config=generation_config, | |
| max_entry=None, | |
| progress=progress) | |
| df_output.to_csv(f'{OUTPUT_DIR}/data_extraction/ingredients.csv', index=False) | |
| return df_output | |
| def convert_json_string_to_dict(json_string, record_id): | |
| default_keys = ['Energy kJ', 'Energy kcal', 'Fat', 'Saturated fat', 'Carbohydrates', 'Sugars', 'Fibers', 'Proteins', | |
| 'Salt'] | |
| clean_string = json_string | |
| if not clean_string: | |
| print(f"ID: {record_id} - La chaîne est vide ou invalide : '{json_string}'") | |
| return {key: -1 for key in default_keys} | |
| try: | |
| return json.loads(clean_string) | |
| except json.JSONDecodeError: | |
| print(f"ID: {record_id} - Erreur lors du décodage du JSON : '{json_string}'") | |
| return {key: -1 for key in default_keys} | |
| def extract_nutritional_values(OUTPUT_DIR, df_product_id, language, progress=gr.Progress()): | |
| df_output = extract_text_from_picture_baseline(OUTPUT_DIR, df_product_id, | |
| get_prompt_nutritional_info(), | |
| type_photo="Nutritionals photo", | |
| generation_config=generation_config_json, | |
| max_entry=None, | |
| progress=progress) | |
| df_output.to_csv(f'{OUTPUT_DIR}/data_extraction/nutritional_values.csv', index=False) | |
| df_output['Extracted_Text_Json'] = df_output.apply( | |
| lambda row: convert_json_string_to_dict(row['Extracted_Text'], row['ID']), axis=1) | |
| keys = list(df_output['Extracted_Text_Json'].iloc[ | |
| 0].keys()) # On prend les clés du premier dictionnaire comme référence | |
| for key in keys: | |
| df_key = df_output[['ID', 'Price', 'Processing time']].copy() | |
| df_key['Extracted_Text'] = df_output['Extracted_Text_Json'].apply(lambda x: x.get(key, None)) | |
| df_key.to_csv(f"{OUTPUT_DIR}/data_extraction/{key.replace(' ', '_').lower()}.csv", index=False) | |
| df_output = df_output[['ID', 'Extracted_Text', 'Price', 'Processing time']] | |
| return df_output | |
| model = 'gemini-1.5-pro-002' | |
| gemini_model, generation_config, generation_config_json = load_gemini_model(model) |