Spaces:
Sleeping
Sleeping
ThierryLapouge
commited on
Commit
·
a1eba61
1
Parent(s):
7ccff21
qualiclip+ model
Browse files- .gitignore +62 -0
- app.py +209 -7
- core.py +263 -16
- requirements.txt +8 -1
.gitignore
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
share/python-wheels/
|
20 |
+
*.egg-info/
|
21 |
+
.installed.cfg
|
22 |
+
*.egg
|
23 |
+
MANIFEST
|
24 |
+
|
25 |
+
# Virtual Environment
|
26 |
+
venv/
|
27 |
+
env/
|
28 |
+
ENV/
|
29 |
+
env.bak/
|
30 |
+
venv.bak/
|
31 |
+
|
32 |
+
# IDE
|
33 |
+
.vscode/
|
34 |
+
.idea/
|
35 |
+
*.swp
|
36 |
+
*.swo
|
37 |
+
*~
|
38 |
+
|
39 |
+
# Gradio
|
40 |
+
.gradio/
|
41 |
+
|
42 |
+
# Development files
|
43 |
+
CLAUDE.md
|
44 |
+
model_list.py
|
45 |
+
list of models.txt
|
46 |
+
test_*.py
|
47 |
+
|
48 |
+
# OS generated files
|
49 |
+
.DS_Store
|
50 |
+
.DS_Store?
|
51 |
+
._*
|
52 |
+
.Spotlight-V100
|
53 |
+
.Trashes
|
54 |
+
ehthumbs.db
|
55 |
+
Thumbs.db
|
56 |
+
|
57 |
+
# Logs
|
58 |
+
*.log
|
59 |
+
|
60 |
+
# Cache
|
61 |
+
.cache/
|
62 |
+
*.tmp
|
app.py
CHANGED
@@ -1,10 +1,212 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
-
import
|
3 |
-
|
4 |
-
metric = IQA("nima-vgg16-ava")
|
5 |
|
6 |
-
|
7 |
-
|
|
|
8 |
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py - Enhanced with Text Quality Assessment
|
2 |
import gradio as gr
|
3 |
+
from core import IQA, TextAwareIQA
|
4 |
+
import numpy as np
|
|
|
5 |
|
6 |
+
# Initialize both simple and advanced models
|
7 |
+
simple_metric = IQA("qualiclip+")
|
8 |
+
advanced_metric = TextAwareIQA("qualiclip+", text_weight=0.3)
|
9 |
|
10 |
+
def simple_assessment(image):
|
11 |
+
"""Original simple interface - backward compatible"""
|
12 |
+
score = simple_metric(image)
|
13 |
+
print(f"Simple score: {score}") # Debug output
|
14 |
+
return f"Quality Score: {score:.2f}/100" if score is not None else "Error: Invalid image"
|
15 |
+
|
16 |
+
def detailed_assessment(image, text_penalty_mode="balanced"):
|
17 |
+
"""Enhanced assessment with detailed breakdown"""
|
18 |
+
try:
|
19 |
+
# Get detailed analysis
|
20 |
+
details = advanced_metric.evaluate(image, text_penalty_mode=text_penalty_mode)
|
21 |
+
|
22 |
+
if details is None or 'error' in details:
|
23 |
+
return "Error: Could not analyze image", "", ""
|
24 |
+
|
25 |
+
# Format main result
|
26 |
+
main_result = f"**Combined Quality Score: {details['combined_score']:.2f}/100**"
|
27 |
+
|
28 |
+
# Format breakdown
|
29 |
+
breakdown = f"""
|
30 |
+
## Score Breakdown:
|
31 |
+
- **Traditional IQA ({details['model_used']}): {details['traditional_score']:.2f}/100**
|
32 |
+
- **Text Quality: {details['text_analysis']['text_quality_score']:.2f}/100**
|
33 |
+
- **Text Weight in Final Score: {details['text_weight']*100:.0f}%**
|
34 |
+
- **Penalty Mode: {details.get('penalty_mode', 'balanced')}**
|
35 |
+
"""
|
36 |
+
|
37 |
+
# Format text analysis details
|
38 |
+
text_analysis = details['text_analysis']
|
39 |
+
if text_analysis['text_detected']:
|
40 |
+
text_details = f"""
|
41 |
+
## Text Analysis:
|
42 |
+
- **Text Regions Detected: {text_analysis['text_regions']}**
|
43 |
+
- **Low Quality Text Regions: {text_analysis['low_quality_regions']}**
|
44 |
+
- **Average OCR Confidence: {text_analysis['avg_confidence']:.3f}**
|
45 |
+
- **Assessment: {text_analysis['details']}**
|
46 |
+
|
47 |
+
### Text Quality Interpretation:
|
48 |
+
- **90-100**: Excellent text clarity, no distortions
|
49 |
+
- **70-89**: Good text quality, minor artifacts
|
50 |
+
- **50-69**: Readable but noticeable distortions
|
51 |
+
- **30-49**: Poor text quality, significant distortions
|
52 |
+
- **0-29**: Severely distorted or unreadable text
|
53 |
+
"""
|
54 |
+
else:
|
55 |
+
text_details = """
|
56 |
+
## Text Analysis:
|
57 |
+
- **No text detected in image**
|
58 |
+
- **Using traditional IQA score only**
|
59 |
+
"""
|
60 |
+
|
61 |
+
return main_result, breakdown, text_details
|
62 |
+
|
63 |
+
except Exception as e:
|
64 |
+
error_msg = f"Error: {str(e)}"
|
65 |
+
return error_msg, "", ""
|
66 |
+
|
67 |
+
def batch_comparison(image):
|
68 |
+
"""Compare different penalty modes side by side"""
|
69 |
+
try:
|
70 |
+
modes = ['lenient', 'balanced', 'strict']
|
71 |
+
results = []
|
72 |
+
|
73 |
+
for mode in modes:
|
74 |
+
details = advanced_metric.evaluate(image, text_penalty_mode=mode)
|
75 |
+
if details and 'error' not in details:
|
76 |
+
results.append({
|
77 |
+
'mode': mode.title(),
|
78 |
+
'score': details['combined_score'],
|
79 |
+
'traditional': details['traditional_score'],
|
80 |
+
'text': details['text_analysis']['text_quality_score'],
|
81 |
+
'text_detected': details['text_analysis']['text_detected']
|
82 |
+
})
|
83 |
+
|
84 |
+
if not results:
|
85 |
+
return "Error: Could not analyze image"
|
86 |
+
|
87 |
+
# Format comparison
|
88 |
+
comparison = "## Penalty Mode Comparison:\n\n"
|
89 |
+
for result in results:
|
90 |
+
comparison += f"**{result['mode']} Mode:**\n"
|
91 |
+
comparison += f"- Combined Score: {result['score']:.2f}/100\n"
|
92 |
+
comparison += f"- Traditional IQA: {result['traditional']:.2f}/100\n"
|
93 |
+
comparison += f"- Text Quality: {result['text']:.2f}/100\n\n"
|
94 |
+
|
95 |
+
if results[0]['text_detected']:
|
96 |
+
comparison += """
|
97 |
+
### Mode Explanations:
|
98 |
+
- **Lenient**: Minimal penalty for text issues (10% weight)
|
99 |
+
- **Balanced**: Moderate penalty for text issues (30% weight)
|
100 |
+
- **Strict**: Heavy penalty for text issues (50% weight)
|
101 |
+
"""
|
102 |
+
else:
|
103 |
+
comparison += "\n*No text detected - all modes return same score*"
|
104 |
+
|
105 |
+
return comparison
|
106 |
+
|
107 |
+
except Exception as e:
|
108 |
+
return f"Error in batch comparison: {str(e)}"
|
109 |
+
|
110 |
+
# Create Gradio interface with tabs
|
111 |
+
with gr.Blocks(title="Enhanced Image Quality Assessment", theme=gr.themes.Soft()) as demo:
|
112 |
+
gr.Markdown("""
|
113 |
+
# 🔍 Enhanced Image Quality Assessment
|
114 |
+
|
115 |
+
This tool evaluates image quality using both traditional perceptual metrics **AND** text-specific quality assessment.
|
116 |
+
Perfect for detecting letter distortions in AI-generated images that traditional IQA tools miss.
|
117 |
+
""")
|
118 |
+
|
119 |
+
with gr.Tabs():
|
120 |
+
# Simple tab - backward compatible
|
121 |
+
with gr.TabItem("🚀 Quick Assessment"):
|
122 |
+
gr.Markdown("### Simple quality score (backward compatible with your original setup)")
|
123 |
+
with gr.Row():
|
124 |
+
with gr.Column():
|
125 |
+
input_image_simple = gr.Image(
|
126 |
+
label="Upload Image",
|
127 |
+
type="pil",
|
128 |
+
format="png"
|
129 |
+
)
|
130 |
+
assess_btn_simple = gr.Button("Assess Quality", variant="primary")
|
131 |
+
with gr.Column():
|
132 |
+
output_simple = gr.Textbox(
|
133 |
+
label="Quality Score",
|
134 |
+
lines=2
|
135 |
+
)
|
136 |
+
|
137 |
+
assess_btn_simple.click(
|
138 |
+
simple_assessment,
|
139 |
+
inputs=input_image_simple,
|
140 |
+
outputs=output_simple
|
141 |
+
)
|
142 |
+
|
143 |
+
# Detailed tab - new enhanced features
|
144 |
+
with gr.TabItem("🔬 Detailed Analysis"):
|
145 |
+
gr.Markdown("### Comprehensive quality assessment with text-specific evaluation")
|
146 |
+
with gr.Row():
|
147 |
+
with gr.Column():
|
148 |
+
input_image_detailed = gr.Image(
|
149 |
+
label="Upload Image",
|
150 |
+
type="pil",
|
151 |
+
format="png"
|
152 |
+
)
|
153 |
+
penalty_mode = gr.Radio(
|
154 |
+
choices=["lenient", "balanced", "strict"],
|
155 |
+
value="balanced",
|
156 |
+
label="Text Penalty Mode",
|
157 |
+
info="How harshly to penalize text quality issues"
|
158 |
+
)
|
159 |
+
assess_btn_detailed = gr.Button("Detailed Analysis", variant="primary")
|
160 |
+
|
161 |
+
with gr.Column():
|
162 |
+
output_main = gr.Markdown(label="Final Score")
|
163 |
+
output_breakdown = gr.Markdown(label="Score Breakdown")
|
164 |
+
output_text_details = gr.Markdown(label="Text Analysis")
|
165 |
+
|
166 |
+
assess_btn_detailed.click(
|
167 |
+
detailed_assessment,
|
168 |
+
inputs=[input_image_detailed, penalty_mode],
|
169 |
+
outputs=[output_main, output_breakdown, output_text_details]
|
170 |
+
)
|
171 |
+
|
172 |
+
# Comparison tab
|
173 |
+
with gr.TabItem("⚖️ Mode Comparison"):
|
174 |
+
gr.Markdown("### Compare how different penalty modes affect the final score")
|
175 |
+
with gr.Row():
|
176 |
+
with gr.Column():
|
177 |
+
input_image_comparison = gr.Image(
|
178 |
+
label="Upload Image",
|
179 |
+
type="pil",
|
180 |
+
format="png"
|
181 |
+
)
|
182 |
+
compare_btn = gr.Button("Compare All Modes", variant="primary")
|
183 |
+
|
184 |
+
with gr.Column():
|
185 |
+
output_comparison = gr.Markdown(label="Comparison Results")
|
186 |
+
|
187 |
+
compare_btn.click(
|
188 |
+
batch_comparison,
|
189 |
+
inputs=input_image_comparison,
|
190 |
+
outputs=output_comparison
|
191 |
+
)
|
192 |
+
|
193 |
+
gr.Markdown("""
|
194 |
+
---
|
195 |
+
### 💡 How It Works:
|
196 |
+
- **Traditional IQA**: Uses qalign (or your chosen model) for overall perceptual quality
|
197 |
+
- **Text Quality**: Uses OCR confidence scores to detect letter distortions and rendering issues
|
198 |
+
- **Combined Score**: Weighted combination that penalizes text quality problems
|
199 |
+
|
200 |
+
### 🎯 Perfect For:
|
201 |
+
- Detecting letter distortions in AI-generated images
|
202 |
+
- Evaluating text rendering quality in synthetic images
|
203 |
+
- Getting more accurate quality assessments for images containing text
|
204 |
+
|
205 |
+
### ⚙️ Penalty Modes:
|
206 |
+
- **Lenient**: Light penalty for text issues (good for artistic images)
|
207 |
+
- **Balanced**: Moderate penalty (recommended for most use cases)
|
208 |
+
- **Strict**: Heavy penalty for any text problems (best for document quality)
|
209 |
+
""")
|
210 |
+
|
211 |
+
if __name__ == "__main__":
|
212 |
+
demo.launch(share=True)
|
core.py
CHANGED
@@ -1,23 +1,270 @@
|
|
1 |
-
|
2 |
import pyiqa
|
3 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
class
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
self.model = pyiqa.create_metric(model_name, device=device)
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
if __name__ == "__main__":
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
image_files = glob.glob("samples/*")
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# core.py - Enhanced with Text Quality AssessmentR
|
2 |
import pyiqa
|
3 |
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import glob
|
6 |
+
import logging
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
import easyocr
|
10 |
+
from typing import Dict, List, Tuple, Optional
|
11 |
+
import warnings
|
12 |
+
warnings.filterwarnings("ignore")
|
13 |
|
14 |
+
class TextQualityAssessor:
|
15 |
+
"""Specialized text quality assessment using OCR confidence scores"""
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
self.ocr_reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
|
19 |
+
|
20 |
+
def assess_text_quality(self, image: Image.Image) -> Dict:
|
21 |
+
"""Assess text quality using OCR confidence and detection metrics"""
|
22 |
+
try:
|
23 |
+
# Convert PIL to OpenCV format
|
24 |
+
cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
25 |
+
|
26 |
+
# Perform OCR with confidence scores
|
27 |
+
results = self.ocr_reader.readtext(cv_image, detail=1)
|
28 |
+
|
29 |
+
if not results:
|
30 |
+
return {
|
31 |
+
'text_detected': False,
|
32 |
+
'text_quality_score': 100.0, # No text = no text quality issues
|
33 |
+
'avg_confidence': 1.0,
|
34 |
+
'text_regions': 0,
|
35 |
+
'low_quality_regions': 0,
|
36 |
+
'details': "No text detected"
|
37 |
+
}
|
38 |
+
|
39 |
+
confidences = [result[2] for result in results]
|
40 |
+
avg_confidence = np.mean(confidences)
|
41 |
+
|
42 |
+
# Count low quality text regions (confidence < 0.8)
|
43 |
+
low_quality_threshold = 0.8
|
44 |
+
low_quality_regions = sum(1 for conf in confidences if conf < low_quality_threshold)
|
45 |
+
|
46 |
+
# Calculate text quality score based on confidence distribution
|
47 |
+
# Higher penalties for very low confidence text
|
48 |
+
quality_penalties = []
|
49 |
+
for conf in confidences:
|
50 |
+
if conf >= 0.9:
|
51 |
+
quality_penalties.append(0) # Excellent text
|
52 |
+
elif conf >= 0.8:
|
53 |
+
quality_penalties.append(5) # Good text
|
54 |
+
elif conf >= 0.6:
|
55 |
+
quality_penalties.append(15) # Readable but poor quality
|
56 |
+
elif conf >= 0.4:
|
57 |
+
quality_penalties.append(30) # Heavily distorted
|
58 |
+
else:
|
59 |
+
quality_penalties.append(50) # Severely distorted/unreadable
|
60 |
+
|
61 |
+
avg_penalty = np.mean(quality_penalties) if quality_penalties else 0
|
62 |
+
text_quality_score = max(0, 100 - avg_penalty)
|
63 |
+
|
64 |
+
# Additional penalty for high proportion of low-quality regions
|
65 |
+
if len(confidences) > 0:
|
66 |
+
low_quality_ratio = low_quality_regions / len(confidences)
|
67 |
+
if low_quality_ratio > 0.5: # More than half regions are poor quality
|
68 |
+
text_quality_score *= 0.7 # 30% additional penalty
|
69 |
+
|
70 |
+
return {
|
71 |
+
'text_detected': True,
|
72 |
+
'text_quality_score': text_quality_score,
|
73 |
+
'avg_confidence': avg_confidence,
|
74 |
+
'text_regions': len(results),
|
75 |
+
'low_quality_regions': low_quality_regions,
|
76 |
+
'details': f"Detected {len(results)} text regions, avg confidence: {avg_confidence:.3f}"
|
77 |
+
}
|
78 |
+
|
79 |
+
except Exception as e:
|
80 |
+
logging.error(f"Text quality assessment error: {str(e)}")
|
81 |
+
return {
|
82 |
+
'text_detected': False,
|
83 |
+
'text_quality_score': 50.0, # Neutral score on error
|
84 |
+
'avg_confidence': 0.0,
|
85 |
+
'text_regions': 0,
|
86 |
+
'low_quality_regions': 0,
|
87 |
+
'details': f"Error: {str(e)}"
|
88 |
+
}
|
89 |
+
|
90 |
+
class HybridIQA:
|
91 |
+
"""Enhanced IQA with text-specific quality assessment"""
|
92 |
+
|
93 |
+
def __init__(self, model_name="qualiclip+", text_weight=0.3):
|
94 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
95 |
self.model = pyiqa.create_metric(model_name, device=device)
|
96 |
+
self.text_assessor = TextQualityAssessor()
|
97 |
+
self.text_weight = text_weight # Weight for text quality in final score
|
98 |
+
self.model_name = model_name
|
99 |
+
|
100 |
+
logging.basicConfig(level=logging.INFO)
|
101 |
+
self.logger = logging.getLogger(__name__)
|
102 |
+
self.logger.info(f"Hybrid IQA loaded: {model_name} + Text Quality Assessment on {device}")
|
103 |
+
|
104 |
+
def __call__(self, image, return_details=False):
|
105 |
+
"""
|
106 |
+
Evaluate image quality with both traditional IQA and text-specific assessment
|
107 |
+
|
108 |
+
Args:
|
109 |
+
image: PIL Image or path to image
|
110 |
+
return_details: If True, return detailed breakdown
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
If return_details=False: Combined quality score (0-100)
|
114 |
+
If return_details=True: Dict with detailed scores and analysis
|
115 |
+
"""
|
116 |
+
try:
|
117 |
+
# Ensure image is PIL Image
|
118 |
+
if not isinstance(image, Image.Image):
|
119 |
+
image = Image.open(image).convert("RGB")
|
120 |
+
else:
|
121 |
+
image = image.convert("RGB")
|
122 |
+
|
123 |
+
# Get traditional IQA score
|
124 |
+
# Get traditional IQA score
|
125 |
+
if self.model_name == 'qalign':
|
126 |
+
# Q-Align has special interface for quality assessment
|
127 |
+
traditional_score = self.model(image, task_='quality')
|
128 |
+
else:
|
129 |
+
traditional_score = self.model(image)
|
130 |
+
|
131 |
+
if hasattr(traditional_score, 'item'):
|
132 |
+
traditional_score = traditional_score.item()
|
133 |
+
|
134 |
+
# Normalize traditional score to 0-100 range
|
135 |
+
if 0 <= traditional_score <= 1:
|
136 |
+
traditional_score *= 100
|
137 |
+
|
138 |
+
# Get text quality assessment
|
139 |
+
text_analysis = self.text_assessor.assess_text_quality(image)
|
140 |
+
|
141 |
+
# Calculate combined score
|
142 |
+
if text_analysis['text_detected']:
|
143 |
+
# If text is detected, combine scores
|
144 |
+
combined_score = (
|
145 |
+
(1 - self.text_weight) * traditional_score +
|
146 |
+
self.text_weight * text_analysis['text_quality_score']
|
147 |
+
)
|
148 |
+
|
149 |
+
# Apply additional penalty if text quality is very poor
|
150 |
+
if text_analysis['text_quality_score'] < 30:
|
151 |
+
combined_score *= 0.8 # 20% additional penalty for severely poor text
|
152 |
+
|
153 |
+
else:
|
154 |
+
# No text detected, use traditional score only
|
155 |
+
combined_score = traditional_score
|
156 |
+
|
157 |
+
if return_details:
|
158 |
+
return {
|
159 |
+
'combined_score': combined_score,
|
160 |
+
'traditional_score': traditional_score,
|
161 |
+
'text_analysis': text_analysis,
|
162 |
+
'model_used': self.model_name,
|
163 |
+
'text_weight': self.text_weight
|
164 |
+
}
|
165 |
+
else:
|
166 |
+
return combined_score
|
167 |
+
|
168 |
+
except Exception as e:
|
169 |
+
self.logger.error(f"Error processing image: {str(e)}")
|
170 |
+
return None if not return_details else {'error': str(e)}
|
171 |
+
|
172 |
+
# Backward compatibility - maintain original IQA interface
|
173 |
+
class IQA(HybridIQA):
|
174 |
+
"""Backward compatible IQA class with enhanced text assessment"""
|
175 |
+
|
176 |
+
def __init__(self, model_name="qualiclip+"):
|
177 |
+
super().__init__(model_name, text_weight=0.3)
|
178 |
+
|
179 |
+
def __call__(self, image):
|
180 |
+
"""Maintain original interface - returns single score"""
|
181 |
+
return super().__call__(image, return_details=False)
|
182 |
+
|
183 |
+
def detailed_analysis(self, image):
|
184 |
+
"""New method for detailed analysis"""
|
185 |
+
return super().__call__(image, return_details=True)
|
186 |
+
|
187 |
+
# Advanced usage class for power users
|
188 |
+
class TextAwareIQA:
|
189 |
+
"""Advanced interface with configurable text assessment parameters"""
|
190 |
+
|
191 |
+
def __init__(self, model_name="qualiclip+", text_weight=0.3, text_threshold=0.8):
|
192 |
+
self.hybrid_iqa = HybridIQA(model_name, text_weight)
|
193 |
+
self.text_threshold = text_threshold
|
194 |
+
|
195 |
+
def evaluate(self, image, text_penalty_mode='balanced'):
|
196 |
+
"""
|
197 |
+
Evaluate with different text penalty modes
|
198 |
+
|
199 |
+
Args:
|
200 |
+
image: PIL Image or path
|
201 |
+
text_penalty_mode: 'strict', 'balanced', or 'lenient'
|
202 |
+
"""
|
203 |
+
details = self.hybrid_iqa(image, return_details=True)
|
204 |
+
|
205 |
+
if details is None or 'error' in details:
|
206 |
+
return details
|
207 |
+
|
208 |
+
# Adjust text penalties based on mode
|
209 |
+
if details['text_analysis']['text_detected']:
|
210 |
+
text_score = details['text_analysis']['text_quality_score']
|
211 |
+
traditional_score = details['traditional_score']
|
212 |
+
|
213 |
+
if text_penalty_mode == 'strict':
|
214 |
+
# Heavily penalize any text quality issues
|
215 |
+
weight = 0.5
|
216 |
+
if text_score < 70:
|
217 |
+
text_score *= 0.6
|
218 |
+
elif text_penalty_mode == 'lenient':
|
219 |
+
# Only penalize severe text issues
|
220 |
+
weight = 0.1
|
221 |
+
if text_score > 40:
|
222 |
+
text_score = min(text_score * 1.2, 100)
|
223 |
+
else: # balanced
|
224 |
+
weight = 0.3
|
225 |
+
|
226 |
+
combined_score = (1 - weight) * traditional_score + weight * text_score
|
227 |
+
details['combined_score'] = combined_score
|
228 |
+
details['penalty_mode'] = text_penalty_mode
|
229 |
+
|
230 |
+
return details
|
231 |
|
232 |
if __name__ == "__main__":
|
233 |
+
# Test both interfaces
|
234 |
+
print("Testing Hybrid IQA System")
|
235 |
+
print("=" * 50)
|
236 |
+
|
237 |
+
# Original interface (backward compatible)
|
238 |
+
print("\n1. Original Interface (Backward Compatible):")
|
239 |
+
iqa_metric = IQA(model_name="qualiclip+")
|
240 |
+
|
241 |
+
# Advanced interface
|
242 |
+
print("\n2. Advanced Interface:")
|
243 |
+
advanced_iqa = TextAwareIQA(model_name="qualiclip+", text_weight=0.4)
|
244 |
+
|
245 |
image_files = glob.glob("samples/*")
|
246 |
+
if not image_files:
|
247 |
+
print("No images found in samples directory. Please add images or adjust the path.")
|
248 |
+
else:
|
249 |
+
for image_file in image_files[:3]: # Test first 3 images
|
250 |
+
print(f"\nAnalyzing: {image_file}")
|
251 |
+
|
252 |
+
# Original score
|
253 |
+
score = iqa_metric(image_file)
|
254 |
+
if score is not None:
|
255 |
+
print(f" Simple Score: {score:.2f}/100")
|
256 |
+
|
257 |
+
# Detailed analysis
|
258 |
+
details = iqa_metric.detailed_analysis(image_file)
|
259 |
+
if details and 'error' not in details:
|
260 |
+
print(f" Traditional IQA: {details['traditional_score']:.2f}/100")
|
261 |
+
print(f" Text Quality: {details['text_analysis']['text_quality_score']:.2f}/100")
|
262 |
+
print(f" Combined Score: {details['combined_score']:.2f}/100")
|
263 |
+
print(f" Text Details: {details['text_analysis']['details']}")
|
264 |
+
|
265 |
+
if details['text_analysis']['text_detected']:
|
266 |
+
print(f" Text Regions: {details['text_analysis']['text_regions']}")
|
267 |
+
print(f" Low Quality Regions: {details['text_analysis']['low_quality_regions']}")
|
268 |
+
print(f" Avg OCR Confidence: {details['text_analysis']['avg_confidence']:.3f}")
|
269 |
+
|
270 |
+
print("-" * 30)
|
requirements.txt
CHANGED
@@ -1 +1,8 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==5.38.2
|
2 |
+
pyiqa==0.1.14.1
|
3 |
+
torch==2.3.1
|
4 |
+
pillow==10.4.0
|
5 |
+
numpy==1.26.4
|
6 |
+
requests==2.32.3
|
7 |
+
easyocr==1.7.1
|
8 |
+
opencv-python-headless==4.9.0.80
|