lenamerkli's picture
Fix several bugs in main.py
a876c78 verified
raw
history blame
9.54 kB
import numpy as np
import torch
import math
import easyocr
import cv2
import os
import base64
import json
import requests
from llama_cpp import Llama
from PIL import Image
from dotenv import load_dotenv
from utils import *
load_dotenv()
SCALE_FACTOR = 4
MAX_SIZE = 5_000_000
MAX_SIDE = 8_000
# ENGINE = ['easyocr']
# ENGINE = ['anthropic', 'claude-3-5-sonnet-20240620']
ENGINE = ['llama_cpp/v2/vision', 'qwen-vl-next_b2583']
def main() -> None:
model_weights = torch.load(relative_path('vision_model.pt'))
model = NeuralNet()
model.load_state_dict(model_weights)
model.to(DEVICE)
model.eval()
with torch.no_grad():
file_path = input('Enter file path: ')
with Image.open(file_path) as image:
image_size = image.size
image = image.resize(IMAGE_SIZE, Image.Resampling.LANCZOS)
image = TRANSFORM(image).to(DEVICE)
output = model(image).tolist()[0]
data = {
'top': {
'left': {
'x': output[0] * image_size[0],
'y': output[1] * image_size[1],
},
'right': {
'x': output[2] * image_size[0],
'y': output[3] * image_size[1],
},
},
'bottom': {
'left': {
'x': output[4] * image_size[0],
'y': output[5] * image_size[1],
},
'right': {
'x': output[6] * image_size[0],
'y': output[7] * image_size[1],
},
},
'curvature': {
'top': {
'x': output[8] * image_size[0],
'y': output[9] * image_size[1],
},
'bottom': {
'x': output[10] * image_size[0],
'y': output[11] * image_size[1],
},
},
}
print(f"{data=}")
image = cv2.imread(file_path)
size_x = ((data['top']['right']['x'] - data['top']['left']['x']) +
(data['bottom']['right']['x'] - data['bottom']['left']['x'])) / 2
size_y = ((data['top']['right']['y'] - data['top']['left']['y']) +
(data['bottom']['right']['y'] - data['bottom']['left']['y'])) / 2
margin_x = size_x * MARGIN
margin_y = size_y * MARGIN
points = np.array([
(max(data['top']['left']['x'] - margin_x, 0),
max(data['top']['left']['y'] - margin_y, 0)),
(min(data['top']['right']['x'] + margin_x, image_size[0]),
max(data['top']['right']['y'] - margin_y, 0)),
(min(data['bottom']['right']['x'] + margin_x, image_size[0]),
min(data['bottom']['right']['y'] + margin_y, image_size[1])),
(max(data['bottom']['left']['x'] - margin_x, 0),
min(data['bottom']['left']['y'] + margin_y, image_size[1])),
(data['curvature']['top']['x'],
max(data['curvature']['top']['y'] - margin_y, 0)),
(data['curvature']['bottom']['x'],
min(data['curvature']['bottom']['y'] + margin_y, image_size[1])),
], dtype=np.float32)
points_float: list[list[float]] = points.tolist()
max_height = int(max([ # y: top left - bottom left, top right - bottom right, curvature top - curvature bottom
abs(points_float[0][1] - points_float[3][1]),
abs(points_float[1][1] - points_float[4][1]),
abs(points_float[2][1] - points_float[5][1]),
])) * SCALE_FACTOR
max_width = int(max([ # x: top left - top right, bottom left - bottom right
abs(points_float[0][0] - points_float[1][0]),
abs(points_float[3][0] - points_float[2][0]),
])) * SCALE_FACTOR
destination_points = np.array([
[0, 0],
[max_width - 1, 0],
[max_width - 1, max_height - 1],
[0, max_height - 1],
[max_width // 2, 0],
[max_width // 2, max_height - 1],
], dtype=np.float32)
homography, _ = cv2.findHomography(points, destination_points)
warped_image = cv2.warpPerspective(image, homography, (max_width, max_height))
cv2.imwrite('_warped_image.png', warped_image)
del data
if ENGINE[0] == 'easyocr':
reader = easyocr.Reader(['de', 'fr', 'en'], gpu=True)
result = reader.readtext('_warped_image.png')
os.remove('_warped_image.png')
text = '\n'.join([r[1] for r in result])
ingredients = {}
elif ENGINE[0] == 'anthropic':
decrease_size('_warped_image.png', '_warped_image.webp', MAX_SIZE, MAX_SIDE)
os.remove('_warped_image.png')
with open('_warped_image.webp', 'rb') as f:
base64_image = base64.b64encode(f.read()).decode('utf-8')
response = requests.post(
url='https://api.anthropic.com/v1/messages',
headers={
'x-api-key': os.environ['ANTHROPIC_API_KEY'],
'anthropic-version': '2023-06-01',
'content-type': 'application/json',
},
data=json.dumps({
'model': ENGINE[1],
'max_tokens': 1024,
'messages': [
{
'role': 'user', 'content': [
{
'type': 'image',
'source': {
'type': 'base64',
'media_type': 'image/webp',
'data': base64_image,
},
},
{
'type': 'text',
'text': PROMPT_CLAUDE,
},
],
},
],
}),
)
os.remove('_warped_image.webp')
try:
data = response.json()
ingredients = json.loads('{' + data['content'][0]['text'].split('{', 1)[-1].rsplit('}', 1)[0] + '}')
except Exception as e:
print(data)
raise e
text = ''
elif ENGINE[0] == 'llama_cpp/v2/vision':
decrease_size('_warped_image.png', '_warped_image.webp', MAX_SIZE, MAX_SIDE)
# os.remove('_warped_image.png')
response = requests.post(
url='http://127.0.0.1:11434/llama_cpp/v2/vision',
headers={
'x-version': '2024-05-21',
'content-type': 'application/json',
},
data=json.dumps({
'task': PROMPT_VISION,
'model': ENGINE[1],
'image_path': relative_path('_warped_image.webp'),
}),
)
os.remove('_warped_image.webp')
text: str = response.json()['text']
ingredients = {}
else:
raise ValueError(f'Unknown engine: {ENGINE[0]}')
if text != '':
if DEVICE == 'cuda':
n_gpu_layers = -1
else:
n_gpu_layers = 0
llm = Llama(
model_path=relative_path('llm.Q4_K_M.gguf'),
n_gpu_layers=n_gpu_layers,
)
llm_result = llm.create_chat_completion(
messages=[
{
'role': 'system',
'content': SYSTEM_PROMPT,
},
{
'role': 'user',
'content': PROMPT_LLM.replace('{{old_data}}', text),
},
],
max_tokens=1024,
temperature=0,
# grammar=GRAMMAR,
)
try:
ingredients = json.loads(
'{' + llm_result['choices'][0]['message']['content'].split('{', 1)[-1].rsplit('}', 1)[0] + '}')
except Exception as e:
print(f"{llm_result=}")
raise e
animal_ingredients = [item for item in ingredients['Zutaten'] if item in ANIMAL]
sometimes_animal_ingredients = [item for item in ingredients['Zutaten'] if item in SOMETIMES_ANIMAL]
milk_ingredients = ([item for item in ingredients['Zutaten'] if item in MILK]
+ [item for item in ingredients['Verunreinigungen'] if item in MILK])
gluten_ingredients = ([item for item in ingredients['Zutaten'] if item in GLUTEN]
+ [item for item in ingredients['Verunreinigungen'] if item in GLUTEN])
print('=' * 64)
print('Zutaten: ' + ', '.join(ingredients['Zutaten']))
print('=' * 64)
print(('Kann Spuren von ' + ', '.join(ingredients['Verunreinigungen']) + ' enthalten.')
if len(ingredients['Verunreinigungen']) > 0 else 'ohne Verunreinigungen')
print('=' * 64)
print('Gefundene tierische Zutaten: '
+ (', '.join(animal_ingredients) if len(animal_ingredients) > 0 else 'keine'))
print('=' * 64)
print('Gefundene potenzielle tierische Zutaten: '
+ (', '.join(sometimes_animal_ingredients) if len(sometimes_animal_ingredients) > 0 else 'keine'))
print('=' * 64)
print('Gefundene Milchprodukte: ' + (', '.join(milk_ingredients) if len(milk_ingredients) > 0 else 'keine'))
print('=' * 64)
print('Gefundene Gluten: ' + (', '.join(gluten_ingredients) if len(gluten_ingredients) > 0 else 'keine'))
print('=' * 64)
print(LEGAL_NOTICE)
print('=' * 64)
if __name__ == '__main__':
main()