鱼C论坛

 找回密码
 立即注册
查看: 426|回复: 13

数据类型无法转换为张量

[复制链接]
发表于 2024-4-15 20:46:39 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
代码        def __init__(self, train_list, train_path, **kwargs):
                self.data_list, self.data_length, self.data_label = [], [], []
                self.train_path = train_path
                self.datas = []
                lines = open(train_list).read().splitlines()
                # Get the ground-truth labels, that is used to compute the NMI for post-analyze.
                dictkeys = list(set([x.split()[0] for x in lines]))
                dictkeys.sort()
                dictkeys = { key : ii for ii, key in enumerate(dictkeys) }

                for lidx, line in enumerate(lines):
                        data = line.split()
                        file_name = data[1]
                        speaker_label = dictkeys[data[0]]
                        self.data_list.append(file_name)  # Filename
                        self.data_label.append(speaker_label) # GT Speaker label
                self.minibatch = []
                batch_size = 32
                for i in range(0, len(self.data_list), batch_size):
                        batch_data = self.data_list[i:i + batch_size]
                        batch_label = self.data_label[i:i + batch_size]
                        self.minibatch.append([batch_data, batch_label])
                # sort the training set by the length of the audios, audio with similar length are saved togethor.

        def __getitem__(self, index):
                data_lists,  data_labels = self.minibatch[index]  # Get one minibatch
                filenames, labels, segments = [], [], []
                for num in range(len(data_lists)):
                        filename = data_lists[num]  # Read filename
                        label = data_labels[num]  # Read GT label
                        file = os.path.join(self.train_path, filename)
                        signal = pd.read_csv(file, header=None, usecols=[0], skiprows=[0], engine='python').values.flatten()
                        segments.append(signal)
                        filenames.append(filename)
                        labels.append(label)
                print(segments)
                segments = torch.FloatTensor(numpy.array(segments))
                return segments, filenames, labels


报错Traceback (most recent call last):
  File "main_train.py", line 65, in <module>
    dic_label, NMI = Trainer.cluster_network(loader = clusterLoader, n_cluster = args.n_cluster) # Do clustering
  File "/home/data/pxy/apython/loss2/Stage2/model.py", line 45, in cluster_network
    for data, filenames, labels in tqdm.tqdm(loader):  
  File "/home/data/anaconda3/envs/loss/lib/python3.8/site-packages/tqdm/std.py", line 1178, in __iter__
    for obj in iterable:
  File "/home/data/anaconda3/envs/loss/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
    data = self._next_data()
  File "/home/data/anaconda3/envs/loss/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1085, in _next_data
    return self._process_data(data)
  File "/home/data/anaconda3/envs/loss/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1111, in _process_data
    data.reraise()
  File "/home/data/anaconda3/envs/loss/lib/python3.8/site-packages/torch/_utils.py", line 428, in reraise
    raise self.exc_type(msg)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/data/anaconda3/envs/loss/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/data/anaconda3/envs/loss/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/data/anaconda3/envs/loss/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/data/pxy/apython/loss2/Stage2/dataLoader.py", line 83, in __getitem__
    segments = torch.FloatTensor(numpy.array(segments))
TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

