lenamerkli commited on
Commit
a876c78
·
verified ·
1 Parent(s): d2f3f0b

Fix several bugs in main.py

Browse files
Files changed (1) hide show
  1. main.py +241 -240
main.py CHANGED
@@ -1,240 +1,241 @@
1
- import numpy as np
2
- import torch
3
- import math
4
- import easyocr
5
- import cv2
6
- import os
7
- import base64
8
- import json
9
- import requests
10
- from llama_cpp import Llama
11
- from PIL import Image
12
- from dotenv import load_dotenv
13
-
14
- from utils import *
15
-
16
- load_dotenv()
17
-
18
- SCALE_FACTOR = 4
19
- MAX_SIZE = 5_000_000
20
- MAX_SIDE = 8_000
21
- # ENGINE = ['easyocr']
22
- # ENGINE = ['anthropic', 'claude-3-5-sonnet-20240620']
23
- ENGINE = ['llama_cpp/v2/vision', 'qwen-vl-next_b2583']
24
-
25
-
26
- def main() -> None:
27
- model_weights = torch.load(relative_path('vision_model.pt'))
28
- model = NeuralNet()
29
- model.load_state_dict(model_weights)
30
- model.to(DEVICE)
31
- model.eval()
32
- with torch.no_grad():
33
- file_path = input('Enter file path: ')
34
- with Image.open(file_path) as image:
35
- image_size = image.size
36
- image = image.resize(IMAGE_SIZE, Image.Resampling.LANCZOS)
37
- image = TRANSFORM(image).to(DEVICE)
38
- output = model(image).tolist()[0]
39
- data = {
40
- 'top': {
41
- 'left': {
42
- 'x': output[0] * image_size[0],
43
- 'y': output[1] * image_size[1],
44
- },
45
- 'right': {
46
- 'x': output[2] * image_size[0],
47
- 'y': output[3] * image_size[1],
48
- },
49
- },
50
- 'bottom': {
51
- 'left': {
52
- 'x': output[4] * image_size[0],
53
- 'y': output[5] * image_size[1],
54
- },
55
- 'right': {
56
- 'x': output[6] * image_size[0],
57
- 'y': output[7] * image_size[1],
58
- },
59
- },
60
- 'curvature': {
61
- 'top': {
62
- 'x': output[8] * image_size[0],
63
- 'y': output[9] * image_size[1],
64
- },
65
- 'bottom': {
66
- 'x': output[10] * image_size[0],
67
- 'y': output[11] * image_size[1],
68
- },
69
- },
70
- }
71
- print(f"{data=}")
72
- image = cv2.imread(file_path)
73
- size_x = ((data['top']['right']['x'] - data['top']['left']['x']) +
74
- (data['bottom']['right']['x'] - data['bottom']['left']['x'])) / 2
75
- size_y = ((data['top']['right']['y'] - data['top']['left']['y']) +
76
- (data['bottom']['right']['y'] - data['bottom']['left']['y'])) / 2
77
- margin_x = size_x * MARGIN
78
- margin_y = size_y * MARGIN
79
- points = np.array([
80
- (max(data['top']['left']['x'] - margin_x, 0),
81
- max(data['top']['left']['y'] - margin_y, 0)),
82
- (min(data['top']['right']['x'] + margin_x, image_size[0]),
83
- max(data['top']['right']['y'] - margin_y, 0)),
84
- (min(data['bottom']['right']['x'] + margin_x, image_size[0]),
85
- min(data['bottom']['right']['y'] + margin_y, image_size[1])),
86
- (max(data['bottom']['left']['x'] - margin_x, 0),
87
- min(data['bottom']['left']['y'] + margin_y, image_size[1])),
88
- (data['curvature']['top']['x'],
89
- max(data['curvature']['top']['y'] - margin_y, 0)),
90
- (data['curvature']['bottom']['x'],
91
- min(data['curvature']['bottom']['y'] + margin_y, image_size[1])),
92
- ], dtype=np.float32)
93
- points_float: list[list[float]] = points.tolist()
94
- max_height = int(max([ # y: top left - bottom left, top right - bottom right, curvature top - curvature bottom
95
- abs(points_float[0][1] - points_float[3][1]),
96
- abs(points_float[1][1] - points_float[4][1]),
97
- abs(points_float[2][1] - points_float[5][1]),
98
- ])) * SCALE_FACTOR
99
- max_width = int(max([ # x: top left - top right, bottom left - bottom right
100
- abs(points_float[0][0] - points_float[1][0]),
101
- abs(points_float[3][0] - points_float[2][0]),
102
- ])) * SCALE_FACTOR
103
- destination_points = np.array([
104
- [0, 0],
105
- [max_width - 1, 0],
106
- [max_width - 1, max_height - 1],
107
- [0, max_height - 1],
108
- [max_width // 2, 0],
109
- [max_width // 2, max_height - 1],
110
- ], dtype=np.float32)
111
- homography, _ = cv2.findHomography(points, destination_points)
112
- warped_image = cv2.warpPerspective(image, homography, (max_width, max_height))
113
- cv2.imwrite('_warped_image.png', warped_image)
114
- del data
115
- if ENGINE[0] == 'easyocr':
116
- reader = easyocr.Reader(['de', 'fr', 'en'], gpu=True)
117
- result = reader.readtext('_warped_image.png')
118
- # os.remove('_warped_image.png')
119
- text = '\n'.join([r[1] for r in result])
120
- ingredients = {}
121
- elif ENGINE[0] == 'anthropic':
122
- decrease_size('_warped_image.png', '_warped_image.webp', MAX_SIZE, MAX_SIDE)
123
- # os.remove('_warped_image.png')
124
- with open('_warped_image.webp', 'rb') as f:
125
- base64_image = base64.b64encode(f.read()).decode('utf-8')
126
- response = requests.post(
127
- url='https://api.anthropic.com/v1/messages',
128
- headers={
129
- 'x-api-key': os.environ['ANTHROPIC_API_KEY'],
130
- 'anthropic-version': '2023-06-01',
131
- 'content-type': 'application/json',
132
- },
133
- data=json.dumps({
134
- 'model': ENGINE[1],
135
- 'max_tokens': 1024,
136
- 'messages': [
137
- {
138
- 'role': 'user', 'content': [
139
- {
140
- 'type': 'image',
141
- 'source': {
142
- 'type': 'base64',
143
- 'media_type': 'image/webp',
144
- 'data': base64_image,
145
- },
146
- },
147
- {
148
- 'type': 'text',
149
- 'text': PROMPT_CLAUDE,
150
- },
151
- ],
152
- },
153
- ],
154
- }),
155
- )
156
- # os.remove('_warped_image.webp')
157
- try:
158
- data = response.json()
159
- ingredients = json.loads('{' + data['content'][0]['text'].split('{', 1)[-1].rsplit('}', 1)[0] + '}')
160
- except Exception as e:
161
- print(data)
162
- raise e
163
- text = ''
164
- elif ENGINE[0] == 'llama_cpp/v2/vision':
165
- decrease_size('_warped_image.png', '_warped_image.webp', MAX_SIZE, MAX_SIDE)
166
- # os.remove('_warped_image.png')
167
- response = requests.post(
168
- url='http://127.0.0.1:11434/llama_cpp/v2/vision',
169
- headers={
170
- 'x-version': '2024-05-21',
171
- 'content-type': 'application/json',
172
- },
173
- data=json.dumps({
174
- 'task': PROMPT_VISION,
175
- 'model': ENGINE[1],
176
- 'image_path': relative_path('_warped_image.webp'),
177
- }),
178
- )
179
- # os.remove('_warped_image.webp')
180
- text: str = response.json()['text']
181
- ingredients = {}
182
- else:
183
- raise ValueError(f'Unknown engine: {ENGINE[0]}')
184
- if text != '':
185
- if DEVICE == 'cuda':
186
- n_gpu_layers = -1
187
- else:
188
- n_gpu_layers = 0
189
- llm = Llama(
190
- model_path=relative_path('llm.Q4_K_M.gguf'),
191
- n_gpu_layers=n_gpu_layers,
192
- )
193
- llm_result = llm.create_chat_completion(
194
- messages=[
195
- {
196
- 'role': 'system',
197
- 'content': SYSTEM_PROMPT,
198
- },
199
- {
200
- 'role': 'user',
201
- 'content': PROMPT_LLM.replace('{{old_data}}', text),
202
- },
203
- ],
204
- max_tokens=1024,
205
- temperature=0,
206
- # grammar=GRAMMAR,
207
- )
208
- try:
209
- ingredients = json.loads(
210
- '{' + llm_result['choices'][0]['message']['content'].split('{', 1)[-1].rsplit('}', 1)[0] + '}')
211
- except Exception as e:
212
- print(f"{llm_result=}")
213
- raise e
214
- animal_ingredients = [item for item in ingredients['Zutaten'] if item in ANIMAL]
215
- sometimes_animal_ingredients = [item for item in ingredients['Zutaten'] if item in SOMETIMES_ANIMAL]
216
- milk_ingredients = ([item for item in ingredients['Zutaten'] if item in MILK]
217
- + [item for item in ingredients['Verunreinigungen'] if item in MILK])
218
- gluten_ingredients = ([item for item in ingredients['Zutaten'] if item in GLUTEN]
219
- + [item for item in ingredients['Verunreinigungen'] if item in GLUTEN])
220
- print('=' * 64)
221
- print('Zutaten: ' + ', '.join(ingredients['Zutaten']))
222
- print('=' * 64)
223
- print('Kann Spuren von ' + ', '.join(ingredients['Verunreinigungen']) + ' enthalten.')
224
- print('=' * 64)
225
- print('Gefundene tierische Zutaten: '
226
- + (', '.join(animal_ingredients) if len(animal_ingredients) > 0 else 'keine'))
227
- print('=' * 64)
228
- print('Gefundene potenzielle tierische Zutaten: '
229
- + (', '.join(sometimes_animal_ingredients) if len(sometimes_animal_ingredients) > 0 else 'keine'))
230
- print('=' * 64)
231
- print('Gefundene Milchprodukte: ' + ', '.join(milk_ingredients) if len(milk_ingredients) > 0 else 'keine')
232
- print('=' * 64)
233
- print('Gefundene Gluten: ' + ', '.join(gluten_ingredients) if len(gluten_ingredients) > 0 else 'keine')
234
- print('=' * 64)
235
- print(LEGAL_NOTICE)
236
- print('=' * 64)
237
-
238
-
239
- if __name__ == '__main__':
240
- main()
 
 
1
+ import numpy as np
2
+ import torch
3
+ import math
4
+ import easyocr
5
+ import cv2
6
+ import os
7
+ import base64
8
+ import json
9
+ import requests
10
+ from llama_cpp import Llama
11
+ from PIL import Image
12
+ from dotenv import load_dotenv
13
+
14
+ from utils import *
15
+
16
+ load_dotenv()
17
+
18
+ SCALE_FACTOR = 4
19
+ MAX_SIZE = 5_000_000
20
+ MAX_SIDE = 8_000
21
+ # ENGINE = ['easyocr']
22
+ # ENGINE = ['anthropic', 'claude-3-5-sonnet-20240620']
23
+ ENGINE = ['llama_cpp/v2/vision', 'qwen-vl-next_b2583']
24
+
25
+
26
+ def main() -> None:
27
+ model_weights = torch.load(relative_path('vision_model.pt'))
28
+ model = NeuralNet()
29
+ model.load_state_dict(model_weights)
30
+ model.to(DEVICE)
31
+ model.eval()
32
+ with torch.no_grad():
33
+ file_path = input('Enter file path: ')
34
+ with Image.open(file_path) as image:
35
+ image_size = image.size
36
+ image = image.resize(IMAGE_SIZE, Image.Resampling.LANCZOS)
37
+ image = TRANSFORM(image).to(DEVICE)
38
+ output = model(image).tolist()[0]
39
+ data = {
40
+ 'top': {
41
+ 'left': {
42
+ 'x': output[0] * image_size[0],
43
+ 'y': output[1] * image_size[1],
44
+ },
45
+ 'right': {
46
+ 'x': output[2] * image_size[0],
47
+ 'y': output[3] * image_size[1],
48
+ },
49
+ },
50
+ 'bottom': {
51
+ 'left': {
52
+ 'x': output[4] * image_size[0],
53
+ 'y': output[5] * image_size[1],
54
+ },
55
+ 'right': {
56
+ 'x': output[6] * image_size[0],
57
+ 'y': output[7] * image_size[1],
58
+ },
59
+ },
60
+ 'curvature': {
61
+ 'top': {
62
+ 'x': output[8] * image_size[0],
63
+ 'y': output[9] * image_size[1],
64
+ },
65
+ 'bottom': {
66
+ 'x': output[10] * image_size[0],
67
+ 'y': output[11] * image_size[1],
68
+ },
69
+ },
70
+ }
71
+ print(f"{data=}")
72
+ image = cv2.imread(file_path)
73
+ size_x = ((data['top']['right']['x'] - data['top']['left']['x']) +
74
+ (data['bottom']['right']['x'] - data['bottom']['left']['x'])) / 2
75
+ size_y = ((data['top']['right']['y'] - data['top']['left']['y']) +
76
+ (data['bottom']['right']['y'] - data['bottom']['left']['y'])) / 2
77
+ margin_x = size_x * MARGIN
78
+ margin_y = size_y * MARGIN
79
+ points = np.array([
80
+ (max(data['top']['left']['x'] - margin_x, 0),
81
+ max(data['top']['left']['y'] - margin_y, 0)),
82
+ (min(data['top']['right']['x'] + margin_x, image_size[0]),
83
+ max(data['top']['right']['y'] - margin_y, 0)),
84
+ (min(data['bottom']['right']['x'] + margin_x, image_size[0]),
85
+ min(data['bottom']['right']['y'] + margin_y, image_size[1])),
86
+ (max(data['bottom']['left']['x'] - margin_x, 0),
87
+ min(data['bottom']['left']['y'] + margin_y, image_size[1])),
88
+ (data['curvature']['top']['x'],
89
+ max(data['curvature']['top']['y'] - margin_y, 0)),
90
+ (data['curvature']['bottom']['x'],
91
+ min(data['curvature']['bottom']['y'] + margin_y, image_size[1])),
92
+ ], dtype=np.float32)
93
+ points_float: list[list[float]] = points.tolist()
94
+ max_height = int(max([ # y: top left - bottom left, top right - bottom right, curvature top - curvature bottom
95
+ abs(points_float[0][1] - points_float[3][1]),
96
+ abs(points_float[1][1] - points_float[4][1]),
97
+ abs(points_float[2][1] - points_float[5][1]),
98
+ ])) * SCALE_FACTOR
99
+ max_width = int(max([ # x: top left - top right, bottom left - bottom right
100
+ abs(points_float[0][0] - points_float[1][0]),
101
+ abs(points_float[3][0] - points_float[2][0]),
102
+ ])) * SCALE_FACTOR
103
+ destination_points = np.array([
104
+ [0, 0],
105
+ [max_width - 1, 0],
106
+ [max_width - 1, max_height - 1],
107
+ [0, max_height - 1],
108
+ [max_width // 2, 0],
109
+ [max_width // 2, max_height - 1],
110
+ ], dtype=np.float32)
111
+ homography, _ = cv2.findHomography(points, destination_points)
112
+ warped_image = cv2.warpPerspective(image, homography, (max_width, max_height))
113
+ cv2.imwrite('_warped_image.png', warped_image)
114
+ del data
115
+ if ENGINE[0] == 'easyocr':
116
+ reader = easyocr.Reader(['de', 'fr', 'en'], gpu=True)
117
+ result = reader.readtext('_warped_image.png')
118
+ os.remove('_warped_image.png')
119
+ text = '\n'.join([r[1] for r in result])
120
+ ingredients = {}
121
+ elif ENGINE[0] == 'anthropic':
122
+ decrease_size('_warped_image.png', '_warped_image.webp', MAX_SIZE, MAX_SIDE)
123
+ os.remove('_warped_image.png')
124
+ with open('_warped_image.webp', 'rb') as f:
125
+ base64_image = base64.b64encode(f.read()).decode('utf-8')
126
+ response = requests.post(
127
+ url='https://api.anthropic.com/v1/messages',
128
+ headers={
129
+ 'x-api-key': os.environ['ANTHROPIC_API_KEY'],
130
+ 'anthropic-version': '2023-06-01',
131
+ 'content-type': 'application/json',
132
+ },
133
+ data=json.dumps({
134
+ 'model': ENGINE[1],
135
+ 'max_tokens': 1024,
136
+ 'messages': [
137
+ {
138
+ 'role': 'user', 'content': [
139
+ {
140
+ 'type': 'image',
141
+ 'source': {
142
+ 'type': 'base64',
143
+ 'media_type': 'image/webp',
144
+ 'data': base64_image,
145
+ },
146
+ },
147
+ {
148
+ 'type': 'text',
149
+ 'text': PROMPT_CLAUDE,
150
+ },
151
+ ],
152
+ },
153
+ ],
154
+ }),
155
+ )
156
+ os.remove('_warped_image.webp')
157
+ try:
158
+ data = response.json()
159
+ ingredients = json.loads('{' + data['content'][0]['text'].split('{', 1)[-1].rsplit('}', 1)[0] + '}')
160
+ except Exception as e:
161
+ print(data)
162
+ raise e
163
+ text = ''
164
+ elif ENGINE[0] == 'llama_cpp/v2/vision':
165
+ decrease_size('_warped_image.png', '_warped_image.webp', MAX_SIZE, MAX_SIDE)
166
+ # os.remove('_warped_image.png')
167
+ response = requests.post(
168
+ url='http://127.0.0.1:11434/llama_cpp/v2/vision',
169
+ headers={
170
+ 'x-version': '2024-05-21',
171
+ 'content-type': 'application/json',
172
+ },
173
+ data=json.dumps({
174
+ 'task': PROMPT_VISION,
175
+ 'model': ENGINE[1],
176
+ 'image_path': relative_path('_warped_image.webp'),
177
+ }),
178
+ )
179
+ os.remove('_warped_image.webp')
180
+ text: str = response.json()['text']
181
+ ingredients = {}
182
+ else:
183
+ raise ValueError(f'Unknown engine: {ENGINE[0]}')
184
+ if text != '':
185
+ if DEVICE == 'cuda':
186
+ n_gpu_layers = -1
187
+ else:
188
+ n_gpu_layers = 0
189
+ llm = Llama(
190
+ model_path=relative_path('llm.Q4_K_M.gguf'),
191
+ n_gpu_layers=n_gpu_layers,
192
+ )
193
+ llm_result = llm.create_chat_completion(
194
+ messages=[
195
+ {
196
+ 'role': 'system',
197
+ 'content': SYSTEM_PROMPT,
198
+ },
199
+ {
200
+ 'role': 'user',
201
+ 'content': PROMPT_LLM.replace('{{old_data}}', text),
202
+ },
203
+ ],
204
+ max_tokens=1024,
205
+ temperature=0,
206
+ # grammar=GRAMMAR,
207
+ )
208
+ try:
209
+ ingredients = json.loads(
210
+ '{' + llm_result['choices'][0]['message']['content'].split('{', 1)[-1].rsplit('}', 1)[0] + '}')
211
+ except Exception as e:
212
+ print(f"{llm_result=}")
213
+ raise e
214
+ animal_ingredients = [item for item in ingredients['Zutaten'] if item in ANIMAL]
215
+ sometimes_animal_ingredients = [item for item in ingredients['Zutaten'] if item in SOMETIMES_ANIMAL]
216
+ milk_ingredients = ([item for item in ingredients['Zutaten'] if item in MILK]
217
+ + [item for item in ingredients['Verunreinigungen'] if item in MILK])
218
+ gluten_ingredients = ([item for item in ingredients['Zutaten'] if item in GLUTEN]
219
+ + [item for item in ingredients['Verunreinigungen'] if item in GLUTEN])
220
+ print('=' * 64)
221
+ print('Zutaten: ' + ', '.join(ingredients['Zutaten']))
222
+ print('=' * 64)
223
+ print(('Kann Spuren von ' + ', '.join(ingredients['Verunreinigungen']) + ' enthalten.')
224
+ if len(ingredients['Verunreinigungen']) > 0 else 'ohne Verunreinigungen')
225
+ print('=' * 64)
226
+ print('Gefundene tierische Zutaten: '
227
+ + (', '.join(animal_ingredients) if len(animal_ingredients) > 0 else 'keine'))
228
+ print('=' * 64)
229
+ print('Gefundene potenzielle tierische Zutaten: '
230
+ + (', '.join(sometimes_animal_ingredients) if len(sometimes_animal_ingredients) > 0 else 'keine'))
231
+ print('=' * 64)
232
+ print('Gefundene Milchprodukte: ' + (', '.join(milk_ingredients) if len(milk_ingredients) > 0 else 'keine'))
233
+ print('=' * 64)
234
+ print('Gefundene Gluten: ' + (', '.join(gluten_ingredients) if len(gluten_ingredients) > 0 else 'keine'))
235
+ print('=' * 64)
236
+ print(LEGAL_NOTICE)
237
+ print('=' * 64)
238
+
239
+
240
+ if __name__ == '__main__':
241
+ main()