Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| import os | |
| import threading | |
| import time | |
| import uuid | |
| import datetime | |
| import gradio as gr | |
| import huggingface_hub | |
| import requests | |
| import random | |
| from functools import partial | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # API configuration | |
| API_URL = os.getenv("API_URL") | |
| API_KEY = os.getenv("API_KEY") | |
| auth_token = os.environ.get("TOKEN") or True | |
| hf_repo = "Iker/Feedback_FactCheking" | |
| huggingface_hub.create_repo( | |
| repo_id=hf_repo, | |
| repo_type="dataset", | |
| token=auth_token, | |
| exist_ok=True, | |
| private=True, | |
| ) | |
| headers = {"X-API-Key": API_KEY, "Content-Type": "application/json"} | |
| def update_models(): | |
| models = random.sample(["pro", "pro2", "turbo", "turbo2"], 2) | |
| print(f"Models updated: {models}") | |
| return models | |
| # Function to submit a fact-checking request | |
| def submit_fact_check(article_topic, config, language, location): | |
| endpoint = f"{API_URL}/fact-check" | |
| payload = { | |
| "article_topic": article_topic, | |
| "config": config, | |
| "language": language, | |
| "location": location, | |
| } | |
| response = requests.post(endpoint, json=payload, headers=headers) | |
| response.raise_for_status() # Raise an exception for HTTP errors | |
| return response.json()["job_id"] | |
| # Function to get the result of a fact-checking job | |
| def get_fact_check_result(job_id): | |
| endpoint = f"{API_URL}/result/{job_id}" | |
| response = requests.get(endpoint, headers=headers) | |
| response.raise_for_status() # Raise an exception for HTTP errors | |
| return response.json() | |
| def fact_checking(article_topic, config): | |
| language = "es" | |
| location = "es" | |
| logger.info(f"Submitting fact-checking request for article: {article_topic}") | |
| try: | |
| job_id = submit_fact_check(article_topic, config, language, location) | |
| logger.info(f"Fact-checking job submitted. Job ID: {job_id}") | |
| # Poll for results | |
| start_time = time.time() | |
| while True: | |
| try: | |
| result = get_fact_check_result(job_id) | |
| if result["status"] == "completed": | |
| logger.info("Fact-checking completed:") | |
| logger.info(f"Response object: {result}") | |
| logger.info( | |
| f"Result: {json.dumps(result['result'], indent=4, ensure_ascii=False)}" | |
| ) | |
| return result["result"] | |
| elif result["status"] == "failed": | |
| logger.error("Fact-checking failed:") | |
| logger.error(f"Response object: {result}") | |
| logger.error(f"Error message: {result['error']}") | |
| return None | |
| else: | |
| elapsed_time = time.time() - start_time | |
| logger.info( | |
| f"Fact-checking in progress. Elapsed time: {elapsed_time:.2f} seconds" | |
| ) | |
| time.sleep(2) # Wait for 2 seconds before checking again | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Error while polling for results: {e}") | |
| time.sleep(2) # Wait before retrying | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"An error occurred while submitting the request: {e}") | |
| def format_response(response): | |
| title = response["metadata"]["title"] | |
| main_claim = response["metadata"]["main_claim"] | |
| fc = response["answer"] | |
| rq = response["related_questions"] | |
| rq_block = [] | |
| for q, a in rq.items(): | |
| rq_block.append(f"**{q}**\n{a}") | |
| return f"## {title}\n\n### {main_claim}\n\n{fc}\n\n{'\n'.join(rq_block)}" | |
| def do_both_fact_checking(msg): | |
| models = update_models() | |
| results = [None, None] | |
| threads = [] | |
| def fact_checking_1_thread(): | |
| results[0] = fact_checking(msg, config=models[0]) | |
| def fact_checking_2_thread(): | |
| results[1] = fact_checking(msg, config=models[1]) | |
| # Start the threads | |
| thread1 = threading.Thread(target=fact_checking_1_thread) | |
| thread2 = threading.Thread(target=fact_checking_2_thread) | |
| threads.append(thread1) | |
| threads.append(thread2) | |
| thread1.start() | |
| thread2.start() | |
| # Wait for the threads to complete | |
| for thread in threads: | |
| thread.join() | |
| # Format the responses | |
| response_1 = format_response(results[0]) if results[0] else None | |
| response_2 = format_response(results[1]) if results[1] else None | |
| history_a = [(msg, response_1)] | |
| history_b = [(msg, response_2)] | |
| return ("", history_a, history_b, models) | |
| def save_history( | |
| models, | |
| history_0, | |
| history_1, | |
| max_new_tokens=None, | |
| temperature=None, | |
| top_p=None, | |
| repetition_penalty=None, | |
| winner=None, | |
| ): | |
| path = f"history_{uuid.uuid4()}.json" | |
| path = os.path.join("data", path) | |
| os.makedirs("data", exist_ok=True) | |
| data = { | |
| "timestamp": datetime.datetime.now().isoformat(), | |
| # "models": models, | |
| "model_a": models[0], | |
| "model_b": models[1], | |
| "hyperparameters": { | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "repetition_penalty": repetition_penalty, | |
| }, | |
| "message": history_0[0][0], | |
| "fc_a": history_0[0][1], | |
| "fc_b": history_1[0][1], | |
| "winner": winner, | |
| } | |
| with open(path, "w") as f: | |
| json.dump(data, ensure_ascii=False, indent=4, fp=f) | |
| huggingface_hub.upload_file( | |
| repo_id=hf_repo, | |
| repo_type="dataset", | |
| token=os.environ.get("TOKEN") or True, | |
| path_in_repo=path, | |
| path_or_fileobj=path, | |
| ) | |
| gr.Info("Feedback sent successfully! Thank you for your help.") | |
| with gr.Blocks( | |
| theme="gradio/soft", | |
| fill_height=True, | |
| fill_width=True, | |
| analytics_enabled=False, | |
| title="Fact Cheking Demo", | |
| css=".center-text { text-align: center; } footer {visibility: hidden;} .avatar-container {width: 50px; height: 50px; border: none;}", | |
| ) as demo: | |
| gr.Markdown("# Fact Checking Arena", elem_classes="center-text") | |
| models = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| chatbot_a = gr.Chatbot( | |
| height=800, | |
| show_copy_all_button=True, | |
| avatar_images=[ | |
| None, | |
| "https://upload.wikimedia.org/wikipedia/commons/a/ac/Green_tick.svg", | |
| ], | |
| ) | |
| with gr.Column(): | |
| chatbot_b = gr.Chatbot( | |
| show_copy_all_button=True, | |
| height=800, | |
| avatar_images=[ | |
| None, | |
| "https://upload.wikimedia.org/wikipedia/commons/a/ac/Green_tick.svg", | |
| ], | |
| ) | |
| msg = gr.Textbox( | |
| label="Introduce que quieres verificar", | |
| placeholder="Los coches electricos contaminan más que los coches de gasolina", | |
| autofocus=True, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| left = gr.Button("👈 Derecha mejor") | |
| with gr.Column(): | |
| tie = gr.Button("🤝 Iguald de buenos") | |
| with gr.Column(): | |
| fail = gr.Button("👎 Igual de malos") | |
| with gr.Column(): | |
| right = gr.Button("👉 Iquierda mejor") | |
| msg.submit( | |
| do_both_fact_checking, | |
| inputs=[ | |
| msg, | |
| ], | |
| outputs=[msg, chatbot_a, chatbot_b, models], | |
| ) | |
| left.click( | |
| partial( | |
| save_history, | |
| winner="model_a", | |
| ), | |
| inputs=[ | |
| models, | |
| chatbot_a, | |
| chatbot_b, | |
| ], | |
| ) | |
| tie.click( | |
| partial( | |
| save_history, | |
| winner="tie", | |
| ), | |
| inputs=[ | |
| models, | |
| chatbot_a, | |
| chatbot_b, | |
| ], | |
| ) | |
| fail.click( | |
| partial( | |
| save_history, | |
| winner="tie (both bad)", | |
| ), | |
| inputs=[ | |
| models, | |
| chatbot_a, | |
| chatbot_b, | |
| ], | |
| ) | |
| right.click( | |
| partial( | |
| save_history, | |
| winner="model_b", | |
| ), | |
| inputs=[ | |
| models, | |
| chatbot_a, | |
| chatbot_b, | |
| ], | |
| ) | |
| demo.load(update_models, inputs=[], outputs=[models]) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| auth=(os.getenv("GRADIO_USERNAME"), os.getenv("GRADIO_PASSWORD")), | |
| ) | |