File size: 3,922 Bytes
0efa81e
3d53d43
ed53c37
3d53d43
0efa81e
3d53d43
8f95bbc
0efa81e
f71f3be
9dad4e7
 
0efa81e
3d53d43
 
 
 
 
0efa81e
 
 
 
 
 
 
 
72c2e54
 
0efa81e
72c2e54
 
0efa81e
 
52d1916
 
c96f4ef
 
 
3d53d43
 
c96f4ef
3d53d43
c96f4ef
 
 
3d53d43
 
0efa81e
 
3d53d43
 
0efa81e
3d53d43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0efa81e
3d53d43
 
 
 
 
 
 
ed4e27f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d53d43
ed4e27f
 
 
 
 
0efa81e
 
72c2e54
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
from transformers import AutoTokenizer, AutoProcessor, VisionEncoderDecoderModel
from vllm import LLM, SamplingParams
from PIL import Image

# Load the language model and tokenizer from Hugging Face
model_name = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Initialize vLLM with CPU configuration
vllm_model = LLM(model=model_name, tensor_parallel_size=1, device="cpu")

# Load the OCR model and processor
ocr_model_name = "microsoft/trocr-small-handwritten"
ocr_model = VisionEncoderDecoderModel.from_pretrained(ocr_model_name)
ocr_processor = AutoProcessor.from_pretrained(ocr_model_name)

def generate_response(prompt, max_tokens, temperature, top_p):
    # Define sampling parameters
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
    )
    
    # Generate text using vLLM (input is the raw string `prompt`)
    output = vllm_model.generate(prompt, sampling_params)
    
    # Extract and decode the generated tokens
    generated_text = output[0].outputs[0].text
    return generated_text

def ocr_image(image_path):
    # Open the image from the file path
    image = Image.open(image_path).convert("RGB")
    
    # Preprocess the image for the OCR model
    pixel_values = ocr_processor(images=image, return_tensors="pt").pixel_values
    
    # Perform OCR inference
    outputs = ocr_model.generate(pixel_values)
    
    # Decode the generated tokens into text
    text = ocr_processor.batch_decode(outputs, skip_special_tokens=True)[0]
    return text

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# πŸš€ Hugging Face Integration with vLLM and OCR (CPU)")
    gr.Markdown("Upload an image to extract text using OCR or generate text using the vLLM integration.")

    with gr.Tab("Text Generation"):
        with gr.Row():
            with gr.Column():
                prompt_input = gr.Textbox(
                    label="Prompt",
                    placeholder="Enter your prompt here...",
                    lines=3,
                )
                max_tokens = gr.Slider(
                    label="Max Tokens",
                    minimum=10,
                    maximum=500,
                    value=100,
                    step=10,
                )
                temperature = gr.Slider(
                    label="Temperature",
                    minimum=0.1,
                    maximum=1.0,
                    value=0.7,
                    step=0.1,
                )
                top_p = gr.Slider(
                    label="Top P",
                    minimum=0.1,
                    maximum=1.0,
                    value=0.9,
                    step=0.1,
                )
                submit_button = gr.Button("Generate")
            
            with gr.Column():
                output_text = gr.Textbox(
                    label="Generated Text",
                    lines=10,
                    interactive=False,
                )
        
        submit_button.click(
            generate_response,
            inputs=[prompt_input, max_tokens, temperature, top_p],
            outputs=output_text,
        )

    with gr.Tab("OCR"):
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(
                    label="Upload Image",
                    type="filepath",  # Corrected type
                    image_mode="RGB",
                )
                ocr_submit_button = gr.Button("Extract Text")
            
            with gr.Column():
                ocr_output = gr.Textbox(
                    label="Extracted Text",
                    lines=10,
                    interactive=False,
                )
        
        ocr_submit_button.click(
            ocr_image,
            inputs=[image_input],
            outputs=ocr_output,
        )

# Launch the app
demo.launch()