dan-durbin commited on
Commit
9aa5f5e
·
1 Parent(s): 9049c49

basic example now working

Browse files
Files changed (1) hide show
  1. app.py +84 -4
app.py CHANGED
@@ -1,7 +1,87 @@
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import torch
3
+ from transformers import AutoConfig, AutoModelForVision2Seq, AutoProcessor
4
+ from PIL import Image, ImageDraw
5
+ import re
6
  import gradio as gr
7
 
8
+ repo = "microsoft/kosmos-2.5"
9
+ device = "cuda"
10
 
11
+ config = AutoConfig.from_pretrained(repo)
12
+ dtype = torch.float16
13
+
14
+ model = AutoModelForVision2Seq.from_pretrained(
15
+ repo, device_map=device, torch_dtype=dtype, config=config
16
+ )
17
+
18
+ processor = AutoProcessor.from_pretrained(repo)
19
+
20
+ prompt = "<ocr>" # Options are '<ocr>' and '<md>'
21
+
22
+
23
+ @spaces.GPU
24
+ def process_image(image_path):
25
+ image = Image.open(image_path)
26
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
27
+
28
+ height, width = inputs.pop("height"), inputs.pop("width")
29
+ raw_width, raw_height = image.size
30
+ scale_height = raw_height / height
31
+ scale_width = raw_width / width
32
+
33
+ inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()}
34
+ inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype)
35
+
36
+ generated_ids = model.generate(**inputs, max_new_tokens=2048)
37
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
38
+
39
+ return postprocess(generated_text, scale_height, scale_width, image)
40
+
41
+
42
+ def postprocess(y, scale_height, scale_width, original_image):
43
+ y = y.replace(prompt, "")
44
+
45
+ if "<md>" in prompt:
46
+ return y, original_image
47
+
48
+ pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>"
49
+ bboxs_raw = re.findall(pattern, y)
50
+
51
+ lines = re.split(pattern, y)[1:]
52
+ bboxs = [re.findall(r"\d+", i) for i in bboxs_raw]
53
+ bboxs = [[int(j) for j in i] for i in bboxs]
54
+
55
+ info = ""
56
+
57
+ # Create a copy of the original image to draw on
58
+ image_with_boxes = original_image.copy()
59
+ draw = ImageDraw.Draw(image_with_boxes)
60
+
61
+ for i in range(len(lines)):
62
+ box = bboxs[i]
63
+ x0, y0, x1, y1 = box
64
+
65
+ if not (x0 >= x1 or y0 >= y1):
66
+ x0 = int(x0 * scale_width)
67
+ y0 = int(y0 * scale_height)
68
+ x1 = int(x1 * scale_width)
69
+ y1 = int(y1 * scale_height)
70
+ info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}\n"
71
+
72
+ # Draw rectangle on the image
73
+ draw.rectangle([x0, y0, x1, y1], outline="red", width=2)
74
+
75
+ return image_with_boxes, info
76
+
77
+
78
+ iface = gr.Interface(
79
+ fn=process_image,
80
+ inputs=gr.Image(type="filepath"),
81
+ outputs=[
82
+ gr.Image(type="pil", label="Image with Bounding Boxes"),
83
+ gr.Textbox(label="Extracted Text"),
84
+ ],
85
+ )
86
+
87
+ iface.launch()