Spaces:
Sleeping
Sleeping
import requests | |
import pandas as pd | |
import os | |
import time | |
import gradio as gr | |
import json | |
import google.generativeai as genai | |
from dotenv import load_dotenv | |
load_dotenv() | |
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY') | |
genai.configure(api_key=GOOGLE_API_KEY) | |
############## Photos ############## | |
def download_file(url, save_path): | |
try: | |
# Send a GET request to the URL | |
response = requests.get(url) | |
# Check if the request was successful (status code 200) | |
if response.status_code == 200: | |
# Open the specified path in binary-write mode and save the content | |
with open(save_path, 'wb') as file: | |
file.write(response.content) | |
else: | |
print(f"Failed to download image. Status code: {response.status_code}") | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
def upload_file(photo_path): | |
photo = genai.upload_file(photo_path) | |
return photo | |
###### Data extraction | |
## Helper function to initialize model | |
price_token={'gemini-1.5-pro-002': {'input': 1.25 / 1000000, 'output': 5 / 1000000} | |
} | |
gemini_safety_settings = [ | |
{ | |
"category": "HARM_CATEGORY_DANGEROUS", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_HARASSMENT", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_HATE_SPEECH", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
"threshold": "BLOCK_NONE", | |
}, | |
{ | |
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
"threshold": "BLOCK_NONE", | |
}, | |
] | |
def load_gemini_model(model_name): | |
generation_config = genai.types.GenerationConfig( | |
# Only one candidate for now. | |
candidate_count=1, | |
max_output_tokens=4000, | |
temperature=0, | |
response_mime_type="text/plain" | |
) | |
generation_config_json = genai.types.GenerationConfig( | |
# Only one candidate for now. | |
candidate_count=1, | |
max_output_tokens=4000, | |
temperature=0, | |
response_mime_type= "application/json" | |
) | |
system_prompt = ["You are a helpful assistant."] | |
gemini_model = genai.GenerativeModel(model_name, system_instruction=system_prompt, | |
safety_settings=gemini_safety_settings) | |
return gemini_model, generation_config, generation_config_json | |
##### Call LLM | |
def call_llm_gemini(model_instance, model, messages, generation_config): | |
response = model_instance.generate_content(messages, | |
generation_config=generation_config) | |
try: | |
response_content = response.text.strip() | |
except: | |
response_content = 'Failed' | |
nb_input_tokens = model_instance.count_tokens(messages).total_tokens | |
nb_output_tokens = model_instance.count_tokens(response_content).total_tokens | |
price = nb_input_tokens * price_token[model]['input'] + nb_output_tokens * price_token[model]['output'] | |
print(f"input tokens: {nb_input_tokens}; output tokens: {nb_output_tokens}, price: {price}") | |
return response_content, nb_input_tokens, nb_output_tokens, price | |
##### Prompts | |
def get_prompt_brand(language): | |
prompt = "What is the brand of this product? Answer with the brand name and nothing else." | |
return prompt | |
def get_prompt_product_name(language): | |
prompt = f"What is the {language} product name of this product? Answer in {language} with the product name and nothing else." | |
return prompt | |
def get_prompt_ingredients(language): | |
prompt=f""" | |
You will be given an image of a product label or packaging. Your task is to extract the ingredients list from this image, focusing specifically on the {language} language version. Here's how to approach this task: | |
1. Analyze the provided image | |
2. Locate the ingredients list on the product label or packaging. | |
3. Identify the {language} language section of the ingredients list. | |
4. Extract only the {language} ingredients list. Ignore any ingredients lists in other languages, even if they are present in the image. | |
5. If there are multiple {language} ingredient lists (e.g., for different flavors or varieties), extract all of them and clearly separate them. | |
6. Do not include any additional information such as allergen warnings, nutritional information, or preparation instructions, even if they are in {language}. | |
7. If you cannot find a {language} ingredients list in the image, state that no {language} ingredients list was found. | |
8. If the image is unclear, state that the image quality is insufficient to extract the ingredients list accurately. | |
Provide your output in the following format: | |
<ingredients> | |
[Insert the extracted {language} ingredients list here, exactly as it appears in the image] | |
</ingredients> | |
Remember, include only the text of the {language} ingredients list, nothing else. Do not translate or interpret the ingredients; simply transcribe them as they appear in {language}. | |
""" | |
return prompt | |
def get_prompt_nutritional_info(): | |
prompt = """Extract the following nutritional information from the product image and present it **only** in JSON format, providing only the values per 100g: Energy kJ, Energy kcal, Fat, Saturated fat, Carbohydrates, Sugars, Fibers, Proteins, Salt. | |
If you can't extract the nutritional information from the image, you need to say why it's the case. | |
The response should contain **only** the following JSON: | |
{ | |
"Energy kJ": 1500, | |
"Energy kcal": 360, | |
"Fat": 18, | |
"Saturated fat": 7, | |
"Carbohydrates": 40, | |
"Sugars": 25, | |
"Fibers": 3, | |
"Proteins": 8, | |
"Salt": 0.5 | |
} | |
No additional text or explanation should be included. | |
""" | |
return prompt | |
##### Extract data functions | |
def extract_text_from_picture_baseline(OUTPUT_DIR, | |
df_product_id, | |
prompt, | |
type_photo, | |
generation_config, | |
max_entry=None, | |
progress=None | |
): | |
outputs = [] | |
if max_entry is None: | |
max_entry = len(df_product_id) | |
for i in progress.tqdm(range(max_entry)) if progress is not None else range(max_entry): | |
start_time = time.time() | |
product = df_product_id.loc[i] | |
product_id = product['ID'] | |
photo_path = f'{OUTPUT_DIR}/photos/{product_id}_{type_photo}.jpg' | |
download_file(url=product[type_photo], save_path=photo_path) | |
photo = upload_file(photo_path) | |
messages = [photo, prompt] | |
try: | |
response_content, _, _, price = call_llm_gemini(gemini_model, model, messages, generation_config) | |
print(response_content) | |
processing_time = time.time() - start_time | |
output = [product_id, response_content, round(price, 4), round(processing_time, 2)] | |
outputs.append(output) | |
except: | |
print(f"Error for ID: {product_id}") | |
df_output = pd.DataFrame(outputs, columns=['ID', 'Extracted_Text', 'Price', 'Processing time']) | |
return df_output | |
def extract_brand(OUTPUT_DIR, df_product_id, language, progress=gr.Progress()): | |
df_output = extract_text_from_picture_baseline(OUTPUT_DIR, | |
df_product_id, | |
get_prompt_brand(language), | |
type_photo="Front photo", | |
generation_config=generation_config, | |
max_entry=None, | |
progress=progress) | |
df_output.to_csv(f'{OUTPUT_DIR}/data_extraction/brand.csv', index=False) | |
return df_output | |
def extract_product_name(OUTPUT_DIR, df_product_id, language, progress=gr.Progress()): | |
df_output = extract_text_from_picture_baseline(OUTPUT_DIR, df_product_id, | |
get_prompt_product_name(language), | |
type_photo="Front photo", | |
generation_config=generation_config, | |
max_entry=None, | |
progress=progress) | |
df_output.to_csv(f'{OUTPUT_DIR}/data_extraction/product_name.csv', index=False) | |
return df_output | |
def extract_ingredients(OUTPUT_DIR, df_product_id, language, progress=gr.Progress()): | |
df_output = extract_text_from_picture_baseline(OUTPUT_DIR, df_product_id, | |
get_prompt_ingredients(language), | |
type_photo="Ingredients photo", | |
generation_config=generation_config, | |
max_entry=None, | |
progress=progress) | |
df_output.to_csv(f'{OUTPUT_DIR}/data_extraction/ingredients.csv', index=False) | |
return df_output | |
def convert_json_string_to_dict(json_string, record_id): | |
default_keys = ['Energy kJ', 'Energy kcal', 'Fat', 'Saturated fat', 'Carbohydrates', 'Sugars', 'Fibers', 'Proteins', | |
'Salt'] | |
clean_string = json_string | |
if not clean_string: | |
print(f"ID: {record_id} - La chaîne est vide ou invalide : '{json_string}'") | |
return {key: -1 for key in default_keys} | |
try: | |
return json.loads(clean_string) | |
except json.JSONDecodeError: | |
print(f"ID: {record_id} - Erreur lors du décodage du JSON : '{json_string}'") | |
return {key: -1 for key in default_keys} | |
def extract_nutritional_values(OUTPUT_DIR, df_product_id, language, progress=gr.Progress()): | |
df_output = extract_text_from_picture_baseline(OUTPUT_DIR, df_product_id, | |
get_prompt_nutritional_info(), | |
type_photo="Nutritionals photo", | |
generation_config=generation_config_json, | |
max_entry=None, | |
progress=progress) | |
df_output.to_csv(f'{OUTPUT_DIR}/data_extraction/nutritional_values.csv', index=False) | |
df_output['Extracted_Text_Json'] = df_output.apply( | |
lambda row: convert_json_string_to_dict(row['Extracted_Text'], row['ID']), axis=1) | |
keys = list(df_output['Extracted_Text_Json'].iloc[ | |
0].keys()) # On prend les clés du premier dictionnaire comme référence | |
for key in keys: | |
df_key = df_output[['ID', 'Price', 'Processing time']].copy() | |
df_key['Extracted_Text'] = df_output['Extracted_Text_Json'].apply(lambda x: x.get(key, None)) | |
df_key.to_csv(f"{OUTPUT_DIR}/data_extraction/{key.replace(' ', '_').lower()}.csv", index=False) | |
df_output = df_output[['ID', 'Extracted_Text', 'Price', 'Processing time']] | |
return df_output | |
model = 'gemini-1.5-pro-002' | |
gemini_model, generation_config, generation_config_json = load_gemini_model(model) |