Skip to content
Snippets Groups Projects

FLAIT: Detector API test script

  • Clone with SSH
  • Clone with HTTPS
  • Embed
  • Share
    The snippet can be accessed without any authentication.
    Authored by Thomas Johansen
    Edited
    test_serving_api.py 2.24 KiB
    from __future__ import annotations
    
    import argparse
    import json
    import typing
    from pathlib import Path
    from typing import Any, Literal, TypedDict
    
    import requests
    from PIL import Image, ImageDraw
    
    
    def main(args: argparse.Namespace):
        api_endpoint = f"http://{args.host}:{args.port}/predict"
    
        content_type = resolve_content_type(args.image_path)
        files = {
            "file": (args.image_path.name, open(args.image_path, "rb"), content_type),
        }
    
        response = requests.post(api_endpoint, files=files)
        response.raise_for_status()
    
        predictions = response.json()
        print(json.dumps(predictions, indent=2))
    
        if args.visualize:
            visualize_predictions(args.image_path, predictions)
    
    
    def resolve_content_type(image_path: Path):
        content_types = {".jpg": "image/jpeg", ".png": "image/png"}
        content_type = content_types[image_path.suffix.lower()]
    
        return content_type
    
    
    Predicate = Literal["connected to", "belongs to"]
    PREDICATES: tuple[Predicate, ...] = typing.get_args(Predicate)
    
    
    class CocoAnnotation(TypedDict):
        id: int
        image_id: int
        category_id: int
        iscrowd: int
        segmentation: list[Any]
        area: float
        bbox: list[float]
        confidence: float
        relationships: list[Relationship]
    
    
    class Relationship(TypedDict):
        subject: int
        predicate: Predicate
        object: int
    
    
    def visualize_predictions(
        image_path: Path,
        predictions: list[CocoAnnotation],
        confidence_threshold: float = 0.5,
    ):
        image = Image.open(image_path)
        draw = ImageDraw.Draw(image)
    
        for prediction in predictions:
            score = prediction["confidence"]
            if score < confidence_threshold:
                continue
    
            label = prediction["category_id"]
    
            bbox = prediction["bbox"]
            bbox = (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])
    
            draw.rectangle(bbox, outline="red")
            draw.text((bbox[0], bbox[1]), f"{label} ({score:.2f})", fill="red")
    
        image.show()
    
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument("--host", type=str, default="localhost")
        parser.add_argument("--port", type=int, default=3000)
        parser.add_argument("--visualize", action="store_true")
        parser.add_argument("image_path", type=Path)
    
        main(parser.parse_args())
    0% Loading or .
    You are about to add 0 people to the discussion. Proceed with caution.
    Please register or to comment