# %%
import matplotlib.style
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import pickle
import torch
from pathlib import Path
from PIL import Image
from PIL import ImageDraw
import numpy as np
from collections import namedtuple
from logging import getLogger
logger = getLogger(__name__)
# %%
class Florence:
    def __init__(self, model_id:str, hack=False):
        if hack:
            return
        self.model = (
            AutoModelForCausalLM.from_pretrained(
                model_id, trust_remote_code=True, torch_dtype="auto"
            )
            .eval()
            .cuda()
        )
        self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
        self.model_id = model_id
    def run(self, img:Image, task_prompt:str, extra_text:str|None=None):
        logger.debug(f"run {task_prompt} {extra_text}")
        model, processor = self.model, self.processor
        prompt = task_prompt + (extra_text if extra_text else "")
        inputs = processor(text=prompt, images=img, return_tensors="pt").to(
            "cuda", torch.float16
        )
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=1024,
            early_stopping=False,
            do_sample=False,
            num_beams=3,
            #temperature=0.1,
        )
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        parsed_answer = processor.post_process_generation(
            generated_text,
            task=task_prompt,
            image_size=(img.width, img.height),
        )
        return parsed_answer
def model_init(hack=False):
    fl = Florence("microsoft/Florence-2-large", hack=hack)
    fl_ft = Florence("microsoft/Florence-2-large-ft", hack=hack)
    return fl, fl_ft
#%%
# florence-2 tasks
TASK_OD = "<OD>"
TASK_SEGMENTATION = '<REFERRING_EXPRESSION_SEGMENTATION>'
TASK_CAPTION = "<CAPTION_TO_PHRASE_GROUNDING>"
TASK_OCR = "<OCR_WITH_REGION>"
TASK_GROUNDING = "<CAPTION_TO_PHRASE_GROUNDING>"
#%%
AIModelResult = namedtuple('AIModelResult', 
                         ['img', 'img2', 'meter_bbox', 'needle_polygons', 'circle_polygons', 'ocr1', 'ocr2'])
cached_results:dict[str, AIModelResult] = {}

#%%
def get_meter_bbox(fl:Florence, img:Image):
    task_prompt, extra_text = TASK_GROUNDING, "a circular meter with white background"
    parsed_answer = fl.run(img, task_prompt, extra_text)
    assert len(parsed_answer) == 1
    k,v = parsed_answer.popitem()
    assert 'bboxes' in v
    assert 'labels' in v
    assert len(v['bboxes']) == 1
    assert len(v['labels']) == 1
    assert v['labels'][0] == 'a circular meter'
    bbox = v['bboxes'][0]
    return bbox

def get_circles(fl:Florence, img2:Image, polygons:list):    
    logger.info('get_circle')
    img3 = Image.new('L', img2.size, color = 'black')
    draw = ImageDraw.Draw(img3)
    for polygon in polygons:
        draw.polygon(polygon, outline='white', width=3, fill='white')
    img2a = np.where(np.array(img3)[:,:,None]>0,  np.array(img2), 255)
    img4 = Image.fromarray(img2a)
    parsed_answer = fl.run(img4, TASK_SEGMENTATION, "a circle")    
    assert len(parsed_answer) == 1
    k,v = parsed_answer.popitem()
    assert 'polygons' in v
    assert len(v['polygons']) == 1
    return v['polygons'][0]

def get_needle_polygons(fl:Florence, img2:Image):
    parsed_answer = fl.run(img2, TASK_SEGMENTATION, "the long narrow black needle hand pass through the center of the cicular meter")
    assert len(parsed_answer) == 1
    k,v = parsed_answer.popitem()
    assert 'polygons' in v
    assert len(v['polygons']) == 1
    needle_polygons = v['polygons'][0]
    return needle_polygons

def get_ocr(fl:Florence, img2:Image):
    logger.info('get_ocr')
    parsed_answer = fl.run(img2, TASK_OCR)
    assert len(parsed_answer)==1
    k,v = parsed_answer.popitem()
    return v

