|
import os |
|
import io |
|
import torch |
|
import gradio as gr |
|
import wikipediaapi |
|
import re |
|
import inflect |
|
import soundfile as sf |
|
import unicodedata |
|
import num2words |
|
import requests |
|
import json |
|
from PIL import Image |
|
from num2words import num2words |
|
from google.cloud import vision |
|
from datasets import load_dataset |
|
from scipy.io.wavfile import write |
|
from transformers import VitsModel, AutoTokenizer |
|
from transformers import pipeline |
|
from transformers import CLIPProcessor, CLIPModel |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan |
|
|
|
def load_attractions_json(url): |
|
response = requests.get(url) |
|
response.raise_for_status() |
|
json_text = response.text |
|
data = json.loads(json_text) |
|
return data |
|
|
|
url = "https://raw.githubusercontent.com/nktssk/tourist-helper/refs/heads/main/landmarks.json" |
|
landmark_titles = load_attractions_json(url) |
|
|
|
print(landmark_titles) |
|
|
|
|
|
def clean_text(text): |
|
text = re.sub(r'МФА:?\s?\[.*?\]', '', text) |
|
text = re.sub(r'\[.*?\]', '', text) |
|
|
|
def remove_diacritics(char): |
|
if unicodedata.category(char) == 'Mn': |
|
return '' |
|
return char |
|
|
|
text = unicodedata.normalize('NFD', text) |
|
text = ''.join(remove_diacritics(char) for char in text) |
|
text = unicodedata.normalize('NFC', text) |
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
text = re.sub(r'[^\w\s.,!?-]', '', text) |
|
|
|
return text.strip() |
|
|
|
def replace_numbers_with_text(input_string): |
|
def convert_number(match): |
|
number = match.group(0) |
|
try: |
|
return num2words(float(number) if '.' in number else int(number), lang='ru') |
|
except Exception: |
|
return number |
|
return re.sub(r'\d+(\.\d+)?', convert_number, input_string) |
|
|
|
|
|
summarization_model = pipeline("summarization", model="facebook/bart-large-cnn") |
|
wiki = wikipediaapi.Wikipedia("Nikita", "en") |
|
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") |
|
t2s_pipe = pipeline("text-to-speech", model="facebook/mms-tts-rus") |
|
translator = pipeline("translation_en_to_ru", model="Helsinki-NLP/opus-mt-en-ru") |
|
|
|
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
text_inputs = clip_processor( |
|
text=landmark_titles, |
|
images=None, |
|
return_tensors="pt", |
|
padding=True |
|
) |
|
with torch.no_grad(): |
|
text_embeds = clip_model.get_text_features(**text_inputs) |
|
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
|
|
|
def text_to_speech(text, output_path="speech.wav"): |
|
text = replace_numbers_with_text(text) |
|
model = VitsModel.from_pretrained("facebook/mms-tts-rus") |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-rus") |
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
output = model(**inputs).waveform.squeeze().numpy() |
|
|
|
sf.write(output_path, output, samplerate=model.config.sampling_rate) |
|
|
|
return output_path |
|
|
|
|
|
def fetch_wikipedia_summary(landmark): |
|
page = wiki.page(landmark) |
|
if page.exists(): |
|
return clean_text(page.summary) |
|
else: |
|
return "Found error!" |
|
|
|
|
|
def recognize_landmark_clip(image): |
|
if not isinstance(image, Image.Image): |
|
image = Image.fromarray(image) |
|
|
|
image_inputs = clip_processor(images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
image_embed = clip_model.get_image_features(**image_inputs) |
|
image_embed = image_embed / image_embed.norm(p=2, dim=-1, keepdim=True) |
|
|
|
similarity = (image_embed @ text_embeds.T).squeeze(0) |
|
best_idx = similarity.argmax().item() |
|
best_score = similarity[best_idx].item() |
|
recognized_landmark = landmark_titles[best_idx] |
|
return recognized_landmark, best_score |
|
|
|
|
|
def tourist_helper_with_russian(landmark): |
|
wiki_text = fetch_wikipedia_summary(landmark) |
|
if wiki_text == "Found error!": |
|
return None |
|
|
|
print(wiki_text) |
|
summarized_text = summarization_model(wiki_text, min_length=20, max_length=210)[0]["summary_text"] |
|
print(summarized_text) |
|
|
|
translated = translator(summarized_text, max_length=1000)[0]["translation_text"] |
|
print(translated) |
|
|
|
audio_path = text_to_speech(translated) |
|
return audio_path |
|
|
|
def process_image_clip(image): |
|
recognized, score = recognize_landmark_clip(image) |
|
print(f"[CLIP] Распознано: {recognized}, score={score:.2f}") |
|
audio_path = tourist_helper_with_russian(recognized) |
|
return audio_path |
|
|
|
def process_text_clip(landmark): |
|
return tourist_helper_with_russian(landmark) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Помощь туристу") |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("CLIP + Sum + Translate + T2S"): |
|
gr.Markdown("### Распознавание (CLIP) и перевод на русский") |
|
|
|
with gr.Row(): |
|
image_input_c = gr.Image(label="Загрузите фото", type="pil") |
|
text_input_c = gr.Textbox(label="Или введите название") |
|
|
|
audio_output_c = gr.Audio(label="Результатт") |
|
|
|
with gr.Row(): |
|
btn_recognize_c = gr.Button("Распознать и перевести на русский") |
|
btn_text_c = gr.Button("Поиск по тексту") |
|
|
|
btn_recognize_c.click( |
|
fn=process_image_clip, |
|
inputs=image_input_c, |
|
outputs=audio_output_c |
|
) |
|
btn_text_c.click( |
|
fn=process_text_clip, |
|
inputs=text_input_c, |
|
outputs=audio_output_c |
|
) |
|
|
|
demo.launch(debug=True) |
|
|
|
|