File size: 3,518 Bytes
9453eac
5c3f634
4abc449
12a2f23
 
9453eac
c5a772e
 
 
54a29b3
c5a772e
 
6ecc4f4
 
c5a772e
12a2f23
768d260
12a2f23
c5a772e
3b15416
f774dbf
 
12a2f23
f774dbf
12a2f23
6ecc4f4
12a2f23
 
 
57fa964
12a2f23
 
f774dbf
12a2f23
 
 
 
 
 
57fa964
12a2f23
 
 
 
 
57fa964
12a2f23
57fa964
12a2f23
 
c5a772e
12a2f23
 
 
 
 
 
 
 
c5a772e
12a2f23
 
 
 
c5a772e
9453eac
 
c5a772e
12a2f23
 
279ab91
12a2f23
 
28a9f71
57fa964
12a2f23
 
 
57fa964
9453eac
 
4abc449
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import easyocr
import numpy as np
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch

class OCRProcessor:
    def __init__(self):
        self.reader = easyocr.Reader(['fa'])

    def extract_text(self, image: np.ndarray) -> str:
        try:
            results = self.reader.readtext(image, detail=0, paragraph=True)
            return "\n".join(results) if results else ""
        except Exception as e:
            return f"خطا در پردازش OCR: {str(e)}"

class PersianQAModel:
    def __init__(self):
        model_name = "OmidSakaki/fa_qa_nlp_model"
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForQuestionAnswering.from_pretrained(model_name)
        except Exception as e:
            raise RuntimeError(f"خطا در بارگذاری مدل پرسش و پاسخ: {str(e)}")

    def answer_question(self, context: str, question: str) -> str:
        if not context.strip() or not question.strip():
            return "متن یا سوال وارد نشده است."
        try:
            inputs = self.tokenizer.encode_plus(
                question, context, return_tensors='pt', truncation=True, max_length=512
            )
            input_ids = inputs["input_ids"].tolist()[0]
            outputs = self.model(**inputs)
            answer_start = torch.argmax(outputs.start_logits)
            answer_end = torch.argmax(outputs.end_logits) + 1
            answer = self.tokenizer.convert_tokens_to_string(
                self.tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end])
            )
            # حذف توکن‌های اضافی یا فاصله
            answer = answer.strip()
            if not answer or answer in ['[CLS]', '[SEP]', '[PAD]']:
                return "جوابی یافت نشد."
            return answer
        except Exception as e:
            return f"خطا در مدل پرسش و پاسخ: {str(e)}"

ocr_processor = OCRProcessor()
qa_model = PersianQAModel()

def pipeline(image, question):
    # استخراج متن از تصویر
    context = ocr_processor.extract_text(image)
    # پاسخ به سوال بر اساس متن
    answer = qa_model.answer_question(context, question)
    return context, answer

with gr.Blocks(title="استخراج متن و پاسخ به سوال از تصویر فارسی") as app:
    gr.Markdown("""
    # سیستم هوشمند پرسش و پاسخ از روی تصویر فارسی
    1. تصویر را بارگذاری کنید تا متن استخراج شود.
    2. سوال خود را به فارسی تایپ کنید.
    3. دکمه «پاسخ» را بزنید.
    """)
    with gr.Row():
        with gr.Column():
            img_input = gr.Image(label="تصویر ورودی", type="numpy")
            question_input = gr.Textbox(label="سوال شما به فارسی", placeholder="مثلاً: نویسنده این متن کیست؟", lines=1)
            process_btn = gr.Button("پاسخ")
        with gr.Column():
            context_output = gr.Textbox(label="متن استخراج شده", lines=10, max_lines=None, interactive=False)
            answer_output = gr.Textbox(label="پاسخ مدل", lines=3, max_lines=None, interactive=False)

    process_btn.click(
        fn=pipeline,
        inputs=[img_input, question_input],
        outputs=[context_output, answer_output]
    )

if __name__ == "__main__":
    app.launch()