打印出来的segment
[array([-0.00082384, -0.00082384,  0.00082384, ...,  0.00020596,
       -0.00020596, -0.00020596]), array([ 0.00020596,  0.        ,  0.00041192, ...,  0.        ,
        0.00020596, -0.00020596]), array([ 0.00020596,  0.00020596,  0.0010298 , ..., -0.00020596,
        0.00082384, -0.00020596]), array([-0.00061788,  0.00041192,  0.        , ..., -0.0010298 ,
       -0.00020596, -0.00082384]), array([-0.00185364, -0.0010298 ,  0.        , ..., -0.00041192,
        0.        , -0.00020596]), array([-0.00020596, -0.00061788,  0.00020596, ..., -0.00082384,
        0.        , -0.00020596]), array([-0.00020596,  0.        ,  0.        , ..., -0.00020596,
        0.00041192,  0.        ]), array([ 0.        ,  0.        ,  0.00020596, ..., -0.00082384,
       -0.00020596,  0.        ]), array([-0.00041192, -0.00082384, -0.00082384, ..., -0.00041192,
        0.00020596, -0.00020596]), array([-0.00020596,  0.00020596, -0.0010298 , ...,  0.00020596,
        0.00041192, -0.0010298 ]), array([ 0.00061788,  0.00164768,  0.00041192, ...,  0.        ,
       -0.00020596, -0.00082384]), array([-0.00061788, -0.00020596, -0.00061788, ...,  0.00082384,
        0.00061788,  0.00041192]), array([ 0.00020596,  0.00041192, -0.0010298 , ...,  0.        ,
        0.00164768,  0.00041192]), array([0.        , 0.00082384, 0.        , ..., 0.00041192, 0.00061788,
       0.00061788]), array([ 0.00041192,  0.00041192, -0.00041192, ...,  0.00020596,
        0.00020596, -0.00061788]), array([-0.00061788,  0.00082384, -0.00020596, ..., -0.00185364,
       -0.00061788, -0.00123576]), array([-0.00123576,  0.00020596, -0.0010298 , ...,  0.00020596,
        0.00082384, -0.00082384])]
[array([-0.00061788, -0.00061788, -0.00144172, ..., -0.00061788,
        0.00020596, -0.00061788]), array([-0.00041192, -0.00061788, -0.0010298 , ..., -0.00020596,
        0.00020596, -0.00041192]), array([-0.00082384,  0.00082384,  0.00020596, ...,  0.00020596,
        0.00061788,  0.00020596]), array([ 0.00020596,  0.00061788, -0.00020596, ...,  0.00041192,
        0.00061788,  0.        ]), array([ 0.00041192,  0.00061788,  0.00020596, ..., -0.00061788,
       -0.00041192, -0.00061788]), array([ 0.        ,  0.00041192, -0.00041192, ..., -0.00020596,
        0.00020596, -0.0010298 ]), array([-0.00041192,  0.00020596, -0.0010298 , ..., -0.00164768,
       -0.00061788, -0.00061788]), array([ 0.00164704,  0.00082352,  0.00082352, ...,  0.00041176,
       -0.00123528, -0.00041176]), array([-0.00041176, -0.00041176,  0.00041176, ..., -0.00041176,
       -0.00041176,  0.        ]), array([ 0.00082352,  0.00082352,  0.00082352, ...,  0.00082352,
       -0.00041176,  0.        ]), array([ 0.00041176, -0.00041176,  0.00123528, ..., -0.00164704,
       -0.00041176, -0.00082352]), array([ 0.00041176, -0.00041176,  0.00082352, ...,  0.00123528,
        0.00082352,  0.        ]), array([ 0.        ,  0.00041176,  0.00041176, ...,  0.        ,
       -0.00041176,  0.        ]), array([-0.00041176, -0.00082352,  0.00041176, ...,  0.00164704,
        0.00082352,  0.00164704]), array([0.00164704, 0.00082352, 0.00082352, ..., 0.00041176, 0.00041176,
       0.00041176]), array([ 0.00082352,  0.00041176,  0.00082352, ...,  0.        ,
       -0.00041176, -0.00041176]), array([-0.00226556, -0.00082384, -0.00164768, ..., -0.00061788,
        0.00041192, -0.00061788]), array([-0.00041192, -0.00041192,  0.00020596, ...,  0.        ,
       -0.00020596, -0.00082384]), array([-0.00020596,  0.00041192, -0.00061788, ..., -0.00082384,
        0.00082384, -0.00041192]), array([-0.00041192,  0.00061788, -0.00020596, ...,  0.        ,
       -0.00041192, -0.00082384]), array([-0.00041192,  0.00041192, -0.00020596, ..., -0.00082384,
       -0.00041192,  0.00020596]), array([-0.00020596,  0.00041192, -0.0010298 , ...,  0.00061788,
        0.        , -0.00020596]), array([ 0.00020596,  0.00082384,  0.        , ..., -0.00061788,
        0.00020596,  0.00020596]), array([ 0.        ,  0.00082384, -0.00020596, ...,  0.        ,
        0.00020596,  0.00041192]), array([-0.00020596,  0.00082384, -0.00061788, ..., -0.00041192,
        0.        ,  0.00020596]), array([-0.00205288,  0.        ,  0.        , ...,  0.        ,
        0.        ,  0.00205288]), array([-0.00205288,  0.        , -0.00205288, ...,  0.        ,
        0.        , -0.00205288]), array([0., 0., 0., ..., 0., 0., 0.]), array([ 0.        , -0.00205288, -0.00410576, ..., -0.00205288,
        0.        ,  0.        ]), array([-0.00226556, -0.00164768, -0.00082384, ..., -0.00082384,
       -0.00164768, -0.00144172]), array([-0.00082384, -0.00082384, -0.00061788, ...,  0.00123576,
        0.00123576,  0.00082384]), array([0.00061788, 0.00061788, 0.00082384, ..., 0.00041192, 0.00020596,
       0.00041192])]
