File size: 4,839 Bytes
d2f3f0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
import torchvision
import os
import json
from PIL import Image
from datetime import datetime
__all__ = [
'current_time',
'relative_path',
'NeuralNet',
'DEVICE',
'IMAGE_SIZE',
'TRANSFORM',
'MARGIN',
'GRID_SIZE',
'decrease_size',
'PROMPT_LLM',
'PROMPT_CLAUDE',
'PROMPT_VISION',
'EOS',
'GRAMMAR',
'SYSTEM_PROMPT',
'ANIMAL',
'SOMETIMES_ANIMAL',
'MILK',
'GLUTEN',
'LEGAL_NOTICE',
]
MARGIN = 0.1
GRID_SIZE = 4096
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGE_SIZE = (224, 224)
TRANSFORM = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
with open('prompt_llm.md', 'r', encoding='utf-8') as _f:
PROMPT_LLM = _f.read()
with open('prompt_claude.md', 'r', encoding='utf-8') as _f:
PROMPT_CLAUDE = _f.read()
with open('prompt_vision.md', 'r', encoding='utf-8') as _f:
PROMPT_VISION = _f.read()
EOS = '\n<|im_end|>'
SYSTEM_PROMPT = 'Du bist ein hilfreicher assistant.'
with open('grammar.gbnf', 'r', encoding='utf-8') as _f:
GRAMMAR = _f.read()
with open('animal.json', 'r', encoding='utf-8') as _f:
ANIMAL = json.load(_f)
with open('sometimes_animal.json', 'r', encoding='utf-8') as _f:
SOMETIMES_ANIMAL = json.load(_f)
with open('milk.json', 'r', encoding='utf-8') as _f:
MILK = json.load(_f)
with open('gluten.json', 'r', encoding='utf-8') as _f:
GLUTEN = json.load(_f)
LEGAL_NOTICE = ('Dieses Programm ist nur für Forschungszwecke gedacht. Fehler können nicht ausgeschlossen werden und '
'sind wahrscheinlich vorhanden. Die Erkennung von Zutaten und Verunreinigungen ist nur zum schnellen '
'Aussortieren und nicht zum Überprüfen gedacht.')
def current_time() -> str:
return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
def relative_path(string: str) -> str:
return os.path.join(os.path.dirname(__file__), string)
class NeuralNet(torch.nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
# Load pre-trained ResNet model
self.backbone = torchvision.models.resnet18(pretrained=True)
# Modify the last layer to output 12 values
self.backbone.fc = torch.nn.Linear(self.backbone.fc.in_features, 12)
# Add a custom head for key-point detection
self.head = torch.nn.Sequential(
torch.nn.Conv2d(512, 256, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(256, 128, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(128, 64, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(64, 12, kernel_size=1),
torch.nn.AdaptiveAvgPool2d(1)
)
def forward(self, x):
# Check if we need to unsqueeze
if len(x.shape) == 3: # Shape [C, H, W]
x = x.unsqueeze(0) # Shape [1, C, H, W]
# Resize input to match ResNet input size if necessary
if x.shape[-2:] != (224, 224):
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
# Pass input through the backbone
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
x = self.backbone.layer1(x)
x = self.backbone.layer2(x)
x = self.backbone.layer3(x)
x = self.backbone.layer4(x)
# Pass input through the custom head
x = self.head(x)
# Flatten the output
x = x.view(x.size(0), -1)
return x
def decrease_size(input_path, output_path, max_size, max_side):
with Image.open(input_path) as img:
original_size = os.path.getsize(input_path)
width, height = img.size
if original_size <= max_size and width <= max_side and height <= max_side:
img.save(output_path, format=output_path.split('.')[-1].upper())
print("Image is already below the maximum size.")
while width > 24 and height > 24:
img_resized = img.resize((width, height), Image.Resampling.LANCZOS)
img_resized.save(output_path, format=output_path.split('.')[-1].upper())
if os.path.getsize(output_path) <= max_size and width <= max_side and height <= max_side:
print(f"Reduced image size to {os.path.getsize(output_path)} bytes.")
break
width, height = int(width * 0.9), int(height * 0.9)
if os.path.getsize(output_path) > max_size:
raise ValueError("Could not reduce PNG size below max_size by reducing resolution.")
|