def get_ai_model_result(img:Image.Image|Path|str, fl:Florence, fl_ft:Florence):
    logger.info("get_ai_model_result")
    if isinstance(img, Path):
        key = img.parts[-1]
    elif isinstance(img, str):
        key = img.split('/')[-1]
    else:
        key = None
    if key is not None and key in cached_results:
        return cached_results[key]
    if isinstance(img, (Path, str)):        
        img = Image.open(img)
    meter_bbox = get_meter_bbox(fl, img)
    logger.info("get meter_bbox")
    img2 = img.crop(meter_bbox)
    needle_polygons = get_needle_polygons(fl, img2)
    logger.info("get needle_polygons")
    result = AIModelResult(img, img2, meter_bbox, needle_polygons,
                            get_circles(fl, img2, needle_polygons),
                            get_ocr(fl, img2),
                            get_ocr(fl_ft, img2)
                            )            
    if key is not None:
        cached_results[key] = result
    return result
#%%
from skimage.measure import regionprops
from skimage.measure import EllipseModel
from skimage.draw import ellipse_perimeter
def get_regionprops(polygons:list) -> regionprops:
    logger.info('get_regionprops')
    coords = np.concatenate(polygons).reshape(-1, 2)
    size = tuple( (coords.max(axis=0)+2).astype('int') )
    img = Image.new('L', size, color = 'black')
    # draw circle polygon
    draw = ImageDraw.Draw(img)
    for polygon in polygons:
        draw.polygon(polygon, outline='white', width=1, fill='white')
    # use skimage to find the mass center of the circle
    circle_imga = (np.array(img)>0).astype(np.uint8)
    property = regionprops(circle_imga)[0]
    return property
def estimate_ellipse(coords, enlarge_factor=1.0):
    em = EllipseModel()
    em.estimate(coords[:, ::-1])
    y, x, a, b, theta = em.params
    a, b = a*enlarge_factor, b*enlarge_factor
    em_params = np.round([y,x, a, b]).astype('int')
    c, r = ellipse_perimeter(*em_params, orientation=-theta)
    return em_params, theta, (c, r)
def estimate_line(coords):
    lm = LineModelND()
    lm.estimate(coords)
    return lm.params
#%%
#%%
from matplotlib import pyplot as plt
import matplotlib
from skimage.measure import LineModelND, ransac
matplotlib.style.use('dark_background')
def rotate_theta(theta):
    return ((theta + 3*np.pi/2)%(2*np.pi))/(2*np.pi)*360
kg_cm2_labels = list(map(str, [1,3,5,7,9,11]))
psi_labels = list(map(str, range(20, 180, 20)))

# lousy decoupling 
MeterResult = namedtuple('MeterResult', [
                                         'result', 
                                         'needle_psi', 
                                         'needle_kg_cm2', 
                                         'needle_theta', 
                                         'orign',
                                         'direction',
                                         'center',

                                         'lm', 
                                         'inliers',

                                         'kg_cm2_texts',                                          
                                         'psi_texts', 
                                         'kg_cm2_centers',
                                         'psi_centers',
                                         'kg_cm2_theta',
                                         'psi_theta',
                                         'kg_cm2_psi',
                                         'psi'                                 ,
                                         ])