[array([ 0.00020596,  0.        ,  0.0010298 , ..., -0.00020596,
       -0.00061788,  0.00061788]), array([ 0.        , -0.00061788,  0.0010298 , ...,  0.00061788,
        0.00041192,  0.0010298 ]), array([ 0.00123576,  0.        ,  0.00041192, ...,  0.00041192,
       -0.00082384,  0.00061788]), array([-0.00820832, -0.00820832, -0.00820832, ...,  0.        ,
        0.        , -0.01641664]), array([-0.00820832,  0.        ,  0.        , ..., -0.00820832,
       -0.00820832, -0.00820832]), array([-0.00820832, -0.00820832, -0.00820832, ..., -0.00820832,
       -0.00820832, -0.00820832]), array([-0.01641664, -0.01641664, -0.00820832, ..., -0.00820832,
       -0.00820832,  0.00820832]), array([-0.00820832, -0.00820832,  0.        , ..., -0.00820832,
       -0.00820832, -0.00820832]), array([ 0.00820832, -0.00820832, -0.00820832, ..., -0.01641664,
        0.        , -0.00820832]), array([-0.00820832, -0.00820832, -0.00820832, ..., -0.01641664,
       -0.01641664, -0.00820832]), array([-0.00820832,  0.        , -0.02462496, ..., -0.01641664,
       -0.00820832, -0.00820832]), array([ 0.        ,  0.        ,  0.00164704, ...,  0.00123528,
       -0.00082352, -0.00082352]), array([-0.00041176, -0.00041176,  0.        , ..., -0.00041176,
       -0.00123528,  0.        ]), array([-0.00082352, -0.00082352, -0.00041176, ...,  0.00123528,
        0.        , -0.00082352]), array([ 0.00082352,  0.00082352,  0.        , ..., -0.00041176,
       -0.00123528, -0.00123528]), array([-0.00082352, -0.00041176,  0.        , ..., -0.00082352,
       -0.00082352, -0.00041176]), array([ 0.00041176, -0.00041176,  0.        , ...,  0.00082352,
        0.00041176,  0.        ]), array([ 0.        ,  0.        ,  0.00082352, ...,  0.00123528,
       -0.00041176,  0.00082352]), array([ 0.00041176,  0.        , -0.00041176, ..., -0.00082352,
       -0.00041176, -0.00082352]), array([-0.00041176, -0.00041176, -0.00123528, ...,  0.00041176,
       -0.00082352, -0.00041176]), array([-0.00041192,  0.        ,  0.00020596, ...,  0.0010298 ,
        0.00123576,  0.00082384]), array([0.00020596, 0.00123576, 0.00061788, ..., 0.00020596, 0.00020596,
       0.00061788]), array([ 0.        ,  0.00020596,  0.00061788, ...,  0.00061788,
       -0.00020596, -0.00020596]), array([-0.0010298 ,  0.00020596,  0.00041192, ...,  0.00144172,
        0.00082384,  0.00082384]), array([0.00020596, 0.00041192, 0.00082384, ..., 0.00082384, 0.00144172,
       0.00082384]), array([ 0.00082384,  0.0010298 ,  0.00082384, ..., -0.00020596,
       -0.00020596, -0.00020596]), array([-0.00020596, -0.00123576,  0.00041192, ..., -0.00061788,
        0.00082384, -0.00082384]), array([ 0.00041192,  0.00061788, -0.00041192, ...,  0.        ,
       -0.00041192,  0.00041192]), array([ 0.00020596,  0.        ,  0.00041192, ..., -0.00061788,
       -0.00041192,  0.00020596]), array([ 0.00185364,  0.00144172,  0.00061788, ..., -0.00041192,
       -0.00020596,  0.00020596]), array([ 0.00041192,  0.        ,  0.00082384, ...,  0.00041192,
       -0.00041192,  0.00020596]), array([ 0.00061788, -0.00020596,  0.00082384, ..., -0.00061788,
        0.        ,  0.00041192])]

