File size: 3,710 Bytes
4d76f31
 
 
 
2000233
 
4d76f31
bd5f278
5764a9a
2000233
37a5fd4
2000233
 
 
 
 
c509235
2000233
 
 
 
 
 
 
f979edc
 
 
2000233
 
 
4d76f31
 
 
 
 
 
 
486107f
 
a378385
e2cf15d
 
c82243e
e2cf15d
8375dd0
e2cf15d
 
 
 
 
 
 
 
ac80de9
1ede285
42d101f
 
 
7a9b799
f40e6b7
 
 
 
 
a20083b
596db7c
 
 
 
 
 
a20083b
 
7ca6f01
a20083b
 
 
 
 
 
 
 
 
 
5efe93f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import requests
from paddleocr import PaddleOCR, draw_ocr
from PIL import Image
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

img = "input_data/ocr_input/korean1.jpg"
text = "ν‘œν˜„μ΄ μ„œνˆ° 것도 잘λͺ»μΈκ°€μš”. λ‚˜ μ°¨κ°€μš΄ λ„μ‹œμ— λ”°λœ»ν•œ μ—¬μž”λ°. κ·Έλƒ₯ μ’‹μ•„ν•œλ‹¨ 말도 μ•ˆ λ˜λŠ”κ°€μš”. μ†”μ§ν•˜κ²Œ λ‚œ λ§ν•˜κ³  μ‹Άμ–΄μš”"
model_id = "deepseek-ai/deepseek-llm-7b-chat"

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)

def text_inference(text, language):
    system_prompt = (
        f"Given the following {language} text, convert each word into their base form. Remove all duplicates. Return the base form words as a comma-separated list, and nothing else."
    )
    user_prompt = f"{system_prompt}\n\nText:\n{text}"

    input_ids = tokenizer.apply_chat_template([{"role": "user", "content": user_prompt}], return_tensors="pt").to(model.device)
    output_ids = model.generate(input_ids, max_new_tokens=256)
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

     # Parse response: take last line, split by commas
    last_line = output_text.strip().split("\n")[-1]
    words = [w.strip() for w in last_line.split(",") if w.strip()]
    return words

def ocr_inference(img, lang):
	ocr = PaddleOCR(use_angle_cls=True, lang=lang,use_gpu=False)
	img_path = img  
	result = ocr.ocr(img_path, cls=True)[0]
	image = Image.open(img_path).convert('RGB')
	boxes = [line[0] for line in result]
	txts = [line[1][0] for line in result]
	scores = [line[1][1] for line in result]
	return txts

def make_flashcards(words, language):

    system_prompt = (
        f"for each {language} word in the list, write a flashcard in this format: the word, then its definition, then an example sentence using the word, and then a translation of example sentence"
    )
    user_prompt = f"{system_prompt}\n\nWords:\n{words}"

    input_ids = tokenizer.apply_chat_template([{"role": "user", "content": user_prompt}], return_tensors="pt").to(model.device)
    output_ids = model.generate(input_ids, max_new_tokens=256)
    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # Parse response: take last line, split by commas
    last_line = output_text.strip().split("\n")[-1]
    output = [w.strip() for w in last_line.split(":") if w.strip()]
    return output

# words=text_inference(text, "korean")
# print("OUTPUT TOUT OUETOI EIFJ IEFJ",words)
# print("flashcard output:",make_flashcards(words, "korean"))

# print("OCR OUTPUT: ", ocr_inference(img, "korean"))
# words=text_inference(text, "korean")
# print("TEXT INPUT: ", text)
# print("WORD PARSING: ",words)
# print("flashcard output:",make_flashcards(words, "korean"))

examples = [
    [{"text": "@RolmOCR OCR the Text in the Image", "files": ["rolm/1.jpeg"]}],
    [{"text": "@RolmOCR Explain the Ad in Detail", "files": ["examples/videoplayback.mp4"]}],
    [{"text": "@RolmOCR OCR the Image", "files": ["rolm/3.jpeg"]}],
    [{"text": "Extract as JSON table from the table", "files": ["examples/4.jpg"]}],
]

demo = gr.ChatInterface(
    fn=ocr_inference,
    description="# **Multimodal OCR `@RolmOCR and Default Qwen2VL OCR`**",
    examples=examples,
    textbox=gr.MultimodalTextbox(
        label="Query Input", 
        file_types=["image", "video"], 
        file_count="multiple", 
        placeholder="Use tag @RolmOCR for RolmOCR, or leave blank for default Qwen2VL OCR"
    ),
    stop_btn="Stop Generation",
    multimodal=True,
    cache_examples=False,
)

demo.launch(debug=True)