def read_meter(img:Image.Image|str|Path, fl, fl_ft):
    # ai model results
    logger.info("read_meter")
    result = get_ai_model_result(img, fl, fl_ft)
    logger.info('ai model done')
    
    # needle direction
    coords = np.concatenate(result.needle_polygons).reshape(-1, 2)
    orign, direction = estimate_line(coords)
    logger.info('needle direction done')

    
    # calculate the meter center 
    circle_props = get_regionprops(result.circle_polygons)
    center = circle_props.centroid[::-1]

    # XXX: the needle direction is from center to orign
    if (orign - center) @ direction < 0:
        direction = -direction

    # calculate the needle theta
    needle_theta = rotate_theta(np.arctan2(direction[1], direction[0]))

    # calulate ocr texts to find kg/cm2 and psi labels
    ocr1, ocr2 = result.ocr1, result.ocr2
    kg_cm2_texts = {}
    psi_texts = {}
    quad_boxes = ocr1['quad_boxes']+ocr2['quad_boxes']
    labels = ocr1['labels']+ocr2['labels']
    for qbox, label in zip(quad_boxes, labels):
        if label in kg_cm2_labels:
            kg_cm2_texts[int(label)]=qbox
        if label in psi_labels:
            psi_texts[int(label)]=qbox
    # calculate the center of kg/cm2 and psi labels
    kg_cm2_centers = np.array(list(kg_cm2_texts.values())).reshape(-1, 4, 2).mean(axis=1)    
    psi_centers = np.array(list(psi_texts.values())).reshape(-1, 4, 2).mean(axis=1)
    
    # convert kg/cm2 and psi labels to polar coordinates, origin is the center of the meter
    # the angle is in degree which is more intuitive
    kg_cm2_coords = kg_cm2_centers - center
    kg_cm2_theta = rotate_theta(np.arctan2(kg_cm2_coords[:, 1], kg_cm2_coords[:, 0]))
    psi_coords = psi_centers - center
    psi_theta = rotate_theta(np.arctan2(psi_coords[:, 1], psi_coords[:, 0]))

    # convert kg_cm2 to psi for fitting a line model
    kg_cm2 = np.array(list(kg_cm2_texts.keys()))
    kg_cm2_psi = kg_cm2 * 14.223    
    # combine kg/cm2 and psi labels to fit a line model
    psi = np.array(list(psi_texts.keys()))
    Y = np.concatenate([kg_cm2_psi, psi])
    X = np.concatenate([kg_cm2_theta, psi_theta])
    data = np.stack([X, Y], axis=1)    
    # run ransac to robustly fit a line model 
    lm, inliers = ransac(data, LineModelND, min_samples=3, 
           residual_threshold=15, 
           max_trials=2)

    # use the model to calculated the needle psi and kg/cm2
    needle_psi = lm.predict(needle_theta)[1]
    needle_kg_cm2 = needle_psi / 14.223
    logger.info('meter result done')

    return MeterResult(result=result,
                          needle_psi=needle_psi,
                          needle_kg_cm2=needle_kg_cm2,
                          needle_theta=needle_theta,
                          orign=orign,
                          direction=direction,
                          center=center,
                          lm=lm,
                          inliers=data[inliers].T,
                          kg_cm2_texts=kg_cm2_texts,
                          psi_texts=psi_texts,
                          kg_cm2_centers=kg_cm2_centers,
                          psi_centers=psi_centers,
                          kg_cm2_theta=kg_cm2_theta,
                          psi_theta=psi_theta,
                          kg_cm2_psi=kg_cm2_psi,
                          psi=psi,
    )


def more_visualization_data(meter_result:MeterResult):
    logger.info('more visualization')
    result = meter_result.result
    center = meter_result.center
    # following calculations are for visualization and debugging
    # calculate the needle head(farest point from center)
    needle_coordinates = np.concatenate(result.needle_polygons).reshape(-1, 2)
    needle_length = np.linalg.norm(needle_coordinates - center,axis=1)
    farest_idx = np.argmax(needle_length)
    needle_head = needle_coordinates[farest_idx]
    needle_head_length = needle_length[farest_idx]
    direction = meter_result.direction * needle_head_length
        
    # inliners data
    inlier_theta, inlier_psi = meter_result.inliers
    
    # predict psi from 0 to 360
    predict_theta = np.linspace(0, 360, 100)
    predict_psi = meter_result.lm.predict(predict_theta)[:, 1]
    return inlier_theta, inlier_psi, predict_theta, predict_psi, needle_head, direction