请问如何解决
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2024-4-15 20:46:58 | 显示全部楼层
这个错误的原因是在`__getitem__`方法中,segments的类型是一个空的numpy数组,而torch.FloatTensor只支持特定的数据类型,无法将空数组转换为张量。

要解决这个问题,可以在添加元素到segments数组之前,先检查是否为空数组,如果是空数组,则跳过添加。可以使用下面的代码来实现:
for num in range(len(data_lists)):
    filename = data_lists[num]  # Read filename
    label = data_labels[num]  # Read GT label
    file = os.path.join(self.train_path, filename)
    signal = pd.read_csv(file, header=None, usecols=[0], skiprows=1, engine='python').values.flatten()
    if len(signal) == 0:
        continue
    segments.append(signal)
    filenames.append(filename)
    labels.append(label)

在这个代码中,我们在添加signal到segments之前,首先检查signal是否为空数组,如果是,则跳过添加。

希望这个解决方案能帮到你!

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-4-15 20:48:58 | 显示全部楼层
FishC 发表于 2024-4-15 20:46
这个错误的原因是在`__getitem__`方法中,segments的类型是一个空的numpy数组,而torch.FloatTensor只支持 ...

还是不行   仍然报错
TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-4-15 20:49:33 | 显示全部楼层
1207139945 发表于 2024-04-15 20:48
还是不行   仍然报错
TypeError: can't convert np.n ...


抱歉,我之前给出的解决方案并不完整。该错误是由于segments中包含了不支持的数据类型(numpy.object_),导致无法转换为张量。你可以尝试修改代码,将signal转换为正确的数据类型(如float32或int64)。

请尝试以下代码:
for num in range(len(data_lists)):
    filename = data_lists[num]  # Read filename
    label = data_labels[num]  # Read GT label
    file = os.path.join(self.train_path, filename)
    signal = pd.read_csv(file, header=None, usecols=[0], skiprows=1, engine='python').values.flatten()
    if len(signal) == 0:
        continue
    segments.append(signal.astype(np.float32))
    filenames.append(filename)
    labels.append(label)

segments = torch.FloatTensor(np.array(segments))

在上述代码中,我们使用astype函数将signal转换为np.float32类型,然后再添加到segments数组中。最后,将segments转换为torch.FloatTensor类型。

