Upload 2 files
Browse files- handler.py +45 -0
 - requirements.txt +10 -0
 
    	
        handler.py
    ADDED
    
    | 
         @@ -0,0 +1,45 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Dict, Any
         
     | 
| 2 | 
         
            +
            from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
         
     | 
| 3 | 
         
            +
            from PIL import Image
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            class EndpointHandler():
         
     | 
| 7 | 
         
            +
                def __init__(self, path=""):
         
     | 
| 8 | 
         
            +
                    # Ładowanie modelu i procesora
         
     | 
| 9 | 
         
            +
                    self.device = "cuda" if torch.cuda.is_available() else "cpu"
         
     | 
| 10 | 
         
            +
                    self.model = AutoModelForZeroShotObjectDetection.from_pretrained(path).to(self.device)
         
     | 
| 11 | 
         
            +
                    self.processor = AutoProcessor.from_pretrained(path)
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
         
     | 
| 14 | 
         
            +
                    # Sprawdź, czy dane wejściowe zawierają wymagane pola
         
     | 
| 15 | 
         
            +
                    if "image" not in data or "text" not in data:
         
     | 
| 16 | 
         
            +
                        return {"error": "Payload must contain 'image' (base64 or URL) and 'text' (queries)."}
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                    # Załaduj obraz
         
     | 
| 19 | 
         
            +
                    image = Image.open(data["image"]) if isinstance(data["image"], str) else data["image"]
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    # Pobierz teksty zapytań
         
     | 
| 22 | 
         
            +
                    text_queries = data["text"]
         
     | 
| 23 | 
         
            +
                    if isinstance(text_queries, list):
         
     | 
| 24 | 
         
            +
                        text_queries = ". ".join([t.lower().strip() + "." for t in text_queries])
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                    # Przygotuj dane wejściowe
         
     | 
| 27 | 
         
            +
                    inputs = self.processor(images=image, text=text_queries, return_tensors="pt").to(self.device)
         
     | 
| 28 | 
         
            +
                    
         
     | 
| 29 | 
         
            +
                    # Przeprowadź inferencję
         
     | 
| 30 | 
         
            +
                    with torch.no_grad():
         
     | 
| 31 | 
         
            +
                        outputs = self.model(**inputs)
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    # Post-process detekcji
         
     | 
| 34 | 
         
            +
                    results = self.processor.post_process_grounded_object_detection(
         
     | 
| 35 | 
         
            +
                        outputs,
         
     | 
| 36 | 
         
            +
                        inputs.input_ids,
         
     | 
| 37 | 
         
            +
                        box_threshold=0.4,
         
     | 
| 38 | 
         
            +
                        text_threshold=0.3,
         
     | 
| 39 | 
         
            +
                        target_sizes=[image.size[::-1]]
         
     | 
| 40 | 
         
            +
                    )
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    # Przygotuj wynik
         
     | 
| 43 | 
         
            +
                    return {
         
     | 
| 44 | 
         
            +
                        "detections": results
         
     | 
| 45 | 
         
            +
                    }
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,10 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            torch
         
     | 
| 2 | 
         
            +
            torchvision
         
     | 
| 3 | 
         
            +
            transformers
         
     | 
| 4 | 
         
            +
            addict
         
     | 
| 5 | 
         
            +
            yapf
         
     | 
| 6 | 
         
            +
            timm
         
     | 
| 7 | 
         
            +
            numpy
         
     | 
| 8 | 
         
            +
            opencv-python
         
     | 
| 9 | 
         
            +
            supervision>=0.22.0
         
     | 
| 10 | 
         
            +
            pillow
         
     |