def visualization(meter_result:MeterResult):
    logger.info('visualization')
    result = meter_result.result
    center = meter_result.center
    needle_psi, needle_kg_cm2 = meter_result.needle_psi, meter_result.needle_kg_cm2
    inlier_theta, inlier_psi, predict_theta, predict_psi, needle_head, direction = more_visualization_data(meter_result)
    # drawing and visualization
    draw = ImageDraw.Draw(result.img2.copy())    
    # draw needle polygons
    for polygon in result.needle_polygons:
        draw.polygon(polygon, outline='red', width=3)
    
    # draw center circle
    draw = ImageDraw.Draw(draw._image.convert('RGBA'))
    
    draw2 = ImageDraw.Draw(Image.new('RGBA', draw._image.size, (0,0,0,0)))
    for polygon in result.circle_polygons:
        draw2.polygon(polygon, outline='purple', width=1, fill = (255,128,255,100))
    img = Image.alpha_composite(draw._image, draw2._image)
    draw = ImageDraw.Draw(img.convert('RGB'))
    
    # draw needle direction
    draw.line((center[0], center[1], center[0]+direction[0], center[1]+direction[1]), fill='yellow', width=3)
    # draw a dot at center
    draw.ellipse((center[0]-5, center[1]-5, center[0]+5, center[1]+5), outline='yellow', width=3)
    # draw a dot at needle_head
    draw.ellipse((needle_head[0]-5, needle_head[1]-5, needle_head[0]+5, needle_head[1]+5), outline='yellow', width=3)

    for x,y in meter_result.kg_cm2_centers:
        draw.ellipse((x-3, y-3, x+3, y+3), outline='blue', width=3)
    for x,y in meter_result.psi_centers:
        draw.ellipse((x-3, y-3, x+3, y+3), outline='green', width=3)
    for label,quad_box in meter_result.kg_cm2_texts.items():
        draw.polygon(quad_box, outline='blue', width=1)
        draw.text((quad_box[0], quad_box[1]-10), str(label), fill='blue', anchor='ls')
    for label,quad_box in meter_result.psi_texts.items():
        draw.polygon(quad_box, outline='green', width=1)
        draw.text((quad_box[0], quad_box[1]-10), str(label), fill='green', anchor='ls')

    if len(meter_result.kg_cm2_centers) >4:
        # the ellipse of kg/cm2 labels, currently only for visualization
        em_params, theta, (c, r) = estimate_ellipse(meter_result.kg_cm2_centers)
        y, x = em_params[:2]
        draw.ellipse((x-5, y-5, x+5, y+5), outline='blue', width=1)
        imga = np.array(draw._image)
        imga[c,r] = (0, 0, 255)
        draw = ImageDraw.Draw(Image.fromarray(imga))

    if len(meter_result.psi_centers) >4:
        # the ellipse of psi labels, currently only for visualization
        em_params, theta, (c, r) = estimate_ellipse(meter_result.psi_centers)
        draw.ellipse((x-5, y-5, x+5, y+5), outline='green', width=1)
        imga = np.array(draw._image)
        imga[c,r] = (0, 255, 0)
        y, x = em_params[:2]
        draw = ImageDraw.Draw(Image.fromarray(imga))
    draw.text((needle_head[0]-10, needle_head[1]-10),
              f'psi={needle_psi:.1f} kg_cm2={needle_kg_cm2:.2f}',anchor='ls',
              fill='yellow')
    plt.plot(predict_theta, predict_psi, color='red', alpha=0.5)
    plt.plot(meter_result.kg_cm2_theta, meter_result.kg_cm2_psi, 'o', color='#77F')
    plt.plot(meter_result.psi_theta, meter_result.psi, 'o', color='#7F7')
    plt.plot(inlier_theta, inlier_psi, 'x', color='red', alpha=0.5)
    plt.vlines(meter_result.needle_theta, 0, 160, colors='yellow', alpha=0.5)
    plt.hlines(meter_result.needle_psi, 0, 360, colors='yellow', alpha=0.5)

    plt.text(meter_result.needle_theta-20, meter_result.needle_psi-20, 
             f'psi={needle_psi:.1f} kg_cm2={needle_kg_cm2:.2f}', color='yellow')
    plt.xlim(0, 360)
    plt.ylim(0, 160)
    logger.info('visualization done')
    return draw._image, plt.gcf()

def clear_cache():
    cached_results.clear()
def save_cache():
    pickle.dump(cached_results, open('cached_results.pkl', 'wb'))
def load_cache():
    global cached_results
    cached_results = pickle.load(open('cached_results.pkl', 'rb'))
#%%
if __name__ == '__main__':
    from io import BytesIO
    from IPython.display import display
    fl, fl_ft = model_init(hack=False)
    #load_cache()
    clear_cache()
    imgs = list(Path('images/good').glob('*.jpg'))#[-1:]
    W, H = 640, 480
    for img_fn in imgs:
        print(img_fn)
        meter_result = read_meter(img_fn, fl, fl_ft)
        img, fig = visualization(meter_result)
        # resize draw._image to fit WxH and keep aspect ratio
        w, h = meter_result.result.img2.size    
        if w/W > h/H:
            w, h = W, int(h*W/w)
        else:
            w, h = int(w*H/h), H
        display(img.resize((w, h)))
        # convert figure to PIL image using io.BytesIO
        buf = BytesIO()
        fig.savefig(buf, format='png')
        buf.seek(0)
        fig_img = Image.open(buf)
        display(fig_img)
        # clear plot 
        plt.clf()






# %%