|
马上注册,结交更多好友,享用更多功能^_^
您需要 登录 才可以下载或查看,没有账号?立即注册
x
本帖最后由 zhangchenyvn 于 2025-8-4 14:07 编辑
由于我的文档实在是依托答辩,于是,我就请AI整理了我的代码和注释文档。
由于是发到GH上了,所以就整理成了英文,先凑活着看吧。GH的话后面会开源周边组件滴……
先来讲讲为啥开发:
最近我和一个同学正在搞一个聊天室,顺便就把我之前学过的一些技能拿出来用了。就写了个这玩意。
然后呢,就得到了依托冒着热气的屎山。这玩意第一稿能运行了,但是根本没法看。基本上都是乱的。
于是,就让AI帮我整理了代码,写了点文档。很灰色幽默,让AI整理AI的代码……
具体的逻辑还是有点乱,最近会找一个时间把这个给整理成一片文章。
使用方法是安装这些模块:
- numpy
- pandas
- torch
- transformers
- scikit-learn
- seaborn
- tqdm
- datasets
- windows-curses # windows请安装这个
复制代码
- import os
- import re
- import sys
- import numpy as np
- import pandas as pd
- import torch
- import torch.nn as nn
- # noinspection PyPep8Naming
- import torch.nn.functional as F
- from torch.utils.data import Dataset, DataLoader
- from transformers import AutoTokenizer, AutoModel
- from sklearn.model_selection import train_test_split
- from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
- from urllib.request import urlopen
- import matplotlib.pyplot as plt
- import seaborn as sns
- from tqdm import tqdm
- import warnings
- from io import BytesIO
- from zipfile import ZipFile
- from datasets import load_dataset
- # import json
- import curses
- from curses import wrapper
- import pickle
- warnings.filterwarnings('ignore')
- # Set random seed to ensure reproducible results
- def set_seed(seed=42):
- """
- Set random seed for reproducibility across different libraries
- Args:
- seed (int): Random seed value
- """
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- set_seed(42)
- # Detect device
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- print(f"Using device: {device}")
- # 1. Dataset download and preprocessing module
- class DatasetManager:
- """
- Dataset management class for downloading and preprocessing various datasets
- This class handles downloading, merging, cleaning, and saving datasets
- for training the hostility detection model. It supports multiple datasets
- including ToxicChat (English) and ASAP (Chinese).
- """
- def __init__(self):
- """
- Initialize DatasetManager with empty dataset containers
- """
- self.toxic_chat_data = None
- self.asap_data = None
- self.cped_data = None
- self.raw_data = None
- self.cleaned_data = None
- @staticmethod
- def download_and_unzip(url, extract_to='.'):
- """
- Download and unzip ZIP file from a URL
- Args:
- url (str): URL to download the ZIP file from
- extract_to (str): Directory to extract the ZIP file to
- """
- print(f"Downloading {url}...")
- http_response = urlopen(url)
- zipfile = ZipFile(BytesIO(http_response.read()))
- print(f"Extracting to {extract_to}...")
- zipfile.extractall(path=extract_to)
- print("Download and extraction completed!")
- def download_toxic_chat(self):
- """
- Download ToxicChat dataset (English, CC-BY license)
- The ToxicChat dataset contains human-annotated conversations with toxicity labels.
- This method downloads and processes the dataset for our hostility detection task.
- Returns:
- dict: Processed dataset with texts, labels, and language information
- """
- print("Downloading ToxicChat dataset...")
- try:
- # Use Hugging Face datasets library to load data
- dataset = load_dataset("lmsys/toxic-chat", "toxicchat0124")
- # Extract text and labels
- texts = []
- labels = [] # -1: hostile, 0: neutral, 1: friendly
- # Process training set
- for item in dataset['train']:
- # Use combination of user input and model output as text
- text = item.get('user_input', '') + ' ' + item.get('model_output', '')
- texts.append(text)
- # Use toxicity label as our target
- toxicity = item.get('toxicity', 0)
- if toxicity == 0: # Non-toxic
- labels.append(0) # neutral
- else: # Toxic
- labels.append(-1) # hostile
- self.toxic_chat_data = {
- 'texts': texts,
- 'labels': labels,
- 'lang': 'en'
- }
- print(f"ToxicChat dataset loaded, total {len(texts)} samples")
- return self.toxic_chat_data
- except Exception as e:
- print(f"Failed to download ToxicChat dataset: {e}")
- return None
- def download_asap(self):
- """
- Download ASAP dataset (Chinese, Apache-2.0 license)
- The ASAP dataset contains Chinese reviews with star ratings.
- This method downloads and processes the dataset, converting star ratings
- to our hostility/friendliness labels.
- Returns:
- dict: Processed dataset with texts, labels, and language information
- """
- print("Downloading ASAP dataset...")
- try:
- # Download ASAP dataset from GitHub
- repo_url = "https://github.com/Meituan-Dianping/asap/archive/refs/heads/master.zip"
- extract_path = "./AloneCheck/asap_data"
- if os.path.exists(extract_path):
- pass
- else:
- # Create directory
- os.makedirs(extract_path, exist_ok=True)
- # Download and extract
- self.download_and_unzip(repo_url, extract_path)
- # Find CSV file
- csv_files = "./AloneCheck/asap_data/asap-master/data/train.csv"
- # Load CSV file
- data = pd.read_csv(csv_files)
- texts = []
- labels = [] # -1: hostile, 0: neutral, 1: friendly
- for _, row in data.iterrows():
- if 'review' in row and 'star' in row:
- texts.append(str(row['review']))
- # Convert sentiment rating to our labels
- # Rating range is 1-5, 1-2 is negative/hostile, 3 is neutral, 4-5 is positive/friendly
- rating = row['star']
- if rating <= 2:
- labels.append(-1) # hostile
- elif rating == 3:
- labels.append(0) # neutral
- else:
- labels.append(1) # friendly
- self.asap_data = {
- 'texts': texts,
- 'labels': labels,
- 'lang': 'zh'
- }
- print(f"ASAP dataset loaded, total {len(texts)} samples")
- return self.asap_data
- except Exception as e:
- print(f"Failed to download ASAP dataset: {e}")
- return None
- @staticmethod
- def load_custom_dataset(texts, labels, langs=None):
- """
- Load custom dataset from provided texts and labels
- Args:
- texts (list): List of text samples
- labels (list): List of corresponding labels
- langs (list, optional): List of language identifiers for each text
- Returns:
- dict: Processed dataset with texts, labels, and language information
- """
- if langs is None:
- # Auto-detect language
- langs = []
- for text in texts:
- if any('\u4e00' <= char <= '\u9fff' for char in text):
- langs.append('zh')
- else:
- langs.append('en')
- return {
- 'texts': texts,
- 'labels': labels,
- 'lang': langs[0] if langs else 'en'
- }
- def merge_datasets(self):
- """
- Merge all downloaded datasets into a single dataset
- Returns:
- dict: Merged dataset containing texts, labels, and language information
- """
- all_texts = []
- all_labels = []
- all_langs = []
- if self.toxic_chat_data:
- all_texts.extend(self.toxic_chat_data['texts'])
- all_labels.extend(self.toxic_chat_data['labels'])
- all_langs.extend(['en'] * len(self.toxic_chat_data['texts']))
- if self.asap_data:
- all_texts.extend(self.asap_data['texts'])
- all_labels.extend(self.asap_data['labels'])
- all_langs.extend(['zh'] * len(self.asap_data['texts']))
- if self.cped_data:
- all_texts.extend(self.cped_data['texts'])
- all_labels.extend(self.cped_data['labels'])
- all_langs.extend(['zh'] * len(self.cped_data['texts']))
- self.raw_data = {
- 'texts': all_texts,
- 'labels': all_labels,
- 'langs': all_langs
- }
- return self.raw_data
- @staticmethod
- def preprocess_text(text, lang='zh'):
- """
- Text preprocessing function
- Args:
- text (str): Input text to preprocess
- lang (str): Language code ('zh' for Chinese, 'en' for English)
- Returns:
- str: Preprocessed text
- """
- # Convert to lowercase
- text = text.lower()
- # Remove special characters and numbers
- if lang == 'zh':
- # Chinese text processing
- text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z\s]', '', text)
- else:
- # English text processing
- text = re.sub(r'[^a-zA-Z\s]', '', text)
- # Remove extra spaces
- text = re.sub(r'\s+', ' ', text).strip()
- return text
- def clean_data(self, data=None):
- """
- Clean and preprocess the dataset
- Args:
- data (dict, optional): Dataset to clean. If None, uses self.raw_data
- Returns:
- dict: Cleaned dataset
- """
- if data is None:
- data = self.raw_data
- cleaned_texts = []
- cleaned_labels = []
- cleaned_langs = []
- for text, label, lang in zip(data['texts'], data['labels'], data['langs']):
- cleaned_text = self.preprocess_text(text, lang)
- if cleaned_text: # Ensure text is not empty
- cleaned_texts.append(cleaned_text)
- cleaned_labels.append(label)
- cleaned_langs.append(lang)
- self.cleaned_data = {
- 'texts': cleaned_texts,
- 'labels': cleaned_labels,
- 'langs': cleaned_langs
- }
- return self.cleaned_data
- def save_dataset(self, path='./AloneCheck/dataset.pkl'):
- """
- Save dataset to file
- Args:
- path (str): Path to save the dataset
- """
- os.makedirs(os.path.dirname(path), exist_ok=True)
- with open(path, 'wb') as f:
- # noinspection PyTypeChecker
- pickle.dump({
- 'raw_data': self.raw_data,
- 'cleaned_data': self.cleaned_data
- }, f)
- print(f"Dataset saved to {path}")
- def load_dataset(self, path='./AloneCheck/dataset.pkl'):
- """
- Load dataset from file
- Args:
- path (str): Path to load the dataset from
- Returns:
- bool: True if successful, False otherwise
- """
- if os.path.exists(path):
- with open(path, 'rb') as f:
- data = pickle.load(f)
- self.raw_data = data['raw_data']
- self.cleaned_data = data['cleaned_data']
- print(f"Dataset loaded from {path}")
- return True
- return False
- # 2. Custom dataset class
- class ConversationDataset(Dataset):
- """
- Custom dataset class for loading conversation data
- This class extends PyTorch's Dataset class to provide a convenient
- interface for loading and processing conversation data for training
- the hostility detection model.
- """
- def __init__(self, texts, labels, tokenizer, max_length=512, lang='zh'):
- """
- Initialize the ConversationDataset
- Args:
- texts (list): List of text samples
- labels (list): List of corresponding labels
- tokenizer: Tokenizer to use for processing text
- max_length (int): Maximum sequence length
- lang (str): Language code
- """
- self.texts = texts
- self.labels = labels
- self.tokenizer = tokenizer
- self.max_length = max_length
- self.lang = lang
- # Convert labels from -1,0,1 to 0,1,2 for cross-entropy calculation
- self.label_mapping = {-1: 0, 0: 1, 1: 2}
- def __len__(self):
- """
- Return the number of samples in the dataset
- Returns:
- int: Number of samples
- """
- return len(self.texts)
- def __getitem__(self, idx):
- """
- Get a single sample from the dataset
- Args:
- idx (int): Index of the sample to retrieve
- Returns:
- dict: Dictionary containing input_ids, attention_mask, and labels
- """
- text = str(self.texts[idx])
- label = self.labels[idx]
- # Convert label
- mapped_label = self.label_mapping[label]
- # Use tokenizer to process text
- encoding = self.tokenizer(
- text,
- truncation=True,
- padding='max_length',
- max_length=self.max_length,
- return_tensors='pt'
- )
- return {
- 'input_ids': encoding['input_ids'].flatten(),
- 'attention_mask': encoding['attention_mask'].flatten(),
- 'labels': torch.tensor(mapped_label, dtype=torch.long)
- }
- # 3. Context Jumping Correlation Attention Mechanism
- class ContextJumpingAttention(nn.Module):
- """
- Context Jumping Correlation Attention Mechanism
- This attention mechanism allows the model to focus on relevant parts of the
- input text while also considering the global context. It uses a gating mechanism
- to balance between local context and global context.
- """
- def __init__(self, hidden_size, num_heads=8, dropout=0.1):
- """
- Initialize the Context Jumping Attention mechanism
- Args:
- hidden_size (int): Size of the hidden state
- num_heads (int): Number of attention heads
- dropout (float): Dropout rate
- """
- super(ContextJumpingAttention, self).__init__()
- self.hidden_size = hidden_size
- self.num_heads = num_heads
- self.head_dim = hidden_size // num_heads
- assert (
- self.head_dim * num_heads == hidden_size
- ), "hidden_size must be divisible by num_heads"
- # Linear transformation layers
- self.query = nn.Linear(hidden_size, hidden_size)
- self.key = nn.Linear(hidden_size, hidden_size)
- self.value = nn.Linear(hidden_size, hidden_size)
- # Context jumping gating mechanism
- self.context_gate = nn.Linear(hidden_size * 2, hidden_size)
- self.gate_activation = nn.Sigmoid()
- # Output layer
- self.out = nn.Linear(hidden_size, hidden_size)
- self.dropout = nn.Dropout(dropout)
- # Layer normalization
- self.layer_norm = nn.LayerNorm(hidden_size)
- # noinspection PyPep8Naming
- def forward(self, hidden_states, attention_mask=None):
- """
- Forward pass of the Context Jumping Attention mechanism
- Args:
- hidden_states (torch.Tensor): Input hidden states
- attention_mask (torch.Tensor, optional): Attention mask
- Returns:
- torch.Tensor: Output after applying attention mechanism
- """
- batch_size, seq_len, hidden_size = hidden_states.size()
- # Calculate Q, K, V
- Q = self.query(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
- K = self.key(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
- V = self.value(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
- # Calculate attention scores
- attention_scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.head_dim)
- # Apply attention mask (if any)
- if attention_mask is not None:
- # noinspection PyTypeChecker
- attention_scores = attention_scores.masked_fill(attention_mask[:, None, None, :] == 0, -1e9)
- # Calculate attention weights
- attention_probs = F.softmax(attention_scores, dim=-1)
- attention_probs = self.dropout(attention_probs)
- # Apply attention weights to V
- context = torch.matmul(attention_probs, V)
- context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)
- # Context jumping mechanism
- # Calculate global context
- global_context = torch.mean(hidden_states, dim=1, keepdim=True).expand(-1, seq_len, -1)
- # Calculate gate value
- gate_input = torch.cat([hidden_states, global_context], dim=-1)
- gate = self.gate_activation(self.context_gate(gate_input))
- # Apply gating mechanism
- gated_context = gate * context + (1 - gate) * hidden_states
- # Output transformation
- output = self.out(gated_context)
- output = self.dropout(output)
- output = self.layer_norm(output + hidden_states) # Residual connection
- return output
- # 4. CNN module
- class TextCNN(nn.Module):
- """
- Text CNN module for extracting local features from text
- This module applies multiple convolutional filters with different kernel sizes
- to capture n-gram features at different scales.
- """
- def __init__(self, embedding_dim, num_filters, filter_sizes, dropout=0.1):
- """
- Initialize the Text CNN module
- Args:
- embedding_dim (int): Dimension of input embeddings
- num_filters (int): Number of filters per kernel size
- filter_sizes (tuple): Tuple of kernel sizes to use
- dropout (float): Dropout rate
- """
- super(TextCNN, self).__init__()
- # Multiple convolution kernel sizes
- self.convs = nn.ModuleList([
- nn.Conv1d(embedding_dim, num_filters, kernel_size=fs)
- for fs in filter_sizes
- ])
- self.dropout = nn.Dropout(dropout)
- def forward(self, x):
- """
- Forward pass of the Text CNN module
- Args:
- x (torch.Tensor): Input tensor of shape (batch_size, seq_len, embedding_dim)
- Returns:
- torch.Tensor: Output tensor after applying convolutions and pooling
- """
- # x shape: (batch_size, seq_len, embedding_dim)
- # Convert to (batch_size, embedding_dim, seq_len) to adapt to Conv1d
- x = x.transpose(1, 2)
- # Apply convolution and activation function
- conv_outputs = [F.relu(conv(x)) for conv in self.convs]
- # Max pooling
- pooled_outputs = [F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in conv_outputs]
- # Concatenate all pooled outputs
- output = torch.cat(pooled_outputs, dim=1)
- return self.dropout(output)
- # 5. Structural Pattern Recognition module
- class StructuralPatternRecognition(nn.Module):
- """
- Structural Pattern Recognition module
- This module recognizes structural patterns in the text by applying multiple
- pattern recognition layers and using attention to weight the importance of
- different patterns.
- """
- def __init__(self, hidden_size, num_patterns, dropout=0.1):
- """
- Initialize the Structural Pattern Recognition module
- Args:
- hidden_size (int): Size of the hidden state
- num_patterns (int): Number of pattern recognition layers
- dropout (float): Dropout rate
- """
- super(StructuralPatternRecognition, self).__init__()
- self.hidden_size = hidden_size
- self.num_patterns = num_patterns
- # Pattern recognition layers
- self.pattern_layers = nn.ModuleList([
- nn.Linear(hidden_size, hidden_size)
- for _ in range(num_patterns)
- ])
- # Pattern attention
- self.pattern_attention = nn.Linear(hidden_size, 1)
- # Output layer
- self.output_layer = nn.Linear(hidden_size, hidden_size)
- self.dropout = nn.Dropout(dropout)
- # Layer normalization
- self.layer_norm = nn.LayerNorm(hidden_size)
- def forward(self, x):
- """
- Forward pass of the Structural Pattern Recognition module
- Args:
- x (torch.Tensor): Input tensor
- Returns:
- torch.Tensor: Output after applying pattern recognition
- """
- batch_size, seq_len, hidden_size = x.size()
- # Apply different pattern recognition layers
- pattern_outputs = []
- for pattern_layer in self.pattern_layers:
- pattern_output = pattern_layer(x)
- pattern_outputs.append(pattern_output)
- # Calculate pattern attention weights
- pattern_stack = torch.stack(pattern_outputs, dim=2) # (batch_size, seq_len, num_patterns, hidden_size)
- # noinspection PyTypeChecker
- pattern_stack = pattern_stack.view(-1, self.num_patterns,
- hidden_size) # (batch_size * seq_len, num_patterns, hidden_size)
- pattern_attention_weights = F.softmax(self.pattern_attention(pattern_stack), dim=1)
- # Apply attention weights
- weighted_patterns = torch.sum(pattern_stack * pattern_attention_weights, dim=1)
- weighted_patterns = weighted_patterns.view(batch_size, seq_len, hidden_size)
- # Output transformation
- output = self.output_layer(weighted_patterns)
- output = self.dropout(output)
- output = self.layer_norm(output + x) # Residual connection
- return output
- # 6. Overall model
- class ConversationHostilityDetector(nn.Module):
- """
- Chat Context Hostile/Friendly Detection Model
- This is the main model that combines BERT encoding, Context Jumping Attention,
- Structural Pattern Recognition, and Text CNN for detecting hostility or friendliness
- in conversation contexts.
- """
- def __init__(self,
- bert_model_name='bert-base-multilingual-cased',
- hidden_size=768,
- num_classes=3, # 0: hostile, 1: neutral, 2: friendly (internal representation)
- num_attention_heads=8,
- num_filters=100,
- filter_sizes=(2, 3, 4, 5),
- num_patterns=5,
- dropout=0.1,
- temperature=1.0): # Add temperature parameter for smoothing probabilities
- """
- Initialize the Conversation Hostility Detector model
- Args:
- bert_model_name (str): Name of the pre-trained BERT model
- hidden_size (int): Size of hidden layers
- num_classes (int): Number of output classes
- num_attention_heads (int): Number of attention heads
- num_filters (int): Number of CNN filters
- filter_sizes (tuple): Tuple of CNN filter sizes
- num_patterns (int): Number of structural patterns
- dropout (float): Dropout rate
- temperature (float): Temperature parameter for smoothing probabilities
- """
- super(ConversationHostilityDetector, self).__init__()
- # BERT encoder
- self.bert = AutoModel.from_pretrained(bert_model_name)
- self.bert_hidden_size = self.bert.config.hidden_size
- # Context Jumping Correlation Attention
- # noinspection PyTypeChecker
- self.context_attention = ContextJumpingAttention(
- hidden_size=self.bert_hidden_size,
- num_heads=num_attention_heads,
- dropout=dropout
- )
- # Text CNN
- # noinspection PyTypeChecker
- self.text_cnn = TextCNN(
- embedding_dim=self.bert_hidden_size,
- num_filters=num_filters,
- filter_sizes=filter_sizes,
- dropout=dropout
- )
- # Structural Pattern Recognition
- # noinspection PyTypeChecker
- self.pattern_recognition = StructuralPatternRecognition(
- hidden_size=self.bert_hidden_size,
- num_patterns=num_patterns,
- dropout=dropout
- )
- # Classifier
- # noinspection PyTypeChecker
- self.classifier = nn.Sequential(
- nn.Linear(len(filter_sizes) * num_filters + self.bert_hidden_size, hidden_size),
- nn.ReLU(),
- nn.Dropout(dropout),
- nn.Linear(hidden_size, num_classes)
- )
- # Temperature parameter for smoothing probability distribution
- self.temperature = temperature
- def forward(self, input_ids, attention_mask):
- """
- Forward pass of the model
- Args:
- input_ids (torch.Tensor): Token IDs
- attention_mask (torch.Tensor): Attention mask
- Returns:
- tuple: (logits, probabilities)
- """
- # BERT encoding
- bert_outputs = self.bert(
- input_ids=input_ids,
- attention_mask=attention_mask
- )
- sequence_output = bert_outputs.last_hidden_state # (batch_size, seq_len, hidden_size)
- # Context Jumping Correlation Attention
- context_output = self.context_attention(sequence_output, attention_mask)
- # Structural Pattern Recognition
- pattern_output = self.pattern_recognition(context_output)
- # Text CNN
- # Use [CLS] token output as the entire sequence representation
- cls_output = pattern_output[:, 0, :] # (batch_size, hidden_size)
- # Apply CNN to the entire sequence
- cnn_output = self.text_cnn(pattern_output) # (batch_size, len(filter_sizes) * num_filters)
- # Concatenate CLS output and CNN output
- combined_output = torch.cat([cls_output, cnn_output], dim=1)
- # Classification
- logits = self.classifier(combined_output)
- # Apply temperature scaling and calculate softmax probabilities
- scaled_logits = logits / self.temperature
- probabilities = F.softmax(scaled_logits, dim=-1)
- return logits, probabilities
- def predict_with_probabilities(self, input_ids, attention_mask):
- """
- Predict and return probabilities and classification results
- Args:
- input_ids (torch.Tensor): Token IDs
- attention_mask (torch.Tensor): Attention mask
- Returns:
- dict: Dictionary containing logits, probabilities, smooth scores, and predicted labels
- """
- self.eval()
- with torch.no_grad():
- logits, probabilities = self.forward(input_ids, attention_mask)
- # Map probabilities to range -1 to 1
- # 0: hostile, 1: neutral, 2: friendly (internal representation)
- # Convert to -1: hostile, 0: neutral, 1: friendly (external representation)
- prob_hostile = probabilities[:, 0].cpu().numpy()
- prob_neutral = probabilities[:, 1].cpu().numpy()
- prob_friendly = probabilities[:, 2].cpu().numpy()
- # Calculate smooth scores (-1 to 1)
- # Use weighted average: -1 * p_hostile + 0 * p_neutral + 1 * p_friendly
- smooth_scores = -1 * prob_hostile + 0 * prob_neutral + 1 * prob_friendly
- # Determine classes based on probabilities
- predicted_classes = torch.argmax(probabilities, dim=1).cpu().numpy()
- # Map internal classes back to external classes
- reverse_mapping = {0: -1, 1: 0, 2: 1}
- predicted_labels = [reverse_mapping[c] for c in predicted_classes]
- return {
- 'logits': logits,
- 'probabilities': probabilities,
- 'smooth_scores': smooth_scores,
- 'predicted_labels': predicted_labels,
- 'class_probabilities': {
- 'hostile': prob_hostile,
- 'neutral': prob_neutral,
- 'friendly': prob_friendly
- }
- }
- # 7. Python API for model calling
- class HostilityDetectorAPI:
- """
- Chat Context Hostile/Friendly Detection API
- This class provides a convenient interface for loading a pre-trained model
- and making predictions on new text. It supports both single text prediction
- and conversation analysis.
- """
- def __init__(self, model_path='./AloneCheck/model.pth',
- bert_model_name='bert-base-multilingual-cased',
- device='auto'):
- """
- Initialize the detector
- Args:
- model_path (str): Path to the pre-trained model
- bert_model_name (str): Name of the BERT model
- device (str): Device to use ('auto', 'cpu', 'cuda')
- """
- # Set device
- if device == 'auto':
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- else:
- self.device = torch.device(device)
- print(f"Using device: {self.device}")
- # Load tokenizer
- self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
- # Load model
- self.model = self._load_model(model_path, bert_model_name)
- self.model.to(self.device)
- self.model.eval()
- # Class name mapping
- self.class_names = ['hostile', 'neutral', 'friendly']
- self.label_mapping = {0: -1, 1: 0, 2: 1} # Internal label to external label mapping
- self.reverse_label_mapping = {-1: 0, 0: 1, 1: 2} # External label to internal label mapping
- def _load_model(self, model_path, bert_model_name):
- """
- Load a pre-trained model
- Args:
- model_path (str): Path to the model file
- bert_model_name (str): Name of the BERT model
- Returns:
- ConversationHostilityDetector: Loaded model
- """
- # Initialize model
- # noinspection PyTypeChecker
- model = ConversationHostilityDetector(
- bert_model_name=bert_model_name,
- hidden_size=768,
- num_classes=3,
- num_attention_heads=8,
- num_filters=100,
- filter_sizes=[2, 3, 4, 5],
- num_patterns=5,
- dropout=0.1,
- temperature=1.0
- )
- # Load model weights
- if os.path.exists(model_path):
- model.load_state_dict(torch.load(model_path, map_location=self.device))
- print(f"Successfully loaded model: {model_path}")
- else:
- raise FileNotFoundError(f"Model file not found: {model_path}")
- return model
- def _preprocess_text(self, text):
- """
- Text preprocessing
- Args:
- text (str): Input text
- Returns:
- str: Preprocessed text
- """
- # Convert to lowercase
- text = text.lower()
- # Determine language
- lang = 'zh' if any('\u4e00' <= char <= '\u9fff' for char in text) else 'en'
- # Remove special characters and numbers
- if lang == 'zh':
- # Chinese text processing
- text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z\s]', '', text)
- else:
- # English text processing
- text = re.sub(r'[^a-zA-Z\s]', '', text)
- # Remove extra spaces
- text = re.sub(r'\s+', ' ', text).strip()
- return text
- def predict(self, text):
- """
- Predict the hostility/friendliness of a single text
- Args:
- text (str): Input text
- Returns:
- dict: Dictionary containing prediction results
- - predicted_label: Predicted label ('hostile', 'neutral', 'friendly')
- - smooth_score: Smooth score (-1 to 1)
- - probabilities: Probability dictionary for each class
- - confidence: Prediction confidence (probability of the highest class)
- """
- # Preprocess text
- cleaned_text = self._preprocess_text(text)
- # Use tokenizer to process text
- encoding = self.tokenizer(
- cleaned_text,
- truncation=True,
- padding='max_length',
- max_length=512,
- return_tensors='pt'
- )
- input_ids = encoding['input_ids'].to(self.device)
- attention_mask = encoding['attention_mask'].to(self.device)
- # Predict
- with torch.no_grad():
- results = self.model.predict_with_probabilities(input_ids, attention_mask)
- # Get prediction results
- predicted_label = results['predicted_labels'][0]
- predicted_class = self.class_names[predicted_label + 1] # +1 because -1 corresponds to index 0
- # Get probabilities
- hostile_prob = results['class_probabilities']['hostile'][0]
- neutral_prob = results['class_probabilities']['neutral'][0]
- friendly_prob = results['class_probabilities']['friendly'][0]
- # Get smooth score
- smooth_score = results['smooth_scores'][0]
- # Calculate confidence (probability of the highest class)
- # noinspection PyTypeChecker
- confidence = max(hostile_prob, neutral_prob, friendly_prob)
- # Build result dictionary
- result = {
- 'predicted_label': predicted_class,
- 'smooth_score': float(smooth_score),
- 'probabilities': {
- 'hostile': float(hostile_prob),
- 'neutral': float(neutral_prob),
- 'friendly': float(friendly_prob)
- },
- 'confidence': float(confidence)
- }
- return result
- def batch_predict(self, texts):
- """
- Predict the hostility/friendliness of multiple texts
- Args:
- texts (list): List of input texts
- Returns:
- list: List of dictionaries containing prediction results for each text
- """
- results = []
- for text in texts:
- result = self.predict(text)
- results.append(result)
- return results
- def analyze_conversation(self, conversation):
- """
- Analyze the overall hostility/friendliness of a conversation
- Args:
- conversation (list): List of conversation texts in chronological order
- Returns:
- dict: Conversation analysis results
- - overall_sentiment: Overall sentiment ('hostile', 'neutral', 'friendly')
- - average_score: Average smooth score
- - sentiment_distribution: Sentiment distribution
- - hostility_trend: Hostility trend (based on changes in smooth scores)
- """
- # Predict each message
- predictions = self.batch_predict(conversation)
- # Calculate average smooth score
- scores = [p['smooth_score'] for p in predictions]
- average_score = np.mean(scores)
- # Determine overall sentiment
- if average_score < -0.33:
- overall_sentiment = 'hostile'
- elif average_score > 0.33:
- overall_sentiment = 'friendly'
- else:
- overall_sentiment = 'neutral'
- # Calculate sentiment distribution
- sentiment_counts = {'hostile': 0, 'neutral': 0, 'friendly': 0}
- for p in predictions:
- sentiment_counts[p['predicted_label']] += 1
- total = len(predictions)
- sentiment_distribution = {
- 'hostile': sentiment_counts['hostile'] / total,
- 'neutral': sentiment_counts['neutral'] / total,
- 'friendly': sentiment_counts['friendly'] / total
- }
- # Calculate hostility trend (based on changes in smooth scores)
- if len(scores) > 1:
- # Calculate linear regression slope
- x = np.arange(len(scores))
- slope, _ = np.polyfit(x, scores, 1)
- if slope < -0.1:
- hostility_trend = 'deteriorating'
- elif slope > 0.1:
- hostility_trend = 'improving'
- else:
- hostility_trend = 'stable'
- else:
- hostility_trend = 'unknown'
- return {
- 'overall_sentiment': overall_sentiment,
- 'average_score': float(average_score),
- 'sentiment_distribution': sentiment_distribution,
- 'hostility_trend': hostility_trend,
- 'message_predictions': predictions
- }
- def get_explanation(self, text, result):
- """
- Get an explanation for the prediction result
- Args:
- text (str): Original text
- result (dict): Prediction result dictionary
- Returns:
- str: Explanation text
- """
- # Ensure result contains all necessary keys
- if 'predicted_label' not in result:
- return "Error: 'predicted_label' key missing in prediction result"
- if 'smooth_score' not in result:
- return "Error: 'smooth_score' key missing in prediction result"
- if 'confidence' not in result:
- return "Error: 'confidence' key missing in prediction result"
- predicted_label = result['predicted_label']
- smooth_score = result['smooth_score']
- confidence = result['confidence']
- explanation = f"Text: "{text}"\n"
- explanation += f"Predicted sentiment: {predicted_label}\n"
- explanation += f"Smooth score: {smooth_score:.2f} (range: -1 to 1)\n"
- explanation += f"Confidence: {confidence:.2%}\n\n"
- # Provide explanation based on prediction result
- if predicted_label == 'hostile':
- explanation += "This text shows hostile or negative sentiment."
- if smooth_score < -0.7:
- explanation += " The hostility level is very high."
- elif smooth_score < -0.4:
- explanation += " The hostility level is relatively high."
- else:
- explanation += " The hostility level is relatively low."
- elif predicted_label == 'friendly':
- explanation += "This text shows friendly or positive sentiment."
- if smooth_score > 0.7:
- explanation += " The friendliness level is very high."
- elif smooth_score > 0.4:
- explanation += " The friendliness level is relatively high."
- else:
- explanation += " The friendliness level is relatively low."
- else:
- explanation += "This text shows neutral sentiment."
- if abs(smooth_score) < 0.2:
- explanation += " The sentiment is very neutral."
- else:
- explanation += " Slightly leaning towards " + ("positive" if smooth_score > 0 else "negative") + "."
- return explanation
- # 8. Curses Menu System
- class HostilityDetectorMenu:
- """
- Curses-based menu system for Hostility Detector
- This class provides a terminal-based user interface for interacting with
- the hostility detection system, including training new models, making predictions,
- and continuing training with previous datasets.
- """
- def __init__(self, stdscr):
- """
- Initialize the menu system
- Args:
- stdscr: Curses standard screen object
- """
- self.stdscr = stdscr
- curses.curs_set(0)
- curses.init_pair(1, curses.COLOR_WHITE, curses.COLOR_BLUE)
- curses.init_pair(2, curses.COLOR_BLACK, curses.COLOR_WHITE)
- self.menu_items = [
- "1. Train New Model",
- "2. Use API for Prediction",
- "3. Continue Training (Previous Dataset)",
- "4. Exit"
- ]
- self.current_row = 0
- def draw_menu(self):
- """
- Draw the main menu
- """
- self.stdscr.clear()
- h, w = self.stdscr.getmaxyx()
- # Title
- title = "Chat Context Hostility/Friendly Detection System"
- self.stdscr.addstr(1, (w - len(title)) // 2, title, curses.color_pair(1))
- # Menu items
- for idx, item in enumerate(self.menu_items):
- x = w // 2 - len(item) // 2
- y = h // 2 - len(self.menu_items) // 2 + idx
- if idx == self.current_row:
- self.stdscr.addstr(y, x, item, curses.color_pair(2))
- else:
- self.stdscr.addstr(y, x, item)
- # Instructions
- instructions = "Use UP/DOWN arrows to navigate, ENTER to select"
- self.stdscr.addstr(h - 2, (w - len(instructions)) // 2, instructions)
- self.stdscr.refresh()
- def run(self):
- """
- Run the menu system
- """
- while True:
- self.draw_menu()
- key = self.stdscr.getch()
- if key == curses.KEY_UP and self.current_row > 0:
- self.current_row -= 1
- elif key == curses.KEY_DOWN and self.current_row < len(self.menu_items) - 1:
- self.current_row += 1
- elif key == curses.KEY_ENTER or key in [10, 13]:
- self.handle_selection()
- if self.current_row == 3: # Exit
- break
- def handle_selection(self):
- """
- Handle menu selection
- """
- if self.current_row == 0:
- self.train_new_model()
- elif self.current_row == 1:
- self.use_api_prediction()
- elif self.current_row == 2:
- self.continue_training()
- elif self.current_row == 3:
- self.exit_program()
- def train_new_model(self):
- """
- Train a new model
- """
- self.stdscr.clear()
- self.stdscr.addstr(1, 1, "Training New Model...")
- self.stdscr.refresh()
- try:
- # Temporarily restore normal terminal mode for training
- curses.endwin()
- # Train the model
- train_new_model_main()
- # Restore curses mode
- curses.initscr()
- curses.curs_set(0)
- except Exception as e:
- # Restore curses mode in case of error
- curses.initscr()
- curses.curs_set(0)
- self.stdscr.clear()
- self.stdscr.addstr(1, 1, f"Error: {str(e)}")
- self.stdscr.addstr(3, 1, "Press any key to continue...")
- self.stdscr.getch()
- def use_api_prediction(self):
- """
- Use API for prediction
- """
- self.stdscr.clear()
- self.stdscr.addstr(1, 1, "API Prediction Mode")
- self.stdscr.addstr(3, 1, "Enter text to analyze (or 'quit' to return to menu):")
- self.stdscr.refresh()
- curses.echo()
- curses.curs_set(1)
- try:
- # Initialize API
- api = HostilityDetectorAPI()
- while True:
- self.stdscr.move(5, 1)
- self.stdscr.clrtoeol()
- text = self.stdscr.getstr(5, 1, 200).decode('utf-8')
- if text.lower() == 'quit':
- break
- if text.strip():
- # Make prediction
- result = api.predict(text)
- # Display results
- self.stdscr.addstr(7, 1, f"Prediction: {result['predicted_label']}")
- self.stdscr.addstr(8, 1, f"Smooth Score: {result['smooth_score']:.4f}")
- self.stdscr.addstr(9, 1, f"Probabilities:")
- self.stdscr.addstr(10, 3, f"Hostile: {result['probabilities']['hostile']:.4f}")
- self.stdscr.addstr(11, 3, f"Neutral: {result['probabilities']['neutral']:.4f}")
- self.stdscr.addstr(12, 3, f"Friendly: {result['probabilities']['friendly']:.4f}")
- self.stdscr.addstr(14, 1, "Enter next text or 'quit' to return:")
- self.stdscr.refresh()
- except Exception as e:
- self.stdscr.addstr(16, 1, f"Error: {str(e)}")
- self.stdscr.addstr(18, 1, "Press any key to continue...")
- self.stdscr.getch()
- finally:
- curses.noecho()
- curses.curs_set(0)
- def continue_training(self):
- """
- Continue training with previous dataset
- """
- self.stdscr.clear()
- self.stdscr.addstr(1, 1, "Continue Training with Previous Dataset")
- self.stdscr.refresh()
- try:
- # Temporarily restore normal terminal mode for training
- curses.endwin()
- # Continue training
- continue_training_with_previous_dataset()
- # Restore curses mode
- curses.initscr()
- curses.curs_set(0)
- except Exception as e:
- # Restore curses mode in case of error
- curses.initscr()
- curses.curs_set(0)
- self.stdscr.clear()
- self.stdscr.addstr(1, 1, f"Error: {str(e)}")
- self.stdscr.addstr(3, 1, "Press any key to continue...")
- self.stdscr.getch()
- def exit_program(self):
- """
- Exit the program
- """
- self.stdscr.clear()
- self.stdscr.addstr(1, 1, "Exiting program...")
- self.stdscr.refresh()
- # 9. Training and evaluation functions
- def train_model(model, train_loader, val_loader, optimizer, criterion, num_epochs=10, device='cpu',
- save_path='./AloneCheck/model.pth', save_best_only=True):
- """
- Train the model
- Args:
- model: Model to train
- train_loader: Training data loader
- val_loader: Validation data loader
- optimizer: Optimizer
- criterion: Loss function
- num_epochs (int): Number of training epochs
- device (str): Device to use
- save_path (str): Path to save the model
- save_best_only (bool): Whether to save only the best model
- Returns:
- tuple: (model, train_losses, val_losses, train_accs, val_accs)
- """
- model = model.to(device)
- criterion = criterion.to(device)
- train_losses = []
- val_losses = []
- train_accs = []
- val_accs = []
- best_val_acc = 0.0
- best_model_state = None
- # Create directory if it doesn't exist
- os.makedirs(os.path.dirname(save_path), exist_ok=True)
- try:
- for epoch in range(num_epochs):
- # Training phase
- model.train()
- running_loss = 0.0
- correct = 0
- total = 0
- for batch in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs} - Training'):
- input_ids = batch['input_ids'].to(device)
- attention_mask = batch['attention_mask'].to(device)
- labels = batch['labels'].to(device)
- optimizer.zero_grad()
- logits, _ = model(input_ids, attention_mask)
- loss = criterion(logits, labels)
- loss.backward()
- optimizer.step()
- running_loss += loss.item()
- _, predicted = torch.max(logits.data, 1)
- total += labels.size(0)
- # noinspection PyUnresolvedReferences
- correct += (predicted == labels).sum().item()
- train_loss = running_loss / len(train_loader)
- train_acc = correct / total
- train_losses.append(train_loss)
- train_accs.append(train_acc)
- # Validation phase
- model.eval()
- running_loss = 0.0
- correct = 0
- total = 0
- with torch.no_grad():
- for batch in tqdm(val_loader, desc=f'Epoch {epoch + 1}/{num_epochs} - Validation'):
- input_ids = batch['input_ids'].to(device)
- attention_mask = batch['attention_mask'].to(device)
- labels = batch['labels'].to(device)
- logits, _ = model(input_ids, attention_mask)
- loss = criterion(logits, labels)
- running_loss += loss.item()
- _, predicted = torch.max(logits.data, 1)
- total += labels.size(0)
- # noinspection PyUnresolvedReferences
- correct += (predicted == labels).sum().item()
- val_loss = running_loss / len(val_loader)
- val_acc = correct / total
- val_losses.append(val_loss)
- val_accs.append(val_acc)
- print(f'Epoch {epoch + 1}/{num_epochs}:')
- print(f'Training loss: {train_loss:.4f}, Training accuracy: {train_acc:.4f}')
- print(f'Validation loss: {val_loss:.4f}, Validation accuracy: {val_acc:.4f}')
- # Save best model
- if val_acc > best_val_acc:
- best_val_acc = val_acc
- best_model_state = model.state_dict().copy()
- if save_best_only:
- torch.save(model.state_dict(), save_path)
- print(f"Best model saved to {save_path}")
- except KeyboardInterrupt:
- print("User Interrupt.")
- # Load best model
- if best_model_state is not None:
- model.load_state_dict(best_model_state)
- # Save final model if not saving best only
- if not save_best_only:
- torch.save(model.state_dict(), save_path)
- print(f"Final model saved to {save_path}")
- return model, train_losses, val_losses, train_accs, val_accs
- def continue_training_with_previous_dataset(model_path='./AloneCheck/model.pth',
- dataset_path='./AloneCheck/dataset.pkl',
- num_epochs=5, batch_size=8, learning_rate=2e-5,
- save_path='./AloneCheck/continued_model.pth'):
- """
- Continue training a pre-trained model with the previous dataset
- Args:
- model_path (str): Path to the pre-trained model
- dataset_path (str): Path to the saved dataset
- num_epochs (int): Number of additional training epochs
- batch_size (int): Batch size for training
- learning_rate (float): Learning rate
- save_path (str): Path to save the continued model
- Returns:
- tuple: (model, train_losses, val_losses, train_accs, val_accs)
- """
- print("Starting continued training with previous dataset...")
- # Load dataset
- dataset_manager = DatasetManager()
- if not dataset_manager.load_dataset(dataset_path):
- raise FileNotFoundError(f"Dataset not found at {dataset_path}. Please train a model first.")
- # Initialize model and load pre-trained weights
- # noinspection PyTypeChecker
- model = ConversationHostilityDetector(
- bert_model_name='bert-base-multilingual-cased',
- hidden_size=768,
- num_classes=3,
- num_attention_heads=8,
- num_filters=100,
- filter_sizes=[2, 3, 4, 5],
- num_patterns=5,
- dropout=0.1,
- temperature=1.0
- )
- model.load_state_dict(torch.load(model_path, map_location=device))
- model.to(device)
- # Use the cleaned data from previous training
- cleaned_data = dataset_manager.cleaned_data
- # Split data
- X_train, X_temp, y_train, y_temp = train_test_split(
- cleaned_data['texts'],
- cleaned_data['labels'],
- test_size=0.3,
- random_state=42
- )
- X_val, X_test, y_val, y_test = train_test_split(
- X_temp,
- y_temp,
- test_size=0.5,
- random_state=42
- )
- print(f"Previous dataset size: {len(cleaned_data['texts'])}")
- print(f"Training set size: {len(X_train)}")
- print(f"Validation set size: {len(X_val)}")
- print(f"Test set size: {len(X_test)}")
- # Create data loaders
- tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')
- train_dataset = ConversationDataset(X_train, y_train, tokenizer)
- val_dataset = ConversationDataset(X_val, y_val, tokenizer)
- test_dataset = ConversationDataset(X_test, y_test, tokenizer)
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
- test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
- # Define optimizer and criterion
- criterion = nn.CrossEntropyLoss()
- optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
- # Train model
- # noinspection PyTypeChecker
- model, train_losses, val_losses, train_accs, val_accs = train_model(
- model, train_loader, val_loader, optimizer, criterion,
- num_epochs=num_epochs, device=device, save_path=save_path
- )
- # Evaluate on test set
- print("Evaluating continued model...")
- # noinspection PyTypeChecker
- evaluation_results = evaluate_model(model, test_loader, device=device)
- print(f"Test accuracy: {evaluation_results['accuracy']:.4f}")
- print(f"Test F1 score: {evaluation_results['f1_score']:.4f}")
- return model, train_losses, val_losses, train_accs, val_accs
- def evaluate_model(model, test_loader, device='cpu'):
- """
- Evaluate the model
- Args:
- model: Model to evaluate
- test_loader: Test data loader
- device (str): Device to use
- Returns:
- dict: Evaluation results including accuracy, F1 score, classification report, and confusion matrix
- """
- model = model.to(device)
- model.eval()
- all_labels = []
- all_predictions = []
- all_smooth_scores = []
- all_probabilities = []
- with torch.no_grad():
- for batch in tqdm(test_loader, desc='Evaluation'):
- input_ids = batch['input_ids'].to(device)
- attention_mask = batch['attention_mask'].to(device)
- labels = batch['labels'].to(device)
- # Use probability prediction method
- results = model.predict_with_probabilities(input_ids, attention_mask)
- # Convert internal labels (0,1,2) to external labels (-1,0,1)
- reverse_mapping = {0: -1, 1: 0, 2: 1}
- external_labels = [reverse_mapping[label.item()] for label in labels]
- all_labels.extend(external_labels)
- all_predictions.extend(results['predicted_labels'])
- all_smooth_scores.extend(results['smooth_scores'])
- all_probabilities.append(results['probabilities'].cpu().numpy())
- # Merge all probabilities
- all_probabilities = np.vstack(all_probabilities)
- # Calculate evaluation metrics
- accuracy = accuracy_score(all_labels, all_predictions)
- f1 = f1_score(all_labels, all_predictions, average='weighted')
- # Get all possible class labels
- unique_labels = sorted(set(all_labels + all_predictions))
- # Class names
- class_names = ['Hostile', 'Neutral', 'Friendly']
- # Only include actually occurring classes
- target_names = [class_names[i + 1] for i in unique_labels] # +1 because -1 corresponds to index 0
- # Classification report
- report = classification_report(
- all_labels,
- all_predictions,
- labels=unique_labels,
- target_names=target_names,
- output_dict=True
- )
- # Confusion matrix
- cm = confusion_matrix(all_labels, all_predictions, labels=unique_labels)
- return {
- 'accuracy': accuracy,
- 'f1_score': f1,
- 'classification_report': report,
- 'confusion_matrix': cm,
- 'all_labels': all_labels,
- 'all_predictions': all_predictions,
- 'smooth_scores': all_smooth_scores,
- 'probabilities': all_probabilities,
- 'unique_labels': unique_labels,
- 'target_names': target_names
- }
- def plot_results(train_losses, val_losses, train_accs, val_accs, confusion_matrix, class_names):
- """
- Plot training results and confusion matrix
- Args:
- train_losses (list): Training losses
- val_losses (list): Validation losses
- train_accs (list): Training accuracies
- val_accs (list): Validation accuracies
- confusion_matrix (array): Confusion matrix
- class_names (list): Class names
- """
- # Create figure
- fig, axes = plt.subplots(2, 2, figsize=(15, 12))
- # Plot loss curves
- axes[0, 0].plot(train_losses, label='Training Loss')
- axes[0, 0].plot(val_losses, label='Validation Loss')
- axes[0, 0].set_title('Training and Validation Loss')
- axes[0, 0].set_xlabel('Epoch')
- axes[0, 0].set_ylabel('Loss')
- axes[0, 0].legend()
- # Plot accuracy curves
- axes[0, 1].plot(train_accs, label='Training Accuracy')
- axes[0, 1].plot(val_accs, label='Validation Accuracy')
- axes[0, 1].set_title('Training and Validation Accuracy')
- axes[0, 1].set_xlabel('Epoch')
- axes[0, 1].set_ylabel('Accuracy')
- axes[0, 1].legend()
- # Plot confusion matrix
- sns.heatmap(confusion_matrix, annot=True, fmt='d', cmap='Blues',
- xticklabels=class_names, yticklabels=class_names, ax=axes[1, 0])
- axes[1, 0].set_title('Confusion Matrix')
- axes[1, 0].set_xlabel('Predicted Label')
- axes[1, 0].set_ylabel('True Label')
- # Hide fourth subplot
- axes[1, 1].axis('off')
- plt.tight_layout()
- plt.savefig('./AloneCheck/results.png')
- plt.show()
- # 10. Main training function (for menu)
- def train_new_model_main():
- """
- Main training function for new model
- """
- # 1. Dataset download and preprocessing
- print("Starting dataset download and preprocessing...")
- dataset_manager = DatasetManager()
- # Download real datasets
- print("Downloading real datasets...")
- toxic_chat_data = dataset_manager.download_toxic_chat()
- asap_data = dataset_manager.download_asap()
- # If real dataset download fails, use simulated data
- if not toxic_chat_data and not asap_data:
- print("All real dataset downloads failed!!!")
- raise
- else:
- # Merge real datasets
- print("Merging real datasets...")
- raw_data = dataset_manager.merge_datasets()
- cleaned_data = dataset_manager.clean_data()
- print(f"Data preprocessing completed, total {len(cleaned_data['texts'])} samples")
- # Save dataset for future use
- dataset_manager.save_dataset()
- # 2. Initialize tokenizer
- print("Initializing tokenizer...")
- tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')
- # 3. Split training, validation, and test sets
- print("Splitting datasets...")
- X_train, X_temp, y_train, y_temp = train_test_split(
- cleaned_data['texts'],
- cleaned_data['labels'],
- test_size=0.3,
- random_state=42
- )
- X_val, X_test, y_val, y_test = train_test_split(
- X_temp,
- y_temp,
- test_size=0.5,
- random_state=42
- )
- print(f"Training set size: {len(X_train)}")
- print(f"Validation set size: {len(X_val)}")
- print(f"Test set size: {len(X_test)}")
- # 4. Create datasets and data loaders
- print("Creating datasets and data loaders...")
- train_dataset = ConversationDataset(X_train, y_train, tokenizer)
- val_dataset = ConversationDataset(X_val, y_val, tokenizer)
- test_dataset = ConversationDataset(X_test, y_test, tokenizer)
- batch_size = 8
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
- test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
- # 5. Initialize model
- print("Initializing model...")
- # noinspection PyTypeChecker
- model = ConversationHostilityDetector(
- bert_model_name='bert-base-multilingual-cased',
- hidden_size=768,
- num_classes=3,
- num_attention_heads=8,
- num_filters=100,
- filter_sizes=[2, 3, 4, 5],
- num_patterns=5,
- dropout=0.1,
- temperature=1.0
- )
- # 6. Define optimizer and criterion
- criterion = nn.CrossEntropyLoss()
- optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
- # 7. Train model
- print("Starting model training...")
- # noinspection PyTypeChecker
- model, train_losses, val_losses, train_accs, val_accs = train_model(
- model, train_loader, val_loader, optimizer, criterion,
- num_epochs=10, device=device
- )
- # 8. Evaluate model
- print("Evaluating model...")
- # noinspection PyTypeChecker
- evaluation_results = evaluate_model(model, test_loader, device=device)
- print(f"Test accuracy: {evaluation_results['accuracy']:.4f}")
- print(f"Test F1 score: {evaluation_results['f1_score']:.4f}")
- # 9. Plot results
- print("Plotting results...")
- plot_results(
- train_losses, val_losses, train_accs, val_accs,
- evaluation_results['confusion_matrix'], evaluation_results['target_names']
- )
- print("Training completed!")
- # 11. Main function
- def main(stdscr=None):
- """
- Main function to run the program
- Args:
- stdscr: Curses standard screen object (for menu mode)
- """
- if stdscr is not None:
- # Run with menu
- menu = HostilityDetectorMenu(stdscr)
- menu.run()
- else:
- # Run without menu (direct API usage)
- print("Chat Context Hostility/Friendly Detection System")
- print("Initializing API...")
- try:
- api = HostilityDetectorAPI()
- print("API initialized successfully!")
- # Example usage
- print("\nExample prediction:")
- example_text = "This is a friendly message, I hope you have a great day!"
- result = api.predict(example_text)
- print(f"Text: {example_text}")
- print(f"Prediction: {result['predicted_label']}")
- print(f"Confidence: {result['confidence']:.2%}")
- except FileNotFoundError as e:
- print(f"Error: {str(e)}")
- print("Please train a model first using the menu interface.")
- except Exception as e:
- print(f"Error: {str(e)}")
- if __name__ == "__main__":
- # Check if running in interactive terminal
- if sys.stdin.isatty():
- # Run with menu
- wrapper(main)
- else:
- # Run without menu
- main()
复制代码 |
评分
-
查看全部评分
|