|
|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
如题,写了一个利用大模型校对的Py程序,代码如下:
- """
- 文本自动校对
- 版本: 1.1
- 作者:shadowmage
- 创建日期: 2025年11月
- 配置文件示例 (config.ini):
- ==============
- [API]
- base_url = https://api.siliconflow.cn
- api_key = your_api_key_here
- model_id = deepseek-ai/DeepSeek-R1-0528-Qwen3-8B
- chat_endpoint = /v1/chat/completions
- [Processing]
- max_chunk_size = 1500
- max_retries = 3
- backoff_factor = 1
- timeout = 600
- long_text_threshold = 1000
- [Paths]
- [Spacy]
- language_model = zh_core_web_sm
- fallback_model = en_core_web_sm
- 许可证: MIT License
- """
- import os
- import json
- import requests
- import difflib
- import re
- import spacy
- import configparser
- from pathlib import Path
- import time
- import random
- import datetime
- from concurrent.futures import ThreadPoolExecutor, as_completed
- import threading
- from requests.adapters import HTTPAdapter
- from urllib3.util.retry import Retry
- import csv
- from collections import defaultdict
- class OptimizedConfigManager:
- def __init__(self, config_path='config.ini'):
- self.config_path = Path(config_path)
- self.config = configparser.ConfigParser()
- self.load_config()
-
- def load_config(self):
- if not self.config_path.exists():
- raise FileNotFoundError(f"配置文件 {self.config_path} 不存在")
-
- self.config.read(self.config_path, encoding='utf-8')
-
- required_sections = ['API', 'Processing']
- for section in required_sections:
- if not self.config.has_section(section):
- raise ValueError(f"配置文件中缺少必要的 [{section}] 部分")
-
- def get_api_config(self):
- return {
- 'base_url': self.config.get('API', 'base_url', fallback='https://api.siliconflow.cn'),
- 'api_key': self.config.get('API', 'api_key'),
- 'model_id': self.config.get('API', 'model_id', fallback='deepseek-ai/DeepSeek-R1-0528-Qwen3-8B'),
- 'chat_endpoint': self.config.get('API', 'chat_endpoint', fallback='/v1/chat/completions'),
- 'batch_size': self.config.getint('API', 'batch_size', fallback=3),
- 'max_workers': self.config.getint('API', 'max_workers', fallback=5)
- }
-
- def get_processing_config(self):
- return {
- 'max_chunk_size': self.config.getint('Processing', 'max_chunk_size', fallback=1500),
- 'max_retries': self.config.getint('Processing', 'max_retries', fallback=3),
- 'backoff_factor': self.config.getfloat('Processing', 'backoff_factor', fallback=1),
- 'timeout': self.config.getint('Processing', 'timeout', fallback=600),
- 'long_text_threshold': self.config.getint('Processing', 'long_text_threshold', fallback=1000),
- 'enable_cache': self.config.getboolean('Processing', 'enable_cache', fallback=True),
- 'noun_dict_path': self.config.get('Processing', 'noun_dict_path', fallback='noun_correction_dict.csv')
- }
-
- def get_paths_config(self):
- return {
- 'source_suffix': self.config.get('Paths', 'source_suffix', fallback='_校对结果'),
- 'report_prefix': self.config.get('Paths', 'report_prefix', fallback='校对报告_')
- }
-
- def get_spacy_config(self):
- return {
- 'language_model': self.config.get('Spacy', 'language_model', fallback='zh_core_web_sm'),
- 'fallback_model': self.config.get('Spacy', 'fallback_model', fallback='en_core_web_sm')
- }
- class NounCorrectionManager:
- def __init__(self, dict_path='noun_correction_dict.csv'):
- self.dict_path = Path(dict_path)
- self.correction_dict = {}
- self.load_dictionary()
-
- def load_dictionary(self):
- if self.dict_path.exists():
- try:
- with open(self.dict_path, 'r', encoding='utf-8', newline='') as f:
- reader = csv.reader(f)
- for row in reader:
- if len(row) >= 2:
- original, corrected = row[0], row[1]
- self.correction_dict[original] = corrected
- print(f"已加载名词修正字典: {len(self.correction_dict)} 条记录")
- except Exception as e:
- print(f"加载名词修正字典失败: {e}")
-
- def save_dictionary(self):
- try:
- with open(self.dict_path, 'w', encoding='utf-8', newline='') as f:
- writer = csv.writer(f)
- for original, corrected in self.correction_dict.items():
- writer.writerow([original, corrected])
- print(f"名词修正字典已保存: {len(self.correction_dict)} 条记录")
- except Exception as e:
- print(f"保存名词修正字典失败: {e}")
-
- def add_correction(self, original, corrected):
- if original and corrected and original != corrected:
- self.correction_dict[original] = corrected
-
- def apply_corrections(self, text):
- if not self.correction_dict:
- return text
-
- for original, corrected in self.correction_dict.items():
- text = re.sub(r'\b' + re.escape(original) + r'\b', corrected, text)
- return text
- class OptimizedTextProofreader:
- def __init__(self, config_path='config.ini'):
- self._log_lock = threading.Lock()
- self.config_manager = OptimizedConfigManager(config_path)
- self.api_config = self.config_manager.get_api_config()
- self.processing_config = self.config_manager.get_processing_config()
- self.paths_config = self.config_manager.get_paths_config()
- self.spacy_config = self.config_manager.get_spacy_config()
-
- self.noun_manager = NounCorrectionManager(self.processing_config['noun_dict_path'])
- self.chat_endpoint = f"{self.api_config['base_url']}{self.api_config['chat_endpoint']}"
- self.text_cache = {} if self.processing_config['enable_cache'] else None
- self.session = self._create_session()
- self.nlp = self._initialize_spacy()
- self.noun_changes = defaultdict(list)
-
- self._log("优化版文本校对器初始化完成")
- self._log(f"API模型: {self.api_config['model_id']}")
- self._log(f"最大工作线程: {self.api_config['max_workers']}")
- self._log(f"批量大小: {self.api_config['batch_size']}")
- self._log(f"缓存启用: {self.processing_config['enable_cache']}")
- self._log(f"名词修正字典: {self.processing_config['noun_dict_path']}")
-
- def _create_session(self):
- session = requests.Session()
- retry_strategy = Retry(
- total=self.processing_config['max_retries'],
- backoff_factor=self.processing_config['backoff_factor'],
- status_forcelist=[429, 500, 502, 503, 504],
- )
- adapter = HTTPAdapter(max_retries=retry_strategy, pool_connections=10, pool_maxsize=100)
- session.mount("http://", adapter)
- session.mount("https://", adapter)
-
- return session
-
- def _get_timestamp(self):
- return datetime.datetime.now().strftime("%Y/%m/%d-%H:%M:%S")
-
- def _log(self, message):
- timestamp = self._get_timestamp()
- with self._log_lock:
- print(f"{timestamp} {message}")
-
- def _initialize_spacy(self):
- try:
- nlp = spacy.load(self.spacy_config['language_model'])
- self._log(f"spaCy模型加载成功: {self.spacy_config['language_model']}")
- return nlp
- except OSError:
- self._log(f"警告: 未找到spaCy模型 {self.spacy_config['language_model']},尝试备用模型...")
- try:
- nlp = spacy.load(self.spacy_config['fallback_model'])
- self._log(f"使用spaCy备用模型: {self.spacy_config['fallback_model']}")
- return nlp
- except OSError:
- self._log("错误: 未找到任何spaCy模型,将使用基于标点的分句方案")
- return None
-
- def _get_cache_key(self, text):
- return hash(text[:100] + text[-100:]) # 使用首尾各100字符的哈希作为键
-
- def sentence_segmentation(self, text):
- sentences = []
-
- if self.nlp:
- doc = self.nlp(text)
- for sent in doc.sents:
- sentence_text = sent.text.strip()
- if sentence_text:
- sentences.append(sentence_text)
- else:
- sentence_endings = r'([。!?!?]+)'
- parts = re.split(sentence_endings, text)
-
- current_sentence = ""
- for i, part in enumerate(parts):
- if i % 2 == 0: # 句子内容
- current_sentence = part.strip()
- else: # 句子结束标点
- if current_sentence:
- sentences.append(current_sentence + part)
- current_sentence = ""
-
- if current_sentence:
- sentences.append(current_sentence)
-
- return sentences
-
- def remove_duplicate_paragraphs(self, text):
- if not text.strip():
- return text
-
- paragraphs = re.split(r'\n\s*\n', text)
- unique_paragraphs = []
- seen_paragraphs = set()
-
- for para in paragraphs:
- clean_para = re.sub(r'\s+', ' ', para.strip())
- if clean_para and clean_para not in seen_paragraphs:
- seen_paragraphs.add(clean_para)
- unique_paragraphs.append(para)
-
- if len(unique_paragraphs) < len(paragraphs):
- self._log(f"删除了 {len(paragraphs) - len(unique_paragraphs)} 个重复段落")
-
- return '\n\n'.join(unique_paragraphs)
-
- def batch_call_api(self, text_sentences):
- if not text_sentences:
- return []
-
- if self.text_cache is not None:
- cached_results = []
- uncached_sentences = []
- uncached_indices = []
-
- for i, sentence in enumerate(text_sentences):
- cache_key = self._get_cache_key(sentence)
- if cache_key in self.text_cache:
- cached_results.append((i, self.text_cache[cache_key]))
- else:
- uncached_sentences.append(sentence)
- uncached_indices.append(i)
-
- if cached_results:
- self._log(f"缓存命中: {len(cached_results)}/{len(text_sentences)} 个句子")
- else:
- uncached_sentences = text_sentences
- uncached_indices = list(range(len(text_sentences)))
- cached_results = []
-
- if uncached_sentences:
- results = self._parallel_process_sentences(uncached_sentences, uncached_indices)
- all_results = cached_results + results
- all_results.sort(key=lambda x: x[0])
- return [result[1] for result in all_results]
- else:
- cached_results.sort(key=lambda x: x[0])
- return [result[1] for result in cached_results]
-
- def _parallel_process_sentences(self, sentences, indices):
- results = []
-
- batch_size = self.api_config['batch_size']
- max_workers = min(self.api_config['max_workers'], len(sentences))
-
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- future_to_index = {
- executor.submit(self._safe_call_api, sentence): idx
- for sentence, idx in zip(sentences, indices)
- }
-
- completed = 0
- for future in as_completed(future_to_index):
- idx = future_to_index[future]
- try:
- result = future.result()
- results.append((idx, result))
- completed += 1
-
- if self.text_cache is not None and result is not None:
- cache_key = self._get_cache_key(sentences[indices.index(idx)])
- self.text_cache[cache_key] = result
-
- if completed % 10 == 0: # 每10个句子报告一次进度
- self._log(f"进度: {completed}/{len(sentences)} 个句子处理完成")
- except Exception as e:
- self._log(f"句子处理失败: {str(e)}")
- results.append((idx, None)) # 失败时返回None
-
- return results
-
- def _safe_call_api(self, text):
- try:
- return self.call_siliconflow_api(text)
- except Exception as e:
- self._log(f"API调用异常: {str(e)}")
- return None
-
- def _clean_markdown_annotations(self, text):
- text = re.sub(r'\{[^}]*?(?:句式修正|补全|修正|优化)[^}]*?\}', '', text)
- text = re.sub(r'\*\*.*?\*\*', '', text) # 加粗
- text = re.sub(r'\*.*?\*', '', text) # 斜体
- text = re.sub(r'`.*?`', '', text) # 代码标记
- text = re.sub(r'#+\s*', '', text) # 标题标记
-
- text = re.sub(r'\s+', ' ', text).strip()
-
- return text
-
- def _extract_noun_changes(self, original, corrected):
- if not original or not corrected or original == corrected:
- return
-
- if self.nlp:
- try:
- doc_orig = self.nlp(original)
- doc_corr = self.nlp(corrected)
-
- orig_nouns = [token.text for token in doc_orig if token.pos_ in ['NOUN', 'PROPN']]
- corr_nouns = [token.text for token in doc_corr if token.pos_ in ['NOUN', 'PROPN']]
-
- for orig_noun in orig_nouns:
- if orig_noun not in corr_nouns:
- # 在修正文本中查找可能对应的名词
- for corr_noun in corr_nouns:
- if corr_noun not in orig_nouns and len(corr_noun) > 1:
- self.noun_changes[orig_noun].append(corr_noun)
- self.noun_manager.add_correction(orig_noun, corr_noun)
- self._log(f"名词修改记录: '{orig_noun}' -> '{corr_noun}'")
- except Exception as e:
- self._log(f"名词分析失败: {e}")
-
- def call_siliconflow_api(self, text):
- headers = {
- "Authorization": f"Bearer {self.api_config['api_key']}",
- "Content-Type": "application/json"
- }
- system_prompt = """你是一位专业的小说文本校对专家。请严格按照以下要求进行校对:
- ## 校对要求
- 1. 只进行错别字修正、标点符号校正、语法错误修复
- 2. 保持原文意思和风格不变
- 3. 不要添加任何说明、注释、标记或解释
- 4. 绝对禁止使用 {句式修正}、{补全} 等任何格式的标记
- 5. 不要添加任何MD格式(如**加粗**、*斜体*、`代码`等)
- 6. 输出必须是纯净的校对后文本
- ## 输出格式
- 直接输出校对后的纯净文本,不要添加任何额外内容。"""
- messages = [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": f"请对以下文本进行专业校对,只修正语言错误,不要改变意思和风格,不要添加任何标记:\n\n{text}"}
- ]
- payload = {
- "model": self.api_config['model_id'],
- "messages": messages,
- "temperature": 0.1,
- "max_tokens": 4000,
- "top_p": 0.9,
- "stream": False
- }
- max_retries = self.processing_config['max_retries']
-
- for attempt in range(max_retries + 1):
- try:
- self._log(f"正在调用校对API,文本长度: {len(text)} 字符 (第{attempt + 1}次尝试)")
-
- response = self.session.post(
- self.chat_endpoint,
- headers=headers,
- json=payload,
- timeout=self.processing_config['timeout']
- )
-
- if response.status_code != 200:
- if attempt < max_retries:
- wait_time = random.randint(5, 30)
- time.sleep(wait_time)
- continue
- else:
- return None
-
- response.raise_for_status()
- result = response.json()
-
- corrected_text = result['choices'][0]['message']['content'].strip()
-
- corrected_text = self._clean_markdown_annotations(corrected_text)
-
- self._extract_noun_changes(text, corrected_text)
-
- return corrected_text
-
- except requests.exceptions.RequestException as e:
- if attempt < max_retries:
- wait_time = random.randint(5, 30)
- time.sleep(wait_time)
- continue
- else:
- return None
- except (KeyError, IndexError, json.JSONDecodeError) as e:
- return None
- def process_text_sentences(self, text):
- preprocessed_text = self.noun_manager.apply_corrections(text)
- if preprocessed_text != text:
- self._log("已应用名词修正字典进行预处理")
-
- sentences = self.sentence_segmentation(preprocessed_text)
- self._log(f"检测到句子数: {len(sentences)}")
-
- if len(sentences) == 0:
- return text
-
- start_time = time.time()
- corrected_sentences = self.batch_call_api(sentences)
- processing_time = time.time() - start_time
-
- self._log(f"句子处理完成,耗时: {processing_time:.2f}秒")
-
- final_sentences = []
- for i, (original, corrected) in enumerate(zip(sentences, corrected_sentences)):
- if corrected is None:
- final_sentences.append(original)
- else:
- final_sentences.append(corrected)
-
- result_text = ' '.join(final_sentences)
-
- result_text = self.remove_duplicate_paragraphs(result_text)
-
- return result_text
-
- def process_chapter(self, input_path, output_dir):
- try:
- with open(input_path, 'r', encoding='utf-8') as f:
- original_text = f.read()
- self._log(f"读取文件成功,长度: {len(original_text)} 字符")
- except Exception as e:
- self._log(f"文件读取失败: {str(e)}")
- return False
-
- start_time = time.time()
- corrected_text = self.process_text_sentences(original_text)
- processing_time = time.time() - start_time
-
- if corrected_text is None:
- self._log("处理失败,跳过此文件")
- return False
-
- self._log(f"文本处理完成,耗时: {processing_time:.2f}秒")
-
- if self.noun_changes:
- self.noun_manager.save_dictionary()
-
- output_path = output_dir / input_path.name
- try:
- with open(output_path, 'w', encoding='utf-8') as f:
- f.write(corrected_text)
- self._log(f"已保存: {output_path.name}")
- return True
- except Exception as e:
- self._log(f"文件保存失败: {str(e)}")
- return False
-
- def main(self):
- if self.nlp is None:
- self._log("警告: spaCy模型未正确加载,使用备用方案")
-
- book_name = input("请输入需要校对的书名: ").strip()
- if not book_name:
- print("错误: 书名不能为空")
- return
-
- source_dir = Path(book_name)
- if not source_dir.exists() or not source_dir.is_dir():
- print(f"错误: 找不到 '{book_name}' 文件夹")
- return
-
- output_dir = Path(f"{book_name}{self.paths_config['source_suffix']}")
-
- try:
- output_dir.mkdir(exist_ok=True)
- except Exception as e:
- print(f"创建输出目录失败: {str(e)}")
- return
-
- txt_files = list(source_dir.glob("*.txt"))
- if not txt_files:
- print(f"警告: 没有找到txt文件")
- return
-
- total_start = time.time()
- success_count = 0
- print(f"开始逐句处理 {len(txt_files)} 个文件...")
-
- for i, file_path in enumerate(txt_files, 1):
- print(f"\n[{i}/{len(txt_files)}] 处理: {file_path.name}")
- if self.process_chapter(file_path, output_dir):
- success_count += 1
-
- if self.noun_changes:
- self.noun_manager.save_dictionary()
- print(f"\n名词修正字典已更新: {len(self.noun_manager.correction_dict)} 条记录")
-
- total_time = time.time() - total_start
- print("\n" + "="*50)
- print(f"全部处理完成,总耗时: {total_time:.2f}秒")
- print(f"平均每个文件: {total_time/len(txt_files):.2f}秒")
- print(f"结果目录: {output_dir.resolve()}")
- print(f"成功处理: {success_count}/{len(txt_files)} 个文件")
- def main():
- try:
- proofreader = OptimizedTextProofreader('config.ini')
- proofreader.main()
- except FileNotFoundError as e:
- print(f"错误: {e}")
- print("请确保 config.ini 配置文件存在")
- except ValueError as e:
- print(f"配置错误: {e}")
- except Exception as e:
- print(f"程序初始化失败: {e}")
- if __name__ == "__main__":
- main()
复制代码
目前来看,效率还有待提高,有些文本会瞎改,这个是大模型的问题吗?有没有什么改进方法? |
|