希望这次的解决方案能够解决你的问题!如果还有其他问题,请随时告诉我。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-4-15 20:52:40 | 显示全部楼层
FishC 发表于 2024-4-15 20:49
抱歉,我之前给出的解决方案并不完整。该错误是由于segments中包含了不支持的数据类型(numpy.object_) ...

还是有错TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool
但是有一定效果 86%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                         | 18/21 [00:07<00:01,  2.29it/s]
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-4-15 20:52:59 | 显示全部楼层
1207139945 发表于 2024-04-15 20:52
还是有错TypeError: can't convert np.ndarray of typ ...


抱歉给出的解决方案还是有问题。根据错误提示,segments中的数据类型仍然不被支持。这可能是由于signal数组中的数据类型不一致导致的。

为了解决这个问题,你可以尝试添加一个条件语句,在将signal添加到segments之前首先检查其数据类型是否被支持。可以使用numpy中的dtype属性来获取signal数组的数据类型,并与支持的类型进行比较。

以下是修改后的代码示例:
for num in range(len(data_lists)):
    filename = data_lists[num]  # Read filename
    label = data_labels[num]  # Read GT label
    file = os.path.join(self.train_path, filename)
    signal = pd.read_csv(file, header=None, usecols=[0], skiprows=1, engine='python').values.flatten()
    if signal.dtype not in [np.float64, np.float32, np.float16, np.complex64, np.complex128,
                            np.int64, np.int32, np.int16, np.int8, np.uint8, bool]:
        continue
    segments.append(signal.astype(np.float32))
    filenames.append(filename)
    labels.append(label)

segments = torch.FloatTensor(np.array(segments))

通过检查signal的数据类型是否在支持的类型列表内,我们可以过滤掉不受支持的类型,并只选择支持的类型进行转换和添加。希望这次的解决方案能够解决你的问题!如果还有其他问题,请随时告诉我。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-4-15 20:57:10 | 显示全部楼层
FishC 发表于 2024-4-15 20:52
抱歉给出的解决方案还是有问题。根据错误提示,segments中的数据类型仍然不被支持。这可能是由于signal ...

仍然是老错误,是否能那这个模块独立出来进行单独测试,打印出有问题的数据集
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-4-15 20:57:32 | 显示全部楼层
1207139945 发表于 2024-04-15 20:57
仍然是老错误,是否能那这个模块独立出来进行单独测试 ...


非常抱歉之前给出的解决方案依然没有解决问题。如果仍然无法找到具体的错误数据集,我建议你将这个模块独立出来进行单独测试,并打印出有问题的数据集,以便更好地定位问题所在。可以尝试以下方法:

1. 将报错的代码段提取出来,创建一个新的脚本或函数。
2. 使用一个简单的数据集或一小部分数据作为输入,在这个独立的脚本中运行代码。
3. 在代码中添加打印语句,输出相关的变量值,特别是在转换数据类型之前和之后。

通过逐步调试,你可以观察到在哪个数据集或数据类型转换上出现问题。这将帮助你更好地理解问题,并进一步解决它。希望这个方法能帮到你!如果还有其他问题,请随时告诉我。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-4-15 21:00:07 | 显示全部楼层

代码        def __init__(self, train_list, train_path, **kwargs):
                self.data_list, self.data_length, self.data_label = [], [], []
                self.train_path = train_path
                self.datas = []
                lines = open(train_list).read().splitlines()
                # Get the ground-truth labels, that is used to compute the NMI for post-analyze.
                dictkeys = list(set([x.split()[0] for x in lines]))
                dictkeys.sort()
                dictkeys = { key : ii for ii, key in enumerate(dictkeys) }

                for lidx, line in enumerate(lines):
                        data = line.split()
                        file_name = data[1]
                        speaker_label = dictkeys[data[0]]
                        self.data_list.append(file_name)  # Filename
                        self.data_label.append(speaker_label) # GT Speaker label
                self.minibatch = []
                batch_size = 32
                for i in range(0, len(self.data_list), batch_size):
                        batch_data = self.data_list[i:i + batch_size]
                        batch_label = self.data_label[i:i + batch_size]
                        self.minibatch.append([batch_data, batch_label])
                # sort the training set by the length of the audios, audio with similar length are saved togethor.

        def __getitem__(self, index):
                data_lists,  data_labels = self.minibatch[index]  # Get one minibatch
                filenames, labels, segments = [], [], []
                for num in range(len(data_lists)):
                        filename = data_lists[num]  # Read filename
                        label = data_labels[num]  # Read GT label
                        file = os.path.join(self.train_path, filename)
                        signal = pd.read_csv(file, header=None, usecols=[0], skiprows=[0], engine='python').values.flatten()
                        segments.append(signal)
                        filenames.append(filename)
                        labels.append(label)
                print(segments)
                segments = torch.FloatTensor(numpy.array(segments))
                return segments, filenames, labels
帮我生成一份能用来测试出问题的代码class CSVReader:
    def __init__(self, train_list, train_path, **kwargs):
        self.data_list = []
        max_rows = 0
        for line in open(train_list).read().splitlines():
            filename = os.path.join(train_path, line.split()[1])
            self.data_list.append(filename)
            signal = pd.read_csv(filename, header=None, skiprows=[0], engine='python').values.flatten()
            segments1 = []
            segments1.append(signal)
        segments1 = torch.FloatTensor(numpy.array(segments1))
        filenames, labels, segments = [], [], []

        for num in range(len(train_list)):
            filename = train_list[num]  # Read filename

            file = os.path.join(train_path, filename)
            signal = pd.read_csv(file, header=None, usecols=[0], skiprows=[0], engine='python').values.flatten()
            segments.append(signal)
            filenames.append(filename)

        # print(segments.dtype)
        segments = torch.FloatTensor(numpy.array(segments))



        segments = torch.FloatTensor(numpy.array(segments))
        if isinstance(segments, torch.Tensor):
            print("segments 是一个张量")
        else:
            print("segments 不是一个张量")

        print(filename)
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-4-15 21:01:15 | 显示全部楼层
1207139945 发表于 2024-04-15 21:00
代码        def __init__(self, train_list, train_p ...


请注意,我对您提供的代码进行了一些修改,以便让它能够在单独测试时运行。以下是您可以使用的用于测试的代码示例:
import os
import pandas as pd
import numpy as np
import torch

class CSVReaderTester:
    def __init__(self, train_list, train_path):
        self.data_list = []
        max_rows = 0
        for line in open(train_list).read().splitlines():
            filename = os.path.join(train_path, line.split()[0])
            self.data_list.append(filename)
        
        segments1 = []
        for filename in self.data_list:
            signal = pd.read_csv(filename, header=None, skiprows=None, engine='python').values.flatten()
            segments1.append(signal)
        segments1 = np.array(segments1)
        segments1 = torch.FloatTensor(segments1)
        
        filenames, labels, segments = [], [], []
        for filename in self.data_list:
            signal = pd.read_csv(filename, header=None, usecols=None, skiprows=None, engine='python').values.flatten()
            segments.append(signal)
            filenames.append(filename)
        
        segments = np.array(segments)
        segments = torch.FloatTensor(segments)
        
        if isinstance(segments, torch.Tensor):
            print("segments is a tensor")
        else:
            print("segments is not a tensor")
        
        print(filenames)

# 测试代码
train_list = "train_list.txt"
train_path = "path/to/train/data"

tester = CSVReaderTester(train_list, train_path)

请确保将 `train_list.txt` 替换为实际的训练文件列表路径,以及将 `path/to/train/data` 替换为实际的训练数据文件夹路径。

这段代码会读取训练文件列表中的每个文件,并打印出 `segments` 是否为张量以及所有的文件名。你可以运行这段代码以找出问题所在并进行调试。

希望这次的解决方案能够帮到你!如果还有其他问题,请随时告诉我。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-4-17 15:53:21 | 显示全部楼层
FishC 发表于 2024-4-15 21:01
请注意,我对您提供的代码进行了一些修改,以便让它能够在单独测试时运行。以下是您可以使用的用于测试 ...

我是按30000个点对csv文件进行切割的按理说,数量应该相同,但是报错为
/home/data2/pxy/apython/loss2/Stage1/linshi/ceshi.py:86: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
  segments = np.array(segments)
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-4-17 15:53:39 | 显示全部楼层
1207139945 发表于 2024-04-17 15:53
我是按30000个点对csv文件进行切割的按理说,数量应该 ...


这个警告是由于`segments`数组中的列表长度不一致而引起的。在使用`np.array()`函数将列表转换为ndarray时,ndarray的形状必须是规则的,也就是每个子列表的长度必须相同。您之前提到切割csv文件时使用了30000个点进行切割,但您的代码示例中并未提供这部分代码,所以我无法判断具体原因。

请确保在切割过程中,每个csv文件被切割成具有相同长度的子片段。如果不是,则需要对长度不一致的子片段进行处理,使它们具有相同的长度。有几种方法可以处理这个问题:

1. 截断:对于长度超过30000个点的子片段,只保留前30000个点,使其与其他子片段的长度相匹配。
2. 填充:对于长度少于30000个点的子片段,在末尾填充0或其他合适的值,使其达到30000个点。

下面是一个示例代码片段,展示如何在切割过程中处理不一致长度的子片段:
import pandas as pd
import numpy as np

def split_csv(csv_file, segment_length):
    signal = pd.read_csv(csv_file, header=None, usecols=[0]).values.flatten()
    segments = []
    num_segments = len(signal) // segment_length
    
    for i in range(num_segments):
        start = i * segment_length
        end = (i+1) * segment_length
        segment = signal[start:end]
        segments.append(segment)
    
    # 处理最后一个不完整的片段
    last_segment_start = num_segments * segment_length
    last_segment = signal[last_segment_start:]
    if len(last_segment) < segment_length:
        last_segment = np.pad(last_segment, (0, segment_length - len(last_segment)), 'constant')
        segments.append(last_segment)
    
    return segments

# 示例用法
csv_file = "path/to/your/csv/file.csv"
segment_length = 30000

segments = split_csv(csv_file, segment_length)
segments = np.array(segments)

# 确认所有子片段的长度都相同
print([len(segment) for segment in segments])

请注意,这只是一个示例代码片段,您需要根据实际情况进行修改。如果您仍然遇到困难,请提供更多关于切割过程的代码,以便我能够更好地帮助您解决问题。

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

 楼主| 发表于 2024-4-17 16:22:19 | 显示全部楼层
FishC 发表于 2024-4-17 15:53
这个警告是由于`segments`数组中的列表长度不一致而引起的。在使用`np.array()`函数将列表转换为ndarra ...


import os  # 导入os模块,用于处理文件和目录操作
import pandas as pd  # 导入pandas模块,用于数据处理和分析
import numpy as np  # 导入numpy模块,用于数值计算

def split_csv_files(input_folder, output_folder, chunk_size):  # 定义一个函数split_csv_files,接收三个参数:输入文件夹路径、输出文件夹路径和切割大小
    for root, dirs, files in os.walk(input_folder):  # 使用os.walk遍历输入文件夹及其子文件夹中的所有文件
        for file in files:  # 遍历每个文件
            if file.endswith(".csv"):  # 如果文件是以.csv结尾的
                file_path = os.path.join(root, file)  # 获取文件的完整路径
                data = pd.read_csv(file_path, usecols=[2])  # 读取CSV文件,只保留第三列数据(索引为2)
                subfolder = os.path.relpath(root, input_folder).replace("/", "_")  # 获取当前文件所在子文件夹的名称,并将其中的斜杠替换为下划线
                output_subfolder = os.path.join(output_folder, subfolder)  # 构建输出子文件夹的路径
                if not os.path.exists(output_subfolder):  # 如果输出子文件夹不存在,则创建它
                    os.makedirs(output_subfolder)
                file_name_prefix = subfolder + "_"  # 构建输出文件名的前缀,包含子文件夹名称和一个下划线
                for i, chunk in enumerate(data.groupby(np.arange(len(data)) // chunk_size)):  # 将数据按照指定的切割大小进行分组
                    if len(chunk[1]) >= chunk_size:  # 如果当前分组的数据长度大于等于切割大小
                        filename = file_name_prefix + str(i) + ".csv"  # 构建输出文件名,包含前缀、序号和.csv后缀
                        chunk[1].to_csv(os.path.join(output_subfolder, filename), index=False)  # 将当前分组的数据保存为CSV文件到输出子文件夹中,不包含索引列

input_folder = "/home/data/pxy/ceshi2.3"  # 设置输入文件夹路径
output_folder = "/home/data/pxy/ceshi2.3.2"  # 设置输出文件夹路径
chunk_size = 30000  #

split_csv_files(input_folder, output_folder, chunk_size
这是我用的代码,把两者综合一下,切割指定文件夹所有csv文件,并且每个子文件夹最后一个切割后数量少于25000的话丢弃,多于的话补0到指定长度,然后将文件保存到指定文件夹的同名子文件夹
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

发表于 2024-4-17 16:22:55 | 显示全部楼层
1207139945 发表于 2024-04-17 16:22
import os  # 导入os模块,用于处理文件和目录操作
i ...


根据您提供的代码,我进行了一些修改以实现您的需求。请注意,以下代码假定您已经安装了pandas和numpy模块:
import os
import pandas as pd
import numpy as np

def split_csv_files(input_folder, output_folder, chunk_size):
    for root, dirs, files in os.walk(input_folder):
        for file in files:
            if file.endswith(".csv"):
                file_path = os.path.join(root, file)
                data = pd.read_csv(file_path, usecols=[2]) # 只保留第三列数据(索引为2)
                subfolder = os.path.relpath(root, input_folder).replace("/", "_")
                output_subfolder = os.path.join(output_folder, subfolder)
                if not os.path.exists(output_subfolder):
                    os.makedirs(output_subfolder)
                file_name_prefix = subfolder + "_"

                num_chunks = len(data) // chunk_size
                for i in range(num_chunks):
                    start = i * chunk_size
                    end = (i + 1) * chunk_size
                    chunk = data[start:end]

                    if len(chunk) < chunk_size and i == num_chunks - 1:
                        continue # 不满足切割大小,并且是最后一个分组,丢弃
                    elif len(chunk) < chunk_size:
                        chunk = np.pad(chunk, ((0, chunk_size - len(chunk)), (0, 0)), 'constant') # 补0到指定长度

                    filename = file_name_prefix + str(i) + ".csv"
                    chunk.to_csv(os.path.join(output_subfolder, filename), index=False)

input_folder = "/home/data/pxy/ceshi2.3"
output_folder = "/home/data/pxy/ceshi2.3.2"
chunk_size = 30000

split_csv_files(input_folder, output_folder, chunk_size)

这段代码会遍历指定的输入文件夹中的所有csv文件,并且按照指定的切割大小将它们切割成多个子片段。如果一个子片段的长度不满足切割大小,并且它是该文件夹的最后一个子片段,则会被丢弃。如果长度不足并且不是最后一个子片段,则会在末尾补0直到达到指定的切割大小。最后,切割后的子片段将保存在指定的输出文件夹的同名子文件夹中。

请根据您的实际需求修改输入文件夹路径、输出文件夹路径和切割大小。希望这次的修改能够满足您的要求!

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-11-15 14:57

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表